GLM-4.6-FP8-API / app.py
AARANHA's picture
Create app.py
8c89b4a verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import uvicorn
import os
app = FastAPI(
title="GLM-4.6-FP8 API",
description="API REST funcional para GLM-4.6-FP8 com suporte a múltiplas linguagens",
version="1.0.0"
)
# Modelos cache
model = None
tokenizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"
class ChatRequest(BaseModel):
message: str
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.95
class ChatResponse(BaseModel):
response: str
model: str = "GLM-4.6-FP8"
device: str = device
@app.on_event("startup")
async def startup_event():
global model, tokenizer
try:
print("Carregando modelo GLM-4.6-FP8...")
tokenizer = AutoTokenizer.from_pretrained("zai-org/GLM-4.6-FP8")
model = AutoModelForCausalLM.from_pretrained(
"zai-org/GLM-4.6-FP8",
device_map="auto",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
trust_remote_code=True
)
print("Modelo carregado com sucesso!")
except Exception as e:
print(f"Erro ao carregar modelo: {e}")
raise
@app.get("/")
async def root():
return {
"message": "GLM-4.6-FP8 API",
"version": "1.0.0",
"device": device,
"model_loaded": model is not None,
"endpoints": {
"chat": "/chat",
"generate": "/generate",
"health": "/health"
}
}
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": model is not None,
"device": device
}
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
global model, tokenizer
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Modelo não está carregado")
try:
# Tokenizar entrada
inputs = tokenizer(request.message, return_tensors="pt").to(device)
# Gerar resposta
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
do_sample=True
)
# Decodificar resposta
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return ChatResponse(response=response_text)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erro na geração: {str(e)}")
@app.post("/generate", response_model=ChatResponse)
async def generate(request: ChatRequest):
"""Alias para /chat com formato alternativo"""
return await chat(request)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)