|
|
import os |
|
|
import time |
|
|
import gradio as gr |
|
|
from unsloth import FastLanguageModel |
|
|
from transformers import TrainingArguments, TextStreamer |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
|
|
|
|
|
model_name = "EpistemeAI/Episteme-gptoss-20b-RL" |
|
|
print(f"Loading model: {model_name}") |
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
|
model_name=model_name, |
|
|
load_in_4bit=False, |
|
|
dtype=None, |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
|
|
|
conversation_history = [] |
|
|
training_data = [] |
|
|
TRAIN_EVERY_N_PROMPTS = 8 |
|
|
training_status = "Idle" |
|
|
|
|
|
|
|
|
def chat_response(user_input): |
|
|
global training_data, training_status |
|
|
|
|
|
|
|
|
conversation_history.append({"role": "user", "content": user_input}) |
|
|
|
|
|
|
|
|
full_prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation_history]) |
|
|
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
model.generate(**inputs, streamer=streamer, max_new_tokens=200) |
|
|
response_text = streamer.text.strip() |
|
|
|
|
|
|
|
|
conversation_history.append({"role": "assistant", "content": response_text}) |
|
|
training_data.append({"instruction": user_input, "expected_output": response_text}) |
|
|
|
|
|
|
|
|
if len(training_data) >= TRAIN_EVERY_N_PROMPTS: |
|
|
training_status = "Training..." |
|
|
yield [ |
|
|
conversation_history, |
|
|
"Auto-training in progress... this may take a moment." |
|
|
] |
|
|
train_model(training_data) |
|
|
training_data.clear() |
|
|
training_status = "Idle" |
|
|
|
|
|
yield [conversation_history, f"Assistant: {response_text}"] |
|
|
|
|
|
|
|
|
def train_model(data_samples): |
|
|
dataset = { |
|
|
"instruction": [d["instruction"] for d in data_samples], |
|
|
"expected_output": [d["expected_output"] for d in data_samples], |
|
|
} |
|
|
training_args = TrainingArguments( |
|
|
output_dir="./results", |
|
|
per_device_train_batch_size=2, |
|
|
num_train_epochs=1, |
|
|
learning_rate=2e-4, |
|
|
logging_dir="./logs", |
|
|
save_strategy="no", |
|
|
optim="adamw_bnb_8bit", |
|
|
fp16=True, |
|
|
lr_scheduler_type="cosine", |
|
|
) |
|
|
|
|
|
model.train_model( |
|
|
dataset=dataset, |
|
|
training_args=training_args, |
|
|
lora_r=8, |
|
|
lora_alpha=16, |
|
|
lora_dropout=0.1, |
|
|
) |
|
|
print("✅ Model retrained successfully") |
|
|
|
|
|
|
|
|
def get_training_status(): |
|
|
return f"Training status: {training_status}" |
|
|
|
|
|
|
|
|
def clear_chat(): |
|
|
global conversation_history |
|
|
conversation_history = [] |
|
|
return [], "Chat cleared." |
|
|
|
|
|
|
|
|
with gr.Blocks(title="🧠 Auto-Training Chatbot", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("### 🤖 Adaptive AI Chatbot (with Auto-Retraining)") |
|
|
|
|
|
with gr.Row(): |
|
|
chatbot = gr.Chatbot( |
|
|
height=500, |
|
|
label="Chat Interface", |
|
|
type="messages" |
|
|
) |
|
|
with gr.Column(): |
|
|
status_display = gr.Textbox( |
|
|
label="Training Status", |
|
|
value="Idle", |
|
|
interactive=False |
|
|
) |
|
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
|
|
user_input = gr.Textbox( |
|
|
label="Your message", |
|
|
placeholder="Type your question or message here..." |
|
|
) |
|
|
send_btn = gr.Button("Send") |
|
|
|
|
|
|
|
|
send_btn.click( |
|
|
fn=chat_response, |
|
|
inputs=[user_input], |
|
|
outputs=[chatbot, status_display], |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=clear_chat, |
|
|
outputs=[chatbot, status_display], |
|
|
) |
|
|
|
|
|
|
|
|
timer = gr.Timer(interval=2.0, active=True) |
|
|
timer.tick(fn=get_training_status, outputs=[status_display]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |