import os import io import re from typing import List, Tuple, Dict import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # --- NEW: docs --- import docx from docx.enum.text import WD_ALIGN_PARAGRAPH from docx.text.paragraph import Paragraph # PDF read & write import fitz # PyMuPDF from reportlab.lib.pagesizes import A4 from reportlab.lib.styles import getSampleStyleSheet from reportlab.lib.enums import TA_JUSTIFY from reportlab.platypus import SimpleDocTemplate, Paragraph as RLParagraph, Spacer from reportlab.lib.units import cm # ================= CONFIG ================= MODEL_REPO = "Toadoum/nllb-200-distilled-600M-sherbo-v1" ENG_CODE = "eng_Latn" # English (source) SHER_CODE = "sher_Latn" # Sherbro (target) # Inference MAX_NEW_TOKENS = 256 TEMPERATURE = 0.0 NUM_BEAMS = 4 # Increased for better quality # Performance knobs MAX_SRC_TOKENS = 420 # per chunk; reduce to ~320 if you want even faster BATCH_SIZE = 8 # reduced batch size for larger model # Device selection - CORRECTED device = 0 if torch.cuda.is_available() else "cpu" # Load model & tokenizer once print("Loading model...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_REPO) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") print("Trying to load with local_files_only...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, local_files_only=True) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_REPO, local_files_only=True) print("Model loaded from cache successfully!") except: raise print(f"Device set to use {device}") # Use dtype instead of torch_dtype to avoid deprecation warning translator = pipeline( task="translation", model=model, tokenizer=tokenizer, device=device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) # Simple text box translation (kept) def translate_text_simple(text: str) -> str: if not text or not text.strip(): return "" with torch.no_grad(): out = translator( text, src_lang=ENG_CODE, tgt_lang=SHER_CODE, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, num_beams=NUM_BEAMS, ) return out[0]["translation_text"] # ---------- Chunking + Batched Translation + Cache ---------- def tokenize_len(s: str) -> int: return len(tokenizer.encode(s, add_special_tokens=False)) def chunk_text_for_translation(text: str, max_src_tokens: int = MAX_SRC_TOKENS) -> List[str]: """Split text by sentence-ish boundaries and merge under token limit.""" if not text.strip(): return [] parts = re.split(r'(\s*[\.\!\?…:;]\s+)', text) sentences = [] for i in range(0, len(parts), 2): s = parts[i] p = parts[i+1] if i+1 < len(parts) else "" unit = (s + (p or "")).strip() if unit: sentences.append(unit) chunks, current = [], "" for sent in sentences: candidate = (current + " " + sent).strip() if current else sent if current and tokenize_len(candidate) > max_src_tokens: chunks.append(current.strip()) current = sent else: current = candidate if current.strip(): chunks.append(current.strip()) return chunks # module-level cache: identical chunks translated once TRANSLATION_CACHE: Dict[str, str] = {} def translate_chunks_list(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[str]: """ Translate a list of chunks with de-dup + batching. Returns translations in the same order as input. """ # Normalize & collect unique chunks to translate norm_chunks = [c.strip() for c in chunks] to_translate = [] for c in norm_chunks: if c and c not in TRANSLATION_CACHE: to_translate.append(c) # Batched calls with torch.no_grad(): for i in range(0, len(to_translate), batch_size): batch = to_translate[i:i + batch_size] try: outs = translator( batch, src_lang=ENG_CODE, tgt_lang=SHER_CODE, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, num_beams=NUM_BEAMS, ) for src, o in zip(batch, outs): TRANSLATION_CACHE[src] = o["translation_text"] except Exception as e: print(f"Error translating batch {i}: {e}") # Fallback: translate one by one for src in batch: try: out = translator( src, src_lang=ENG_CODE, tgt_lang=SHER_CODE, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, num_beams=1, # simpler for fallback ) TRANSLATION_CACHE[src] = out[0]["translation_text"] except: TRANSLATION_CACHE[src] = f"[Translation Error: {src[:50]}...]" return [TRANSLATION_CACHE.get(c, "") for c in norm_chunks] def translate_long_text(text: str) -> str: """Chunk → batch translate → rejoin for one paragraph/block.""" chs = chunk_text_for_translation(text) if not chs: return "" trs = translate_chunks_list(chs) # join with space to reconstruct paragraph smoothly return " ".join(trs).strip() # ---------- DOCX helpers (now fully batched across the whole doc) ---------- def is_heading(par: Paragraph) -> Tuple[bool, int]: style = (par.style.name or "").lower() if "heading" in style: for lvl in range(1, 10): if str(lvl) in style: return True, lvl return True, 1 return False, 0 def translate_docx_bytes(file_bytes: bytes) -> bytes: """ Read .docx → collect ALL chunks (paras + table cells) → single batched translation → rebuild .docx. """ f = io.BytesIO(file_bytes) src_doc = docx.Document(f) # 1) Collect work units work = [] # list of dict entries describing items with ranges into all_chunks all_chunks: List[str] = [] # paragraphs for par in src_doc.paragraphs: txt = par.text if not txt.strip(): work.append({"kind": "blank"}) continue is_head, lvl = is_heading(par) if is_head: # treat as single chunk (usually short) work.append({"kind": "heading", "level": min(max(lvl, 1), 9), "range": (len(all_chunks), len(all_chunks)+1)}) all_chunks.append(txt.strip()) else: chs = chunk_text_for_translation(txt) if chs: start = len(all_chunks) all_chunks.extend(chs) work.append({"kind": "para", "range": (start, start+len(chs))}) else: work.append({"kind": "blank"}) # tables for t_idx, table in enumerate(src_doc.tables): t_desc = {"kind": "table", "rows": len(table.rows), "cols": len(table.columns), "cells": []} for r_idx, row in enumerate(table.rows): row_cells = [] for c_idx, cell in enumerate(row.cells): cell_text = "\n".join([p.text for p in cell.paragraphs]).strip() if cell_text: chs = chunk_text_for_translation(cell_text) if chs: start = len(all_chunks) all_chunks.extend(chs) row_cells.append({"range": (start, start+len(chs))}) else: row_cells.append({"range": None}) else: row_cells.append({"range": None}) t_desc["cells"].append(row_cells) work.append(t_desc) # 2) Translate all chunks at once (de-dup + batching) if all_chunks: translated_all = translate_chunks_list(all_chunks) else: translated_all = [] # 3) Rebuild new document new_doc = docx.Document() cursor = 0 # index into translated_all # helper to consume a range and join back def join_range(rng: Tuple[int, int]) -> str: if rng is None: return "" s, e = rng return " ".join(translated_all[s:e]).strip() # rebuild paragraphs for item in work: if item["kind"] == "blank": new_doc.add_paragraph("") elif item["kind"] == "heading": text = join_range(item["range"]) new_doc.add_heading(text, level=item["level"]) elif item["kind"] == "para": text = join_range(item["range"]) p = new_doc.add_paragraph(text) p.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY elif item["kind"] == "table": tbl = new_doc.add_table(rows=item["rows"], cols=item["cols"]) for r_idx in range(item["rows"]): for c_idx in range(item["cols"]): cell_info = item["cells"][r_idx][c_idx] txt = join_range(cell_info["range"]) tgt_cell = tbl.cell(r_idx, c_idx) tgt_cell.text = txt for p in tgt_cell.paragraphs: p.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY out = io.BytesIO() new_doc.save(out) return out.getvalue() # ---------- PDF helpers (batched across the whole PDF) ---------- def extract_pdf_text_blocks(pdf_bytes: bytes) -> List[List[str]]: """ Returns list of pages, each a list of block texts (visual order). """ pages_blocks: List[List[str]] = [] doc = fitz.open(stream=pdf_bytes, filetype="pdf") for page in doc: blocks = page.get_text("blocks") blocks.sort(key=lambda b: (round(b[1], 1), round(b[0], 1))) page_texts = [] for b in blocks: text = b[4].strip() if text: page_texts.append(text) pages_blocks.append(page_texts) doc.close() return pages_blocks def build_pdf_from_blocks(translated_pages: List[List[str]]) -> bytes: """ Build a clean paginated PDF with justified paragraphs. """ buf = io.BytesIO() doc = SimpleDocTemplate( buf, pagesize=A4, rightMargin=2*cm, leftMargin=2*cm, topMargin=2*cm, bottomMargin=2*cm ) styles = getSampleStyleSheet() body = styles["BodyText"] body.alignment = TA_JUSTIFY body.leading = 14 story = [] first = True for blocks in translated_pages: if not first: story.append(Spacer(1, 0.1*cm)) # page break trigger first = False for blk in blocks: story.append(RLParagraph(blk.replace("\n", "
"), body)) story.append(Spacer(1, 0.35*cm)) doc.build(story) return buf.getvalue() def translate_pdf_bytes(file_bytes: bytes) -> bytes: """ Read PDF → collect ALL block chunks across pages → single batched translation → rebuild simple justified PDF. """ pages_blocks = extract_pdf_text_blocks(file_bytes) # 1) collect chunks for the entire PDF all_chunks: List[str] = [] plan = [] # list of pages, each a list of ranges for blocks for blocks in pages_blocks: page_plan = [] for blk in blocks: chs = chunk_text_for_translation(blk) if chs: start = len(all_chunks) all_chunks.extend(chs) page_plan.append((start, start + len(chs))) else: page_plan.append(None) plan.append(page_plan) # 2) translate all chunks at once translated_all = translate_chunks_list(all_chunks) if all_chunks else [] # 3) reconstruct per block translated_pages: List[List[str]] = [] for page_plan in plan: page_out = [] for rng in page_plan: if rng is None: page_out.append("") else: s, e = rng page_out.append(" ".join(translated_all[s:e]).strip()) translated_pages.append(page_out) return build_pdf_from_blocks(translated_pages) # ---------- Gradio file handler (robust) ---------- def translate_document(file_obj): """ Accepts gr.File input (NamedString, filepath str, or dict with binary). Returns (output_file_path, status_message). """ if file_obj is None: return None, "Please select a .docx or .pdf file" try: name = "document" data = None # Case A: plain filepath string if isinstance(file_obj, str): name = os.path.basename(file_obj) with open(file_obj, "rb") as f: data = f.read() # Case B: Gradio NamedString with .name (orig name) and .value (temp path) elif hasattr(file_obj, "name") and hasattr(file_obj, "value"): name = os.path.basename(file_obj.name or "document") with open(file_obj.value, "rb") as f: data = f.read() # Case C: dict (type="binary") elif isinstance(file_obj, dict) and "name" in file_obj and "data" in file_obj: name = os.path.basename(file_obj["name"] or "document") d = file_obj["data"] data = d.read() if hasattr(d, "read") else d else: return None, "Unsupported file input type." if data is None: return None, "Could not read the selected file." if name.lower().endswith(".docx"): out_bytes = translate_docx_bytes(data) out_path = "translated_sherbro.docx" with open(out_path, "wb") as f: f.write(out_bytes) return out_path, "✅ DOCX translation completed." elif name.lower().endswith(".pdf"): out_bytes = translate_pdf_bytes(data) out_path = "translated_sherbro.pdf" with open(out_path, "wb") as f: f.write(out_bytes) return out_path, "✅ PDF translation completed." else: return None, "Unsupported file type. Please choose .docx or .pdf" except Exception as e: return None, f"❌ Error during translation: {str(e)}" # ================== UI ================== # Version très simple compatible avec les anciennes versions de Gradio demo = gr.Blocks(title="English → Sherbro Translation Demo") with demo: # Custom CSS via HTML gr.HTML(""" """) with gr.Group(elem_classes=["header-card"]): gr.HTML( """
English → Sherbro Translation Demo
NLLB-200 Distilled 600M Model - For Bible Society Evaluation
Model: nllb-200-distilled-600M-sherbo-v1
""" ) # Information box gr.HTML( """
📖 About This Demo
This demo showcases English to Sherbro translation using the NLLB-200 model. You can translate individual text passages or entire documents (DOCX/PDF). The model is optimized for biblical and religious text translation.
""" ) with gr.Tabs(): # -------- Tab 1: Text Translation -------- with gr.Tab("Text Translation"): with gr.Row(): with gr.Column(scale=5): src = gr.Textbox( label="Source Text (English)", placeholder="Enter your English text here...", lines=8, autofocus=True ) with gr.Row(): btn = gr.Button("Translate", variant="primary", scale=3) clear_btn = gr.Button("Clear", scale=1) gr.Examples( examples=[ ["The Lord is my shepherd; I shall not want."], ["For God so loved the world, that he gave his only begotten Son."], ["Blessed are the peacemakers, for they will be called children of God."], ["Let everything that has breath praise the Lord."] ], inputs=[src], label="Bible Examples (click to use)" ) with gr.Column(scale=5): tgt = gr.Textbox( label="Translation (Sherbro)", lines=8, interactive=False ) gr.HTML('
Tip: Select text and copy with Ctrl+C (Cmd+C on Mac)
') gr.Markdown('') # -------- Tab 2: Document Translation -------- with gr.Tab("Document Translation (.docx / .pdf)"): with gr.Row(): with gr.Column(scale=5): doc_inp = gr.File( label="Select Document (.docx or .pdf)", file_types=[".docx", ".pdf"], type="filepath" ) run_doc = gr.Button("Translate Document", variant="primary") with gr.Column(scale=5): doc_out = gr.File(label="Translated File (download)") doc_status = gr.Markdown("") run_doc.click(translate_document, inputs=doc_inp, outputs=[doc_out, doc_status]) # Text actions btn.click(translate_text_simple, inputs=src, outputs=tgt) clear_btn.click(lambda: ("", ""), outputs=[src, tgt]) if __name__ == "__main__": # For Hugging Face Spaces deployment demo.launch( server_name="0.0.0.0", server_port=7860, share=True, )