| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Minimal end-to-end example for sui-1-24b summarization. |
| |
| Usage: |
| # Summarize a file |
| uv run example.py document.txt |
| |
| # Summarize inline text |
| uv run example.py --text "Your long text here..." |
| |
| # With custom parameters |
| uv run example.py document.txt --words 300 --tags 8 --language en |
| """ |
|
|
| import argparse |
| import hashlib |
| import json |
| import re |
| import sys |
| from pathlib import Path |
|
|
| |
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Summarize text using sui-1-24b with source grounding", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=__doc__, |
| ) |
| parser.add_argument("input", nargs="?", help="Input file path (or use --text)") |
| parser.add_argument("--text", "-t", help="Input text directly") |
| parser.add_argument("--words", "-w", type=int, default=250, help="Target word count (default: 400)") |
| parser.add_argument("--tags", "-n", type=int, default=4, help="Number of XML tags to cite (default: 10)") |
| parser.add_argument("--language", "-l", default="en", choices=["en", "de", "es", "fr", "it"], help="Language (default: en)") |
| parser.add_argument("--model", "-m", default="ellamind/sui-1-24b", help="Model path or HF repo") |
| parser.add_argument("--tensor-parallel", "-tp", type=int, default=1, help="Tensor parallel size (default: 1)") |
| parser.add_argument("--raw", action="store_true", help="Print raw JSON output instead of formatted") |
| args = parser.parse_args() |
|
|
| |
| if args.text: |
| text = args.text |
| elif args.input: |
| text = Path(args.input).read_text() |
| else: |
| parser.error("Provide input file or --text") |
|
|
| |
| import spacy |
| from vllm import LLM, SamplingParams |
|
|
| |
| |
| |
| spacy_models = { |
| "en": "en_core_web_sm", |
| "de": "de_core_news_sm", |
| "es": "es_core_news_sm", |
| "fr": "fr_core_news_sm", |
| "it": "it_core_news_sm", |
| } |
| try: |
| nlp = spacy.load(spacy_models[args.language]) |
| except OSError: |
| print(f"Error: spaCy model '{spacy_models[args.language]}' not found.") |
| print(f"For English, this should be bundled automatically.") |
| print(f"For other languages, install the model first:") |
| print(f" pip install https://github.com/explosion/spacy-models/releases/download/{spacy_models[args.language]}-3.8.0/{spacy_models[args.language]}-3.8.0-py3-none-any.whl") |
| sys.exit(1) |
|
|
| |
| print("Tagging sentences...") |
| doc = nlp(text) |
| tagged_text = "" |
| tag_mapping = {} |
|
|
| for i, sent in enumerate(doc.sents): |
| sentence = sent.text.strip() |
| if sentence: |
| tag = hashlib.md5(f"{i}_{sentence[:50]}".encode()).hexdigest()[:8] |
| tag_mapping[tag] = sentence |
| tagged_text += f"<{tag}>{sentence}</{tag}>" |
|
|
| print(f"Tagged {len(tag_mapping)} sentences") |
|
|
| |
| language_names = {"en": "English", "de": "German", "es": "Spanish", "fr": "French", "it": "Italian"} |
| prompt = f"""You are a professional summarizer, following all given instructions with the utmost care. |
| |
| <text> |
| {tagged_text} |
| </text> |
| |
| # Output Format |
| The output must be in JSON format with the following structure: |
| 1. A "structure" string containing your thoughts about the content and structure of the summary |
| 2. An "xml_tags" list containing the XML tag identifiers from the tagged text (e.g., "<a1b2c3d4>") |
| 3. A "summary" string containing the actual summary with inline XML tag references |
| |
| # Instructions |
| 1. Start by thinking about and explaining the structure and content of your summary. Select {args.tags} XML tags from the tagged text that capture the most significant data and facts. |
| 2. Begin with an executive summary introducing the title, author (if available), and key findings. |
| 3. Structure the summary in coherent paragraphs. Every paragraph should contain at least one XML tag reference. |
| 4. Reference XML tags inline in square brackets (e.g., [<a1b2c3d4>]) immediately after the statement they support. |
| 5. Each XML tag must appear exactly once in the summary. |
| 6. Avoid a concluding paragraph that merely restates points. |
| 7. Do not use bullet points or headings unless explicitly requested. |
| |
| Parameters: |
| - Word count (excl. XML tags): {args.words} |
| - Number of XML tags: {args.tags} |
| - Language: {language_names[args.language]} |
| """ |
|
|
| |
| print(f"Loading model: {args.model}") |
| llm = LLM( |
| model=args.model, |
| tensor_parallel_size=args.tensor_parallel, |
| dtype="bfloat16", |
| tokenizer_mode="mistral", |
| trust_remote_code=True, |
| limit_mm_per_prompt={"image": 0}, |
| ) |
|
|
| print("Generating summary...") |
| sampling_params = SamplingParams(max_tokens=4096, temperature=0.0) |
| outputs = llm.chat([[{"role": "user", "content": prompt}]], sampling_params) |
| result = outputs[0].outputs[0].text |
|
|
| |
| if args.raw: |
| print(result) |
| return |
|
|
| try: |
| |
| json_match = re.search(r'\{[\s\S]*\}', result) |
| if json_match: |
| data = json.loads(json_match.group()) |
|
|
| print("\n" + "=" * 60) |
| print("SUMMARY") |
| print("=" * 60 + "\n") |
|
|
| summary = data.get("summary", "") |
|
|
| |
| def replace_tag(match): |
| tag = match.group(1) |
| source = tag_mapping.get(tag, "???") |
| |
| if len(source) > 80: |
| source = source[:77] + "..." |
| return f"[{tag}]" |
|
|
| clean_summary = re.sub(r'\[<([a-f0-9]{8})>\]', replace_tag, summary) |
| print(clean_summary) |
|
|
| print("\n" + "-" * 60) |
| print("SOURCES") |
| print("-" * 60) |
|
|
| |
| |
| xml_tags = data.get("xml_tags", []) |
| for tag in xml_tags: |
| if isinstance(tag, str): |
| clean_tag = tag.strip("<>") |
| elif isinstance(tag, dict) and "xml_tag" in tag: |
| clean_tag = tag["xml_tag"].strip("<>") |
| else: |
| continue |
| source = tag_mapping.get(clean_tag, "Not found") |
| if len(source) > 100: |
| source = source[:97] + "..." |
| print(f"[{clean_tag}] {source}") |
|
|
| else: |
| print("Could not parse JSON response:") |
| print(result) |
|
|
| except json.JSONDecodeError as e: |
| print(f"JSON parse error: {e}") |
| print(result) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|