|
|
"""Mòdul per a l'agent de "reflexion".
|
|
|
|
|
|
Entrenament:
|
|
|
|
|
|
- A partir de parelles (une_ad_auto, une_ad_hitl) per a cada sha1sum, es
|
|
|
comparen les pistes d'audiodescripció (línies amb "(AD)") amb intervals
|
|
|
de temps coincidents.
|
|
|
- Per a cada pista es calcula la durada i les longituds (caràcters i paraules)
|
|
|
i s'etiqueta el cas com S/E/R/X/C:
|
|
|
* S: mateixa longitud aproximada.
|
|
|
* E: alargament de la frase.
|
|
|
* R: reducció de la frase.
|
|
|
* X: eliminació de la frase a la versió HITL.
|
|
|
* C: creació de frase, la versió automàtica era buida/inexistent.
|
|
|
- Es desa un CSV amb les mostres i s'entrena un KNN (K=5) que assigna
|
|
|
probabilitats a cadascun dels casos.
|
|
|
|
|
|
Aplicació:
|
|
|
|
|
|
- Per a un SRT donat, es calculen les mateixes variables per a cada pista
|
|
|
d'(AD) i s'aplica el model KNN per decidir S/E/R/X/C.
|
|
|
- S/C: es deixa el text tal qual.
|
|
|
- X: s'elimina la pista.
|
|
|
- E/R: es demana a GPT-4o-mini que alargui/curti lleugerament la frase,
|
|
|
en una sola crida per a totes les frases afectades.
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import csv
|
|
|
import json
|
|
|
import logging
|
|
|
import math
|
|
|
import os
|
|
|
from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, Iterable, List, Optional, Tuple
|
|
|
|
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
|
|
try:
|
|
|
from sklearn.neighbors import KNeighborsClassifier
|
|
|
import joblib
|
|
|
except Exception:
|
|
|
KNeighborsClassifier = None
|
|
|
joblib = None
|
|
|
|
|
|
from .introspection import _iter_une_vs_hitl_pairs
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
|
REFINEMENT_TEMP_DIR = BASE_DIR / "temp"
|
|
|
REFINEMENT_TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
|
REFLEXION_CSV_PATH = REFINEMENT_TEMP_DIR / "reflexion.csv"
|
|
|
REFLEXION_MODEL_PATH = REFINEMENT_TEMP_DIR / "reflexion_knn.joblib"
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class AdCue:
|
|
|
start: float
|
|
|
end: float
|
|
|
text: str
|
|
|
block_lines: List[str]
|
|
|
|
|
|
@property
|
|
|
def duration(self) -> float:
|
|
|
return max(0.0, self.end - self.start)
|
|
|
|
|
|
@property
|
|
|
def char_len(self) -> int:
|
|
|
return len(self.text)
|
|
|
|
|
|
@property
|
|
|
def word_len(self) -> int:
|
|
|
return len(self.text.split())
|
|
|
|
|
|
|
|
|
def _parse_timestamp(ts: str) -> float:
|
|
|
"""Converteix un timestamp SRT HH:MM:SS,mmm a segons."""
|
|
|
|
|
|
try:
|
|
|
hh, mm, rest = ts.split(":")
|
|
|
ss, ms = rest.split(",")
|
|
|
return int(hh) * 3600 + int(mm) * 60 + int(ss) + int(ms) / 1000.0
|
|
|
except Exception:
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
def _parse_srt_ad_cues(srt_content: str) -> List[AdCue]:
|
|
|
"""Extreu pistes d'(AD) d'un SRT.
|
|
|
|
|
|
Retorna una llista d'AdCue amb start/end, text (sense el prefix "(AD)") i
|
|
|
les línies de bloc originals per poder reconstruir l'SRT.
|
|
|
"""
|
|
|
|
|
|
lines = srt_content.splitlines()
|
|
|
i = 0
|
|
|
cues: List[AdCue] = []
|
|
|
|
|
|
while i < len(lines):
|
|
|
|
|
|
if not lines[i].strip():
|
|
|
i += 1
|
|
|
continue
|
|
|
|
|
|
|
|
|
idx_line = lines[i].strip()
|
|
|
i += 1
|
|
|
if i >= len(lines):
|
|
|
break
|
|
|
|
|
|
|
|
|
if "-->" not in lines[i]:
|
|
|
|
|
|
continue
|
|
|
|
|
|
time_line = lines[i].strip()
|
|
|
i += 1
|
|
|
try:
|
|
|
start_str, end_str = [part.strip() for part in time_line.split("-->")]
|
|
|
except ValueError:
|
|
|
continue
|
|
|
|
|
|
start = _parse_timestamp(start_str)
|
|
|
end = _parse_timestamp(end_str)
|
|
|
|
|
|
text_lines: List[str] = []
|
|
|
while i < len(lines) and lines[i].strip():
|
|
|
text_lines.append(lines[i])
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ad_text_parts: List[str] = []
|
|
|
for tl in text_lines:
|
|
|
if "(AD)" in tl:
|
|
|
|
|
|
after = tl.split("(AD)", 1)[1].strip()
|
|
|
if after:
|
|
|
ad_text_parts.append(after)
|
|
|
|
|
|
if not ad_text_parts:
|
|
|
continue
|
|
|
|
|
|
ad_text = " ".join(ad_text_parts).strip()
|
|
|
block_lines = [idx_line, time_line] + text_lines
|
|
|
cues.append(AdCue(start=start, end=end, text=ad_text, block_lines=block_lines))
|
|
|
|
|
|
return cues
|
|
|
|
|
|
|
|
|
def _intervals_overlap(a_start: float, a_end: float, b_start: float, b_end: float) -> bool:
|
|
|
return max(a_start, b_start) < min(a_end, b_end)
|
|
|
|
|
|
|
|
|
def _build_training_rows() -> List[Tuple[float, int, int, str]]:
|
|
|
"""Construeix files d'entrenament (dur, chars, words, label) a partir de
|
|
|
les parelles (une_ad_auto, une_ad_hitl).
|
|
|
"""
|
|
|
|
|
|
rows: List[Tuple[float, int, int, str]] = []
|
|
|
|
|
|
for sha1sum, une_auto, une_hitl in _iter_une_vs_hitl_pairs():
|
|
|
auto_cues = _parse_srt_ad_cues(une_auto)
|
|
|
hitl_cues = _parse_srt_ad_cues(une_hitl)
|
|
|
|
|
|
|
|
|
for ac in auto_cues:
|
|
|
|
|
|
matching: Optional[AdCue] = None
|
|
|
for hc in hitl_cues:
|
|
|
if _intervals_overlap(ac.start, ac.end, hc.start, hc.end):
|
|
|
matching = hc
|
|
|
break
|
|
|
|
|
|
if matching is None:
|
|
|
|
|
|
if ac.text.strip():
|
|
|
rows.append((ac.duration, ac.char_len, ac.word_len, "X"))
|
|
|
continue
|
|
|
|
|
|
|
|
|
auto_text = ac.text.strip()
|
|
|
hitl_text = matching.text.strip()
|
|
|
|
|
|
if not auto_text and hitl_text:
|
|
|
|
|
|
rows.append((matching.duration, 0, 0, "C"))
|
|
|
continue
|
|
|
|
|
|
if not auto_text and not hitl_text:
|
|
|
continue
|
|
|
|
|
|
|
|
|
auto_chars = len(auto_text)
|
|
|
hitl_chars = len(hitl_text)
|
|
|
|
|
|
|
|
|
diff = hitl_chars - auto_chars
|
|
|
if abs(diff) <= max(5, 0.1 * auto_chars):
|
|
|
label = "S"
|
|
|
elif diff > 0:
|
|
|
label = "E"
|
|
|
else:
|
|
|
label = "R"
|
|
|
|
|
|
rows.append((ac.duration, ac.char_len, ac.word_len, label))
|
|
|
|
|
|
|
|
|
for hc in hitl_cues:
|
|
|
has_auto = any(
|
|
|
_intervals_overlap(hc.start, hc.end, ac.start, ac.end) for ac in auto_cues
|
|
|
)
|
|
|
if not has_auto and hc.text.strip():
|
|
|
rows.append((hc.duration, 0, 0, "C"))
|
|
|
|
|
|
return rows
|
|
|
|
|
|
|
|
|
def train_reflexion_model(max_examples: Optional[int] = None) -> None:
|
|
|
"""Entrena el model KNN de reflexion i desa CSV + model.
|
|
|
|
|
|
- Construeix ``reflexion.csv`` amb files ``duracion,char_len,word_len,label``.
|
|
|
- Entrena un KNN (K=5) i el desa a ``reflexion_knn.joblib``.
|
|
|
"""
|
|
|
|
|
|
if KNeighborsClassifier is None or joblib is None:
|
|
|
logger.warning(
|
|
|
"sklearn/joblib no disponibles; el mòdul de reflexion no es pot entrenar."
|
|
|
)
|
|
|
return
|
|
|
|
|
|
rows = _build_training_rows()
|
|
|
if not rows:
|
|
|
logger.warning("No s'han pogut generar files d'entrenament per a reflexion.")
|
|
|
return
|
|
|
|
|
|
if max_examples is not None:
|
|
|
rows = rows[:max_examples]
|
|
|
|
|
|
|
|
|
with REFLEXION_CSV_PATH.open("w", newline="", encoding="utf-8") as f:
|
|
|
writer = csv.writer(f)
|
|
|
writer.writerow(["duration", "char_len", "word_len", "label"])
|
|
|
for dur, cl, wl, lab in rows:
|
|
|
writer.writerow([f"{dur:.3f}", cl, wl, lab])
|
|
|
|
|
|
X = [[dur, cl, wl] for dur, cl, wl, _ in rows]
|
|
|
y = [lab for _, _, _, lab in rows]
|
|
|
|
|
|
knn = KNeighborsClassifier(n_neighbors=5, weights="distance")
|
|
|
knn.fit(X, y)
|
|
|
|
|
|
joblib.dump(knn, REFLEXION_MODEL_PATH)
|
|
|
logger.info(
|
|
|
"Model de reflexion entrenat amb %d mostres i desat a %s",
|
|
|
len(rows),
|
|
|
REFLEXION_MODEL_PATH,
|
|
|
)
|
|
|
|
|
|
|
|
|
def _load_reflexion_model():
|
|
|
if KNeighborsClassifier is None or joblib is None:
|
|
|
return None
|
|
|
if not REFLEXION_MODEL_PATH.exists():
|
|
|
return None
|
|
|
try:
|
|
|
return joblib.load(REFLEXION_MODEL_PATH)
|
|
|
except Exception:
|
|
|
logger.warning("No s'ha pogut carregar el model de reflexion de %s", REFLEXION_MODEL_PATH)
|
|
|
return None
|
|
|
|
|
|
|
|
|
def _get_llm() -> Optional[ChatOpenAI]:
|
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
|
if not api_key:
|
|
|
logger.warning("OPENAI_API_KEY no está configurada; se omite la reflexion.")
|
|
|
return None
|
|
|
try:
|
|
|
return ChatOpenAI(model="gpt-4o-mini", temperature=0.0, api_key=api_key)
|
|
|
except Exception as exc:
|
|
|
logger.error("No se pudo inicializar ChatOpenAI para reflexion: %s", exc)
|
|
|
return None
|
|
|
|
|
|
|
|
|
def _apply_knn_to_cues(cues: List[AdCue]) -> List[str]:
|
|
|
"""Retorna una etiqueta S/E/R/X/C per a cada cue.
|
|
|
|
|
|
Per simplicitat, les pistes amb durada o longitud zero es marquen com "S" si
|
|
|
no hi ha model.
|
|
|
"""
|
|
|
|
|
|
model = _load_reflexion_model()
|
|
|
if model is None:
|
|
|
return ["S" for _ in cues]
|
|
|
|
|
|
X = [[c.duration, c.char_len, c.word_len] for c in cues]
|
|
|
try:
|
|
|
probs = model.predict_proba(X)
|
|
|
classes = list(model.classes_)
|
|
|
labels: List[str] = []
|
|
|
for p in probs:
|
|
|
idx = int(p.argmax())
|
|
|
labels.append(str(classes[idx]))
|
|
|
return labels
|
|
|
except Exception as exc:
|
|
|
logger.error("Error aplicant el model de reflexion: %s", exc)
|
|
|
return ["S" for _ in cues]
|
|
|
|
|
|
|
|
|
def _ask_llm_for_length_adjustments(cues: List[AdCue], labels: List[str]) -> Dict[int, str]:
|
|
|
"""Demana al LLM que alargui/curti frases segons E/R.
|
|
|
|
|
|
Retorna un mapa {index_cue -> nou_text}."""
|
|
|
|
|
|
llm = _get_llm()
|
|
|
if llm is None:
|
|
|
return {}
|
|
|
|
|
|
items: List[Dict[str, str]] = []
|
|
|
for idx, (cue, lab) in enumerate(zip(cues, labels)):
|
|
|
if lab not in {"E", "R"}:
|
|
|
continue
|
|
|
items.append({"id": str(idx), "case": lab, "text": cue.text})
|
|
|
|
|
|
if not items:
|
|
|
return {}
|
|
|
|
|
|
system = SystemMessage(
|
|
|
content=(
|
|
|
"Ets un assistent que ajusta lleugerament la longitud de frases d'"
|
|
|
"audiodescripció en català. \n"
|
|
|
"Rebràs una llista d'objectes JSON amb camps 'id', 'case' (E o R) i "
|
|
|
"'text'. \n"
|
|
|
"Per a cada element has de tornar un nou text que: \n"
|
|
|
"- Si 'case' és 'E': sigui una mica més llarg (afegint detalls" \
|
|
|
" suaus, sense canviar el sentit).\n"
|
|
|
"- Si 'case' és 'R': sigui una mica més curt, més concís, mantenint el" \
|
|
|
" sentit principal.\n"
|
|
|
"Respon EXCLUSIVAMENT en JSON de la forma:\n"
|
|
|
"{\"segments\":[{\"id\":\"...\",\"new_text\":\"...\"}, ...]}"
|
|
|
)
|
|
|
)
|
|
|
|
|
|
user = HumanMessage(content=json.dumps({"segments": items}, ensure_ascii=False))
|
|
|
|
|
|
try:
|
|
|
resp = llm.invoke([system, user])
|
|
|
except Exception as exc:
|
|
|
logger.error("Error llamando al LLM en reflexion (ajustes E/R): %s", exc)
|
|
|
return {}
|
|
|
|
|
|
text = resp.content if isinstance(resp.content, str) else str(resp.content)
|
|
|
try:
|
|
|
data = json.loads(text)
|
|
|
except json.JSONDecodeError:
|
|
|
logger.warning("Respuesta del LLM en reflexion no es JSON válido: %s", text[:2000])
|
|
|
return {}
|
|
|
|
|
|
result: Dict[int, str] = {}
|
|
|
for seg in data.get("segments", []):
|
|
|
try:
|
|
|
idx = int(seg.get("id"))
|
|
|
except Exception:
|
|
|
continue
|
|
|
new_text = str(seg.get("new_text", "")).strip()
|
|
|
if new_text:
|
|
|
result[idx] = new_text
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
def refine_srt_with_reflexion(srt_content: str) -> str:
|
|
|
"""Aplica el pas de "reflexion" sobre un SRT.
|
|
|
|
|
|
- Usa un model KNN entrenat per decidir, per a cada pista d'(AD), si cal
|
|
|
mantenir-la, eliminar-la o ajustar-ne la longitud.
|
|
|
- Per a casos E/R, demana al LLM una versió lleugerament més llarga/curta.
|
|
|
- Si no hi ha model o LLM, retorna el SRT original.
|
|
|
"""
|
|
|
|
|
|
cues = _parse_srt_ad_cues(srt_content)
|
|
|
if not cues:
|
|
|
return srt_content
|
|
|
|
|
|
labels = _apply_knn_to_cues(cues)
|
|
|
|
|
|
|
|
|
adjustments = _ask_llm_for_length_adjustments(cues, labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cue_by_interval: Dict[Tuple[float, float], Tuple[int, AdCue]] = {}
|
|
|
for idx, cue in enumerate(cues):
|
|
|
cue_by_interval[(cue.start, cue.end)] = (idx, cue)
|
|
|
|
|
|
lines = srt_content.splitlines()
|
|
|
i = 0
|
|
|
out_lines: List[str] = []
|
|
|
|
|
|
while i < len(lines):
|
|
|
if not lines[i].strip():
|
|
|
out_lines.append(lines[i])
|
|
|
i += 1
|
|
|
continue
|
|
|
|
|
|
idx_line = lines[i]
|
|
|
i += 1
|
|
|
if i >= len(lines):
|
|
|
out_lines.append(idx_line)
|
|
|
break
|
|
|
|
|
|
time_line = lines[i]
|
|
|
i += 1
|
|
|
if "-->" not in time_line:
|
|
|
|
|
|
out_lines.append(idx_line)
|
|
|
out_lines.append(time_line)
|
|
|
continue
|
|
|
|
|
|
|
|
|
try:
|
|
|
start_str, end_str = [part.strip() for part in time_line.strip().split("-->")]
|
|
|
start = _parse_timestamp(start_str)
|
|
|
end = _parse_timestamp(end_str)
|
|
|
except Exception:
|
|
|
start = end = math.nan
|
|
|
|
|
|
text_block: List[str] = []
|
|
|
while i < len(lines) and lines[i].strip():
|
|
|
text_block.append(lines[i])
|
|
|
i += 1
|
|
|
|
|
|
key = (start, end)
|
|
|
if key not in cue_by_interval:
|
|
|
|
|
|
out_lines.append(idx_line)
|
|
|
out_lines.append(time_line)
|
|
|
out_lines.extend(text_block)
|
|
|
if i < len(lines) and not lines[i].strip():
|
|
|
out_lines.append(lines[i])
|
|
|
i += 1
|
|
|
continue
|
|
|
|
|
|
cue_idx, cue = cue_by_interval[key]
|
|
|
label = labels[cue_idx] if cue_idx < len(labels) else "S"
|
|
|
|
|
|
if label == "X":
|
|
|
|
|
|
if i < len(lines) and not lines[i].strip():
|
|
|
i += 1
|
|
|
continue
|
|
|
|
|
|
|
|
|
new_text = adjustments.get(cue_idx)
|
|
|
if new_text:
|
|
|
|
|
|
new_block: List[str] = []
|
|
|
replaced = False
|
|
|
for tl in text_block:
|
|
|
if "(AD)" in tl and not replaced:
|
|
|
prefix, _ = tl.split("(AD)", 1)
|
|
|
new_block.append(prefix + "(AD) " + new_text)
|
|
|
replaced = True
|
|
|
else:
|
|
|
new_block.append(tl)
|
|
|
text_block = new_block
|
|
|
|
|
|
out_lines.append(idx_line)
|
|
|
out_lines.append(time_line)
|
|
|
out_lines.extend(text_block)
|
|
|
if i < len(lines) and not lines[i].strip():
|
|
|
out_lines.append(lines[i])
|
|
|
i += 1
|
|
|
|
|
|
return "\n".join(out_lines)
|
|
|
|