File size: 5,825 Bytes
21e9cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# app.py
# Run: python app.py
# Then open http://localhost:7860 in browser

import re, random, torch, torch.nn as nn, torch.nn.functional as F
from flask import Flask, request, jsonify

# ---------------- Dataset ----------------
DATA_QA = [
    ("What is the capital of India?", "New Delhi."),
    ("What is the capital of USA?", "Washington, D.C."),
    ("What is the capital of France?", "Paris."),
    ("What is the capital of Japan?", "Tokyo."),
    ("What is the capital of China?", "Beijing."),
    ("What is the capital of Russia?", "Moscow."),
    ("What is the capital of Brazil?", "Brasilia."),
    ("What is the capital of Canada?", "Ottawa."),
]

# ---------------- Tokenizer ----------------
PUNCT = ["?", ".", ",", ":", ";", "!", "/"]
SPECIALS = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"]

def basic_tokenize(text):
    text = text.lower().strip()
    for p in PUNCT: text = text.replace(p, f" {p} ")
    return text.split()

def build_vocab(pairs):
    freq = {}
    for q,a in pairs:
        for t in basic_tokenize("q: "+q) + basic_tokenize("a: "+a):
            freq[t] = freq.get(t,0)+1
    itos = list(SPECIALS)
    for t in sorted(freq.keys()):
        if t not in SPECIALS: itos.append(t)
    stoi = {t:i for i,t in enumerate(itos)}
    return stoi, itos

stoi, itos = build_vocab(DATA_QA)
PAD,BOS,EOS,UNK = [stoi[s] for s in SPECIALS]

def encode(text): return [stoi.get(t,UNK) for t in basic_tokenize(text)]
def wrap_bos_eos(tokens): return [BOS]+tokens+[EOS]

# ---------------- Model ----------------
class TinyTransformer(nn.Module):
    def __init__(self,vocab_size,d_model=64,n_heads=2,n_layers=2,d_ff=128,max_len=64):
        super().__init__()
        self.tok_emb=nn.Embedding(vocab_size,d_model)
        self.pos_emb=nn.Embedding(max_len,d_model)
        enc_layer=nn.TransformerEncoderLayer(d_model,n_heads,d_ff,dropout=0.1,batch_first=True)
        self.transformer=nn.TransformerEncoder(enc_layer,n_layers)
        self.ln=nn.LayerNorm(d_model)
        self.head=nn.Linear(d_model,vocab_size)
        self.max_len=max_len
    def causal_mask(self,T,device): return torch.triu(torch.ones(T,T,device=device),1)==1
    def forward(self,idx):
        B,T=idx.size()
        pos=torch.arange(0,T,device=idx.device).unsqueeze(0).expand(B,T)
        x=self.tok_emb(idx)+self.pos_emb(pos)
        x=self.transformer(x,mask=self.causal_mask(T,idx.device))
        x=self.ln(x)
        return self.head(x)

# ---------------- Training ----------------
def make_sequences():
    seqs=[]
    for q,a in DATA_QA:
        seq=wrap_bos_eos(encode("q: "+q)+encode("a: "+a))
        seqs.append(seq)
    return seqs

def train_model():
    device=torch.device("cpu")
    model=TinyTransformer(len(itos)).to(device)
    opt=torch.optim.AdamW(model.parameters(),lr=3e-3)
    seqs=make_sequences()
    for ep in range(50):
        random.shuffle(seqs)
        for s in seqs:
            x=torch.tensor(s[:-1]).unsqueeze(0)
            y=torch.tensor(s[1:]).unsqueeze(0)
            logits=model(x)
            loss=F.cross_entropy(logits.view(-1,len(itos)),y.view(-1),ignore_index=PAD)
            opt.zero_grad(); loss.backward(); opt.step()
    return model

model=train_model()
model.eval()

# ---------------- Inference ----------------
def generate_answer(question,max_new_tokens=20):
    q_ids=encode("q: "+question)
    a_prefix=encode("a:")
    tokens=wrap_bos_eos(q_ids+a_prefix)[:-1]
    x=torch.tensor(tokens).unsqueeze(0)
    for _ in range(max_new_tokens):
        if x.size(1)>=model.max_len: break
        logits=model(x)
        next_id=logits[:,-1,:].argmax(-1).item()
        if next_id==EOS: break
        x=torch.cat([x,torch.tensor([[next_id]])],1)
    gen_ids=x.squeeze(0).tolist()
    prefix_len=1+len(q_ids)+len(a_prefix)
    answer_ids=gen_ids[prefix_len:]
    return " ".join(itos[i] for i in answer_ids if i not in (PAD,BOS,EOS))

# ---------------- Flask App ----------------
app=Flask(__name__)
BAN_REGEX=re.compile(r"(?i)\bsex\b")

@app.route("/")
def index():
    return """
<!doctype html><html><head><meta charset='utf-8'><title>Chatbot</title>
<style>
body{font-family:sans-serif;background:#111;color:#eee}
#chat{height:300px;overflow-y:auto;border:1px solid #444;padding:10px;margin-bottom:10px}
.bubble{margin:5px;padding:8px;border-radius:6px}
.user{background:#2563eb;color:#fff}
.bot{background:#374151;color:#eee}
</style></head><body>
<h2>SLM Chatbot</h2>
<div><label>Username: <input id='username'></label><button onclick='setUser()'>Set</button></div>
<div id='chat'></div>
<input id='msg' placeholder='Type message'><button onclick='sendMsg()'>Submit</button>
<button onclick='clearChat()'>Clear</button>
<script>
let banned=false,username='';
function addBubble(sender,text){
  let div=document.createElement('div');
  div.className='bubble '+sender;
  div.textContent=(sender==='user'?username||'You':'Bot')+': '+text;
  document.getElementById('chat').appendChild(div);
}
function setUser(){username=document.getElementById('username').value;addBubble('bot','Hello '+username);}
async function sendMsg(){
  if(banned) return;
  let text=document.getElementById('msg').value.trim();
  if(!text) return;
  if(/\\bsex\\b/i.test(text)){banned=true;addBubble('bot','banned');return;}
  addBubble('user',text);
  let r=await fetch('/answer',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({q:text})});
  let j=await r.json();
  addBubble('bot',j.answer);
}
function clearChat(){document.getElementById('chat').innerHTML='';banned=false;}
</script></body></html>
"""

@app.route("/answer",methods=["POST"])
def answer():
    q=request.json.get("q","")
    if BAN_REGEX.search(q): return jsonify({"answer":"banned"})
    ans=generate_answer(q)
    return jsonify({"answer":ans})

if __name__=="__main__":
    app.run(host="0.0.0.0",port=7860)