super_AI / huggingface_client.py
bhavesh122's picture
Upload 20 files
78ca118 verified
import httpx
import json
import os
import random
from typing import AsyncGenerator, Optional, List, Any
from ai_client import AIClient
from dotenv import load_dotenv
load_dotenv()
class HuggingFaceClient(AIClient):
def __init__(self):
# Gather all HF tokens from .env
self.tokens = []
for i in range(1, 10): # Check up to 10 tokens
token = os.getenv(f"HF_TOKEN_{i}")
if token:
self.tokens.append(token.strip())
if not self.tokens:
# Fallback to single HF_TOKEN if available
main_token = os.getenv("HF_TOKEN")
if main_token:
self.tokens.append(main_token.strip())
self.client = httpx.AsyncClient(timeout=60.0)
def _get_token(self) -> str:
if not self.tokens:
return ""
return random.choice(self.tokens)
async def async_stream_request(
self,
model_id: str,
prompt: str,
system_prompt: Optional[str] = None,
**kwargs: Any
) -> AsyncGenerator[str, None]:
token = self._get_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
# Format for HF Inference API (Conversational)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
url = "https://router.huggingface.co/v1/chat/completions"
payload = {
"model": model_id,
"messages": messages,
"stream": True,
"max_tokens": kwargs.get("max_tokens", 2048),
"temperature": kwargs.get("temperature", 0.7)
}
try:
async with self.client.stream("POST", url, headers=headers, json=payload) as response:
if response.status_code != 200:
err_text = await response.aread()
raise Exception(f"Error from HF ({response.status_code}): {err_text.decode()}")
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:].strip()
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content:
yield content
except Exception:
continue
except Exception as e:
raise Exception(f"Connection error: {str(e)}")
async def close(self) -> None:
await self.client.aclose()