import spaces import torch import numpy as np from typing import Generator from transformers import AutoModelForCausalLM, AutoTokenizer from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE # Global variables to store the model and tokenizer tokenizer = None model = None def initialize_model(): """Initializes and loads the model and tokenizer once onto the GPU.""" global tokenizer, model if model is None: try: print(f"Loading model {MODEL_NAME}...") # Use bfloat16 for efficiency on modern GPUs dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=dtype, device_map="auto" ) model.eval() # Set padding token if not defined if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id print("Model loaded successfully.") except Exception as e: print(f"Failed to load model: {e}") raise return tokenizer, model # Call initialization try: initialize_model() except Exception as e: print(f"Warning: Global model initialization failed: {e}") @spaces.GPU(duration=120) def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]: """ Generates a response from the KAT model with proper streaming. """ global tokenizer, model # Fallback initialization if model is None or tokenizer is None: initialize_model() # Convert Gradio history format to the model's chat template format messages = [] for human, bot in history: if human: messages.append({"role": "user", "content": human}) if bot: messages.append({"role": "assistant", "content": bot}) # Add the current prompt messages.append({"role": "user", "content": prompt}) # Apply chat template text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Tokenize with attention mask inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) # Store initial input length initial_length = input_ids.shape[-1] # Generate with streaming using yield-based approach accumulated_text = "" generated_tokens = 0 # Generate tokens incrementally while generated_tokens < MAX_NEW_TOKENS: with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) # Get next token probabilities next_token_logits = outputs.logits[:, -1, :] # Apply temperature if TEMPERATURE > 0: next_token_logits = next_token_logits / TEMPERATURE # Apply softmax and sample probs = torch.softmax(next_token_logits, dim=-1) if DO_SAMPLE: next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # Check for EOS token if next_token.item() == tokenizer.eos_token_id: break # Decode the new token new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True) # Update accumulated text accumulated_text += new_token_text # Yield the current accumulated text yield accumulated_text # Prepare for next iteration input_ids = torch.cat([input_ids, next_token], dim=-1) attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1) # Increment generated tokens counter generated_tokens += 1 # Final yield to ensure complete text if accumulated_text: yield accumulated_text.strip()