Spaces:
Sleeping
Sleeping
| # app.py | |
| # Run: python app.py | |
| # Then open http://localhost:7860 in browser | |
| import re, random, torch, torch.nn as nn, torch.nn.functional as F | |
| from flask import Flask, request, jsonify | |
| # ---------------- Dataset ---------------- | |
| DATA_QA = [ | |
| ("What is the capital of India?", "New Delhi."), | |
| ("What is the capital of USA?", "Washington, D.C."), | |
| ("What is the capital of France?", "Paris."), | |
| ("What is the capital of Japan?", "Tokyo."), | |
| ("What is the capital of China?", "Beijing."), | |
| ("What is the capital of Russia?", "Moscow."), | |
| ("What is the capital of Brazil?", "Brasilia."), | |
| ("What is the capital of Canada?", "Ottawa."), | |
| ] | |
| # ---------------- Tokenizer ---------------- | |
| PUNCT = ["?", ".", ",", ":", ";", "!", "/"] | |
| SPECIALS = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"] | |
| def basic_tokenize(text): | |
| text = text.lower().strip() | |
| for p in PUNCT: text = text.replace(p, f" {p} ") | |
| return text.split() | |
| def build_vocab(pairs): | |
| freq = {} | |
| for q,a in pairs: | |
| for t in basic_tokenize("q: "+q) + basic_tokenize("a: "+a): | |
| freq[t] = freq.get(t,0)+1 | |
| itos = list(SPECIALS) | |
| for t in sorted(freq.keys()): | |
| if t not in SPECIALS: itos.append(t) | |
| stoi = {t:i for i,t in enumerate(itos)} | |
| return stoi, itos | |
| stoi, itos = build_vocab(DATA_QA) | |
| PAD,BOS,EOS,UNK = [stoi[s] for s in SPECIALS] | |
| def encode(text): return [stoi.get(t,UNK) for t in basic_tokenize(text)] | |
| def wrap_bos_eos(tokens): return [BOS]+tokens+[EOS] | |
| # ---------------- Model ---------------- | |
| class TinyTransformer(nn.Module): | |
| def __init__(self,vocab_size,d_model=64,n_heads=2,n_layers=2,d_ff=128,max_len=64): | |
| super().__init__() | |
| self.tok_emb=nn.Embedding(vocab_size,d_model) | |
| self.pos_emb=nn.Embedding(max_len,d_model) | |
| enc_layer=nn.TransformerEncoderLayer(d_model,n_heads,d_ff,dropout=0.1,batch_first=True) | |
| self.transformer=nn.TransformerEncoder(enc_layer,n_layers) | |
| self.ln=nn.LayerNorm(d_model) | |
| self.head=nn.Linear(d_model,vocab_size) | |
| self.max_len=max_len | |
| def causal_mask(self,T,device): return torch.triu(torch.ones(T,T,device=device),1)==1 | |
| def forward(self,idx): | |
| B,T=idx.size() | |
| pos=torch.arange(0,T,device=idx.device).unsqueeze(0).expand(B,T) | |
| x=self.tok_emb(idx)+self.pos_emb(pos) | |
| x=self.transformer(x,mask=self.causal_mask(T,idx.device)) | |
| x=self.ln(x) | |
| return self.head(x) | |
| # ---------------- Training ---------------- | |
| def make_sequences(): | |
| seqs=[] | |
| for q,a in DATA_QA: | |
| seq=wrap_bos_eos(encode("q: "+q)+encode("a: "+a)) | |
| seqs.append(seq) | |
| return seqs | |
| def train_model(): | |
| device=torch.device("cpu") | |
| model=TinyTransformer(len(itos)).to(device) | |
| opt=torch.optim.AdamW(model.parameters(),lr=3e-3) | |
| seqs=make_sequences() | |
| for ep in range(50): | |
| random.shuffle(seqs) | |
| for s in seqs: | |
| x=torch.tensor(s[:-1]).unsqueeze(0) | |
| y=torch.tensor(s[1:]).unsqueeze(0) | |
| logits=model(x) | |
| loss=F.cross_entropy(logits.view(-1,len(itos)),y.view(-1),ignore_index=PAD) | |
| opt.zero_grad(); loss.backward(); opt.step() | |
| return model | |
| model=train_model() | |
| model.eval() | |
| # ---------------- Inference ---------------- | |
| def generate_answer(question,max_new_tokens=20): | |
| q_ids=encode("q: "+question) | |
| a_prefix=encode("a:") | |
| tokens=wrap_bos_eos(q_ids+a_prefix)[:-1] | |
| x=torch.tensor(tokens).unsqueeze(0) | |
| for _ in range(max_new_tokens): | |
| if x.size(1)>=model.max_len: break | |
| logits=model(x) | |
| next_id=logits[:,-1,:].argmax(-1).item() | |
| if next_id==EOS: break | |
| x=torch.cat([x,torch.tensor([[next_id]])],1) | |
| gen_ids=x.squeeze(0).tolist() | |
| prefix_len=1+len(q_ids)+len(a_prefix) | |
| answer_ids=gen_ids[prefix_len:] | |
| return " ".join(itos[i] for i in answer_ids if i not in (PAD,BOS,EOS)) | |
| # ---------------- Flask App ---------------- | |
| app=Flask(__name__) | |
| BAN_REGEX=re.compile(r"(?i)\bsex\b") | |
| def index(): | |
| return """ | |
| <!doctype html><html><head><meta charset='utf-8'><title>Chatbot</title> | |
| <style> | |
| body{font-family:sans-serif;background:#111;color:#eee} | |
| #chat{height:300px;overflow-y:auto;border:1px solid #444;padding:10px;margin-bottom:10px} | |
| .bubble{margin:5px;padding:8px;border-radius:6px} | |
| .user{background:#2563eb;color:#fff} | |
| .bot{background:#374151;color:#eee} | |
| </style></head><body> | |
| <h2>SLM Chatbot</h2> | |
| <div><label>Username: <input id='username'></label><button onclick='setUser()'>Set</button></div> | |
| <div id='chat'></div> | |
| <input id='msg' placeholder='Type message'><button onclick='sendMsg()'>Submit</button> | |
| <button onclick='clearChat()'>Clear</button> | |
| <script> | |
| let banned=false,username=''; | |
| function addBubble(sender,text){ | |
| let div=document.createElement('div'); | |
| div.className='bubble '+sender; | |
| div.textContent=(sender==='user'?username||'You':'Bot')+': '+text; | |
| document.getElementById('chat').appendChild(div); | |
| } | |
| function setUser(){username=document.getElementById('username').value;addBubble('bot','Hello '+username);} | |
| async function sendMsg(){ | |
| if(banned) return; | |
| let text=document.getElementById('msg').value.trim(); | |
| if(!text) return; | |
| if(/\\bsex\\b/i.test(text)){banned=true;addBubble('bot','banned');return;} | |
| addBubble('user',text); | |
| let r=await fetch('/answer',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({q:text})}); | |
| let j=await r.json(); | |
| addBubble('bot',j.answer); | |
| } | |
| function clearChat(){document.getElementById('chat').innerHTML='';banned=false;} | |
| </script></body></html> | |
| """ | |
| def answer(): | |
| q=request.json.get("q","") | |
| if BAN_REGEX.search(q): return jsonify({"answer":"banned"}) | |
| ans=generate_answer(q) | |
| return jsonify({"answer":ans}) | |
| if __name__=="__main__": | |
| app.run(host="0.0.0.0",port=7860) | |