legolasyiu's picture
Update app.py
325feac verified
import os
import time
import gradio as gr
from unsloth import FastLanguageModel
from transformers import TrainingArguments, TextStreamer
from datasets import load_dataset
# --- Environment Fixes ---
os.environ["OMP_NUM_THREADS"] = "1"
# --- Load model ---
model_name = "EpistemeAI/Episteme-gptoss-20b-RL"
print(f"Loading model: {model_name}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
load_in_4bit=False,
dtype=None,
device_map="auto",
)
# --- Simple memory for ongoing chat ---
conversation_history = []
training_data = []
TRAIN_EVERY_N_PROMPTS = 8
training_status = "Idle"
# --- Function: Generate chat response ---
def chat_response(user_input):
global training_data, training_status
# Record the prompt
conversation_history.append({"role": "user", "content": user_input})
# Build prompt
full_prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation_history])
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
# Stream response
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
model.generate(**inputs, streamer=streamer, max_new_tokens=200)
response_text = streamer.text.strip()
# Store assistant message
conversation_history.append({"role": "assistant", "content": response_text})
training_data.append({"instruction": user_input, "expected_output": response_text})
# Auto-train trigger
if len(training_data) >= TRAIN_EVERY_N_PROMPTS:
training_status = "Training..."
yield [
conversation_history,
"Auto-training in progress... this may take a moment."
]
train_model(training_data)
training_data.clear()
training_status = "Idle"
yield [conversation_history, f"Assistant: {response_text}"]
# --- Function: Train model on recent conversations ---
def train_model(data_samples):
dataset = {
"instruction": [d["instruction"] for d in data_samples],
"expected_output": [d["expected_output"] for d in data_samples],
}
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=2,
num_train_epochs=1,
learning_rate=2e-4,
logging_dir="./logs",
save_strategy="no",
optim="adamw_bnb_8bit",
fp16=True,
lr_scheduler_type="cosine",
)
model.train_model(
dataset=dataset,
training_args=training_args,
lora_r=8,
lora_alpha=16,
lora_dropout=0.1,
)
print("✅ Model retrained successfully")
# --- Function: Display training status ---
def get_training_status():
return f"Training status: {training_status}"
# --- Function: Clear chat ---
def clear_chat():
global conversation_history
conversation_history = []
return [], "Chat cleared."
# --- Build Gradio UI ---
with gr.Blocks(title="🧠 Auto-Training Chatbot", theme=gr.themes.Soft()) as demo:
gr.Markdown("### 🤖 Adaptive AI Chatbot (with Auto-Retraining)")
with gr.Row():
chatbot = gr.Chatbot(
height=500,
label="Chat Interface",
type="messages"
)
with gr.Column():
status_display = gr.Textbox(
label="Training Status",
value="Idle",
interactive=False
)
clear_btn = gr.Button("Clear Chat")
user_input = gr.Textbox(
label="Your message",
placeholder="Type your question or message here..."
)
send_btn = gr.Button("Send")
# Bind events
send_btn.click(
fn=chat_response,
inputs=[user_input],
outputs=[chatbot, status_display],
)
clear_btn.click(
fn=clear_chat,
outputs=[chatbot, status_display],
)
# Add auto-refresh every 2s
timer = gr.Timer(interval=2.0, active=True)
timer.tick(fn=get_training_status, outputs=[status_display])
# --- Launch ---
if __name__ == "__main__":
demo.launch()