File size: 4,653 Bytes
287f01b b17b915 287f01b b17b915 ed7696e b17b915 648c0b6 b17b915 287f01b 648c0b6 b17b915 648c0b6 b17b915 ed7696e b17b915 ed7696e 648c0b6 ed7696e b17b915 ed7696e 648c0b6 ed7696e b17b915 ed7696e 648c0b6 ed7696e b17b915 ed7696e 648c0b6 ed7696e b17b915 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# llm_router.py — enruta llamadas a Spaces remotos según config.yaml
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pathlib import Path
import os
import yaml
from remote_clients import InstructClient, VisionClient, ToolsClient, ASRClient
import time
def load_yaml(path: str) -> Dict[str, Any]:
p = Path(path)
if not p.exists():
return {}
return yaml.safe_load(p.read_text(encoding="utf-8")) or {}
class LLMRouter:
def __init__(self, cfg: Dict[str, Any]):
self.cfg = cfg
self.rem = cfg.get("models", {}).get("routing", {}).get("use_remote_for", [])
base_user = cfg.get("remote_spaces", {}).get("user", "veureu")
eps = cfg.get("remote_spaces", {}).get("endpoints", {})
token_enabled = cfg.get("security", {}).get("use_hf_token", False)
hf_token = os.getenv(cfg.get("security", {}).get("hf_token_env", "HF_TOKEN")) if token_enabled else None
def mk_factory(endpoint_key: str, cls):
info = eps.get(endpoint_key, {})
base_url = info.get("base_url") or f"https://{base_user}-{info.get('space')}.hf.space"
use_gradio = (info.get("client", "gradio") == "gradio")
timeout = int(cfg.get("remote_spaces", {}).get("http", {}).get("timeout_seconds", 180))
def _factory():
return cls(base_url=base_url, use_gradio=use_gradio, hf_token=hf_token, timeout=timeout)
return _factory
self.client_factories = {
"salamandra-instruct": mk_factory("salamandra-instruct", InstructClient),
"salamandra-vision": mk_factory("salamandra-vision", VisionClient),
"salamandra-tools": mk_factory("salamandra-tools", ToolsClient),
"whisper-catalan": mk_factory("whisper-catalan", ASRClient),
}
self.service_names = {
"salamandra-instruct": "schat",
"salamandra-vision": "svision",
"salamandra-tools": "stools",
"whisper-catalan": "asr",
}
def _log_connect(self, model_key: str, phase: str, elapsed: float | None = None):
svc = self.service_names.get(model_key, model_key)
if phase == "connect":
print(f"[LLMRouter] Connecting to {svc} space...")
elif phase == "done":
print(f"[LLMRouter] Response from {svc} space received in {elapsed:.2f} s")
# ---- INSTRUCT ----
def instruct(self, prompt: str, system: Optional[str] = None, model: str = "salamandra-instruct", **kwargs) -> str:
if model in self.rem:
self._log_connect(model, "connect")
t0 = time.time()
client = self.client_factories[model]()
out = client.generate(prompt, system=system, **kwargs) # type: ignore
self._log_connect(model, "done", time.time() - t0)
return out
raise RuntimeError(f"Modelo local no implementado para: {model}")
# ---- VISION ----
def vision_describe(self, image_paths: List[str], context: Optional[Dict[str, Any]] = None, model: str = "salamandra-vision", **kwargs) -> List[str]:
if model in self.rem:
self._log_connect(model, "connect")
t0 = time.time()
client = self.client_factories[model]()
out = client.describe(image_paths, context=context, **kwargs) # type: ignore
self._log_connect(model, "done", time.time() - t0)
return out
raise RuntimeError(f"Modelo local no implementado para: {model}")
# ---- TOOLS ----
def chat_with_tools(self, messages: List[Dict[str, str]], tools: Optional[List[Dict[str, Any]]] = None, model: str = "salamandra-tools", **kwargs) -> Dict[str, Any]:
if model in self.rem:
self._log_connect(model, "connect")
t0 = time.time()
client = self.client_factories[model]()
out = client.chat(messages, tools=tools, **kwargs) # type: ignore
self._log_connect(model, "done", time.time() - t0)
return out
raise RuntimeError(f"Modelo local no implementado para: {model}")
# ---- ASR ----
def asr_transcribe(self, audio_path: str, model: str = "whisper-catalan", **kwargs) -> Dict[str, Any]:
if model in self.rem:
self._log_connect(model, "connect")
t0 = time.time()
client = self.client_factories[model]()
out = client.transcribe(audio_path, **kwargs) # type: ignore
self._log_connect(model, "done", time.time() - t0)
return out
raise RuntimeError(f"Modelo local no implementado para: {model}")
|