import os import time import gradio as gr from unsloth import FastLanguageModel from transformers import TrainingArguments, TextStreamer from datasets import load_dataset # --- Environment Fixes --- os.environ["OMP_NUM_THREADS"] = "1" # --- Load model --- model_name = "EpistemeAI/Episteme-gptoss-20b-RL" print(f"Loading model: {model_name}") model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, load_in_4bit=False, dtype=None, device_map="auto", ) # --- Simple memory for ongoing chat --- conversation_history = [] training_data = [] TRAIN_EVERY_N_PROMPTS = 8 training_status = "Idle" # --- Function: Generate chat response --- def chat_response(user_input): global training_data, training_status # Record the prompt conversation_history.append({"role": "user", "content": user_input}) # Build prompt full_prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation_history]) inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) # Stream response streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) model.generate(**inputs, streamer=streamer, max_new_tokens=200) response_text = streamer.text.strip() # Store assistant message conversation_history.append({"role": "assistant", "content": response_text}) training_data.append({"instruction": user_input, "expected_output": response_text}) # Auto-train trigger if len(training_data) >= TRAIN_EVERY_N_PROMPTS: training_status = "Training..." yield [ conversation_history, "Auto-training in progress... this may take a moment." ] train_model(training_data) training_data.clear() training_status = "Idle" yield [conversation_history, f"Assistant: {response_text}"] # --- Function: Train model on recent conversations --- def train_model(data_samples): dataset = { "instruction": [d["instruction"] for d in data_samples], "expected_output": [d["expected_output"] for d in data_samples], } training_args = TrainingArguments( output_dir="./results", per_device_train_batch_size=2, num_train_epochs=1, learning_rate=2e-4, logging_dir="./logs", save_strategy="no", optim="adamw_bnb_8bit", fp16=True, lr_scheduler_type="cosine", ) model.train_model( dataset=dataset, training_args=training_args, lora_r=8, lora_alpha=16, lora_dropout=0.1, ) print("✅ Model retrained successfully") # --- Function: Display training status --- def get_training_status(): return f"Training status: {training_status}" # --- Function: Clear chat --- def clear_chat(): global conversation_history conversation_history = [] return [], "Chat cleared." # --- Build Gradio UI --- with gr.Blocks(title="🧠 Auto-Training Chatbot", theme=gr.themes.Soft()) as demo: gr.Markdown("### 🤖 Adaptive AI Chatbot (with Auto-Retraining)") with gr.Row(): chatbot = gr.Chatbot( height=500, label="Chat Interface", type="messages" ) with gr.Column(): status_display = gr.Textbox( label="Training Status", value="Idle", interactive=False ) clear_btn = gr.Button("Clear Chat") user_input = gr.Textbox( label="Your message", placeholder="Type your question or message here..." ) send_btn = gr.Button("Send") # Bind events send_btn.click( fn=chat_response, inputs=[user_input], outputs=[chatbot, status_display], ) clear_btn.click( fn=clear_chat, outputs=[chatbot, status_display], ) # Add auto-refresh every 2s timer = gr.Timer(interval=2.0, active=True) timer.tick(fn=get_training_status, outputs=[status_display]) # --- Launch --- if __name__ == "__main__": demo.launch()