import os
import io
import re
from typing import List, Tuple, Dict
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# --- NEW: docs ---
import docx
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.text.paragraph import Paragraph
# PDF read & write
import fitz # PyMuPDF
from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib.enums import TA_JUSTIFY
from reportlab.platypus import SimpleDocTemplate, Paragraph as RLParagraph, Spacer
from reportlab.lib.units import cm
# ================= CONFIG =================
MODEL_REPO = "Toadoum/nllb-200-distilled-600M-sherbo-v1"
ENG_CODE = "eng_Latn" # English (source)
SHER_CODE = "sher_Latn" # Sherbro (target)
# Inference
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.0
NUM_BEAMS = 4 # Increased for better quality
# Performance knobs
MAX_SRC_TOKENS = 420 # per chunk; reduce to ~320 if you want even faster
BATCH_SIZE = 8 # reduced batch size for larger model
# Device selection - CORRECTED
device = 0 if torch.cuda.is_available() else "cpu"
# Load model & tokenizer once
print("Loading model...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_REPO)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
print("Trying to load with local_files_only...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, local_files_only=True)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_REPO, local_files_only=True)
print("Model loaded from cache successfully!")
except:
raise
print(f"Device set to use {device}")
# Use dtype instead of torch_dtype to avoid deprecation warning
translator = pipeline(
task="translation",
model=model,
tokenizer=tokenizer,
device=device,
dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
# Simple text box translation (kept)
def translate_text_simple(text: str) -> str:
if not text or not text.strip():
return ""
with torch.no_grad():
out = translator(
text,
src_lang=ENG_CODE,
tgt_lang=SHER_CODE,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
num_beams=NUM_BEAMS,
)
return out[0]["translation_text"]
# ---------- Chunking + Batched Translation + Cache ----------
def tokenize_len(s: str) -> int:
return len(tokenizer.encode(s, add_special_tokens=False))
def chunk_text_for_translation(text: str, max_src_tokens: int = MAX_SRC_TOKENS) -> List[str]:
"""Split text by sentence-ish boundaries and merge under token limit."""
if not text.strip():
return []
parts = re.split(r'(\s*[\.\!\?…:;]\s+)', text)
sentences = []
for i in range(0, len(parts), 2):
s = parts[i]
p = parts[i+1] if i+1 < len(parts) else ""
unit = (s + (p or "")).strip()
if unit:
sentences.append(unit)
chunks, current = [], ""
for sent in sentences:
candidate = (current + " " + sent).strip() if current else sent
if current and tokenize_len(candidate) > max_src_tokens:
chunks.append(current.strip())
current = sent
else:
current = candidate
if current.strip():
chunks.append(current.strip())
return chunks
# module-level cache: identical chunks translated once
TRANSLATION_CACHE: Dict[str, str] = {}
def translate_chunks_list(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[str]:
"""
Translate a list of chunks with de-dup + batching.
Returns translations in the same order as input.
"""
# Normalize & collect unique chunks to translate
norm_chunks = [c.strip() for c in chunks]
to_translate = []
for c in norm_chunks:
if c and c not in TRANSLATION_CACHE:
to_translate.append(c)
# Batched calls
with torch.no_grad():
for i in range(0, len(to_translate), batch_size):
batch = to_translate[i:i + batch_size]
try:
outs = translator(
batch,
src_lang=ENG_CODE,
tgt_lang=SHER_CODE,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
num_beams=NUM_BEAMS,
)
for src, o in zip(batch, outs):
TRANSLATION_CACHE[src] = o["translation_text"]
except Exception as e:
print(f"Error translating batch {i}: {e}")
# Fallback: translate one by one
for src in batch:
try:
out = translator(
src,
src_lang=ENG_CODE,
tgt_lang=SHER_CODE,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
num_beams=1, # simpler for fallback
)
TRANSLATION_CACHE[src] = out[0]["translation_text"]
except:
TRANSLATION_CACHE[src] = f"[Translation Error: {src[:50]}...]"
return [TRANSLATION_CACHE.get(c, "") for c in norm_chunks]
def translate_long_text(text: str) -> str:
"""Chunk → batch translate → rejoin for one paragraph/block."""
chs = chunk_text_for_translation(text)
if not chs:
return ""
trs = translate_chunks_list(chs)
# join with space to reconstruct paragraph smoothly
return " ".join(trs).strip()
# ---------- DOCX helpers (now fully batched across the whole doc) ----------
def is_heading(par: Paragraph) -> Tuple[bool, int]:
style = (par.style.name or "").lower()
if "heading" in style:
for lvl in range(1, 10):
if str(lvl) in style:
return True, lvl
return True, 1
return False, 0
def translate_docx_bytes(file_bytes: bytes) -> bytes:
"""
Read .docx → collect ALL chunks (paras + table cells) → single batched translation → rebuild .docx.
"""
f = io.BytesIO(file_bytes)
src_doc = docx.Document(f)
# 1) Collect work units
work = [] # list of dict entries describing items with ranges into all_chunks
all_chunks: List[str] = []
# paragraphs
for par in src_doc.paragraphs:
txt = par.text
if not txt.strip():
work.append({"kind": "blank"})
continue
is_head, lvl = is_heading(par)
if is_head:
# treat as single chunk (usually short)
work.append({"kind": "heading", "level": min(max(lvl, 1), 9), "range": (len(all_chunks), len(all_chunks)+1)})
all_chunks.append(txt.strip())
else:
chs = chunk_text_for_translation(txt)
if chs:
start = len(all_chunks)
all_chunks.extend(chs)
work.append({"kind": "para", "range": (start, start+len(chs))})
else:
work.append({"kind": "blank"})
# tables
for t_idx, table in enumerate(src_doc.tables):
t_desc = {"kind": "table", "rows": len(table.rows), "cols": len(table.columns), "cells": []}
for r_idx, row in enumerate(table.rows):
row_cells = []
for c_idx, cell in enumerate(row.cells):
cell_text = "\n".join([p.text for p in cell.paragraphs]).strip()
if cell_text:
chs = chunk_text_for_translation(cell_text)
if chs:
start = len(all_chunks)
all_chunks.extend(chs)
row_cells.append({"range": (start, start+len(chs))})
else:
row_cells.append({"range": None})
else:
row_cells.append({"range": None})
t_desc["cells"].append(row_cells)
work.append(t_desc)
# 2) Translate all chunks at once (de-dup + batching)
if all_chunks:
translated_all = translate_chunks_list(all_chunks)
else:
translated_all = []
# 3) Rebuild new document
new_doc = docx.Document()
cursor = 0 # index into translated_all
# helper to consume a range and join back
def join_range(rng: Tuple[int, int]) -> str:
if rng is None:
return ""
s, e = rng
return " ".join(translated_all[s:e]).strip()
# rebuild paragraphs
for item in work:
if item["kind"] == "blank":
new_doc.add_paragraph("")
elif item["kind"] == "heading":
text = join_range(item["range"])
new_doc.add_heading(text, level=item["level"])
elif item["kind"] == "para":
text = join_range(item["range"])
p = new_doc.add_paragraph(text)
p.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY
elif item["kind"] == "table":
tbl = new_doc.add_table(rows=item["rows"], cols=item["cols"])
for r_idx in range(item["rows"]):
for c_idx in range(item["cols"]):
cell_info = item["cells"][r_idx][c_idx]
txt = join_range(cell_info["range"])
tgt_cell = tbl.cell(r_idx, c_idx)
tgt_cell.text = txt
for p in tgt_cell.paragraphs:
p.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY
out = io.BytesIO()
new_doc.save(out)
return out.getvalue()
# ---------- PDF helpers (batched across the whole PDF) ----------
def extract_pdf_text_blocks(pdf_bytes: bytes) -> List[List[str]]:
"""
Returns list of pages, each a list of block texts (visual order).
"""
pages_blocks: List[List[str]] = []
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
for page in doc:
blocks = page.get_text("blocks")
blocks.sort(key=lambda b: (round(b[1], 1), round(b[0], 1)))
page_texts = []
for b in blocks:
text = b[4].strip()
if text:
page_texts.append(text)
pages_blocks.append(page_texts)
doc.close()
return pages_blocks
def build_pdf_from_blocks(translated_pages: List[List[str]]) -> bytes:
"""
Build a clean paginated PDF with justified paragraphs.
"""
buf = io.BytesIO()
doc = SimpleDocTemplate(
buf, pagesize=A4,
rightMargin=2*cm, leftMargin=2*cm,
topMargin=2*cm, bottomMargin=2*cm
)
styles = getSampleStyleSheet()
body = styles["BodyText"]
body.alignment = TA_JUSTIFY
body.leading = 14
story = []
first = True
for blocks in translated_pages:
if not first:
story.append(Spacer(1, 0.1*cm)) # page break trigger
first = False
for blk in blocks:
story.append(RLParagraph(blk.replace("\n", "
"), body))
story.append(Spacer(1, 0.35*cm))
doc.build(story)
return buf.getvalue()
def translate_pdf_bytes(file_bytes: bytes) -> bytes:
"""
Read PDF → collect ALL block chunks across pages → single batched translation → rebuild simple justified PDF.
"""
pages_blocks = extract_pdf_text_blocks(file_bytes)
# 1) collect chunks for the entire PDF
all_chunks: List[str] = []
plan = [] # list of pages, each a list of ranges for blocks
for blocks in pages_blocks:
page_plan = []
for blk in blocks:
chs = chunk_text_for_translation(blk)
if chs:
start = len(all_chunks)
all_chunks.extend(chs)
page_plan.append((start, start + len(chs)))
else:
page_plan.append(None)
plan.append(page_plan)
# 2) translate all chunks at once
translated_all = translate_chunks_list(all_chunks) if all_chunks else []
# 3) reconstruct per block
translated_pages: List[List[str]] = []
for page_plan in plan:
page_out = []
for rng in page_plan:
if rng is None:
page_out.append("")
else:
s, e = rng
page_out.append(" ".join(translated_all[s:e]).strip())
translated_pages.append(page_out)
return build_pdf_from_blocks(translated_pages)
# ---------- Gradio file handler (robust) ----------
def translate_document(file_obj):
"""
Accepts gr.File input (NamedString, filepath str, or dict with binary).
Returns (output_file_path, status_message).
"""
if file_obj is None:
return None, "Please select a .docx or .pdf file"
try:
name = "document"
data = None
# Case A: plain filepath string
if isinstance(file_obj, str):
name = os.path.basename(file_obj)
with open(file_obj, "rb") as f:
data = f.read()
# Case B: Gradio NamedString with .name (orig name) and .value (temp path)
elif hasattr(file_obj, "name") and hasattr(file_obj, "value"):
name = os.path.basename(file_obj.name or "document")
with open(file_obj.value, "rb") as f:
data = f.read()
# Case C: dict (type="binary")
elif isinstance(file_obj, dict) and "name" in file_obj and "data" in file_obj:
name = os.path.basename(file_obj["name"] or "document")
d = file_obj["data"]
data = d.read() if hasattr(d, "read") else d
else:
return None, "Unsupported file input type."
if data is None:
return None, "Could not read the selected file."
if name.lower().endswith(".docx"):
out_bytes = translate_docx_bytes(data)
out_path = "translated_sherbro.docx"
with open(out_path, "wb") as f:
f.write(out_bytes)
return out_path, "✅ DOCX translation completed."
elif name.lower().endswith(".pdf"):
out_bytes = translate_pdf_bytes(data)
out_path = "translated_sherbro.pdf"
with open(out_path, "wb") as f:
f.write(out_bytes)
return out_path, "✅ PDF translation completed."
else:
return None, "Unsupported file type. Please choose .docx or .pdf"
except Exception as e:
return None, f"❌ Error during translation: {str(e)}"
# ================== UI ==================
# Version très simple compatible avec les anciennes versions de Gradio
demo = gr.Blocks(title="English → Sherbro Translation Demo")
with demo:
# Custom CSS via HTML
gr.HTML("""
""")
with gr.Group(elem_classes=["header-card"]):
gr.HTML(
"""