Spaces:
Running
Running
| import httpx | |
| import json | |
| import os | |
| import random | |
| from typing import AsyncGenerator, Optional, List, Any | |
| from ai_client import AIClient | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class HuggingFaceClient(AIClient): | |
| def __init__(self): | |
| # Gather all HF tokens from .env | |
| self.tokens = [] | |
| for i in range(1, 10): # Check up to 10 tokens | |
| token = os.getenv(f"HF_TOKEN_{i}") | |
| if token: | |
| self.tokens.append(token.strip()) | |
| if not self.tokens: | |
| # Fallback to single HF_TOKEN if available | |
| main_token = os.getenv("HF_TOKEN") | |
| if main_token: | |
| self.tokens.append(main_token.strip()) | |
| self.client = httpx.AsyncClient(timeout=60.0) | |
| def _get_token(self) -> str: | |
| if not self.tokens: | |
| return "" | |
| return random.choice(self.tokens) | |
| async def async_stream_request( | |
| self, | |
| model_id: str, | |
| prompt: str, | |
| system_prompt: Optional[str] = None, | |
| **kwargs: Any | |
| ) -> AsyncGenerator[str, None]: | |
| token = self._get_token() | |
| headers = { | |
| "Authorization": f"Bearer {token}", | |
| "Content-Type": "application/json" | |
| } | |
| # Format for HF Inference API (Conversational) | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| url = "https://router.huggingface.co/v1/chat/completions" | |
| payload = { | |
| "model": model_id, | |
| "messages": messages, | |
| "stream": True, | |
| "max_tokens": kwargs.get("max_tokens", 2048), | |
| "temperature": kwargs.get("temperature", 0.7) | |
| } | |
| try: | |
| async with self.client.stream("POST", url, headers=headers, json=payload) as response: | |
| if response.status_code != 200: | |
| err_text = await response.aread() | |
| raise Exception(f"Error from HF ({response.status_code}): {err_text.decode()}") | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| data_str = line[6:].strip() | |
| if data_str == "[DONE]": | |
| break | |
| try: | |
| data = json.loads(data_str) | |
| content = data.get("choices", [{}])[0].get("delta", {}).get("content", "") | |
| if content: | |
| yield content | |
| except Exception: | |
| continue | |
| except Exception as e: | |
| raise Exception(f"Connection error: {str(e)}") | |
| async def close(self) -> None: | |
| await self.client.aclose() | |