Spaces:
Sleeping
Sleeping
| # app.py | |
| # Single-page Gradio app for Hugging Face Spaces | |
| # - Trains MiniGPT and classifier on startup (tiny datasets, short epochs by default) | |
| # - Large, centered UI with three panels: | |
| # 1) Instruction -> Response | |
| # 2) Sentiment Classification | |
| # 3) Next word + dataset sentence completion (prefix of two words) | |
| # - Instant input moderation: banned words trigger immediate error and block | |
| # - Greedy decoding for stable minimal outputs | |
| import math, re, os, torch, torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| import gradio as gr | |
| # ---------------------------- | |
| # 1) Data preparation | |
| # ---------------------------- | |
| lm_corpus = [ | |
| "the cat sits on the mat", | |
| "the dog chases the ball", | |
| "a small model can learn patterns", | |
| "language models predict next tokens", | |
| "transformers use attention mechanism", | |
| "training on tiny data is limited", | |
| "we build a model from scratch", | |
| "this is a minimal example", | |
| "positional embeddings encode order", | |
| "causal masking prevents peeking ahead", | |
| ] | |
| cls_data = [ | |
| ("this is bad", 0), | |
| ("i dislike this", 0), | |
| ("terrible and awful", 0), | |
| ("this is good", 1), | |
| ("i like this", 1), | |
| ("wonderful and great", 1), | |
| ] | |
| inst_data_base = [ | |
| ("<INSTR> write a short greeting <ENDINSTR>", "<RESP> hello! <ENDRESP>"), | |
| ("<INSTR> answer briefly what is a cat <ENDINSTR>", "<RESP> a small animal. <ENDRESP>"), | |
| ("<INSTR> continue the sun is <ENDINSTR>", "<RESP> bright. <ENDRESP>"), | |
| ] | |
| inst_data = inst_data_base * 64 # stabilize tiny-data learning | |
| # ---------------------------- | |
| # Tokenization (word-level) | |
| # ---------------------------- | |
| def normalize_text(s): | |
| s = s.lower().strip() | |
| s = re.sub(r'([.!?,:;])', r' \1 ', s) | |
| s = re.sub(r'\s+', ' ', s) | |
| return s | |
| def build_vocab(texts): | |
| tokens = set() | |
| specials = ["<pad>", "<bos>", "<eos>"] | |
| for t in texts: | |
| t = normalize_text(t) | |
| for tok in t.split(): | |
| tokens.add(tok) | |
| vocab = specials + sorted(list(tokens)) | |
| stoi = {s: i for i, s in enumerate(vocab)} | |
| itos = {i: s for s, i in stoi.items()} | |
| return vocab, stoi, itos | |
| all_texts = lm_corpus + [x for x,_ in cls_data] + [a for a,_ in inst_data_base] + [b for _,b in inst_data_base] | |
| vocab, stoi, itos = build_vocab(all_texts) | |
| PAD, BOS, EOS = stoi["<pad>"], stoi["<bos>"], stoi["<eos>"] | |
| vocab_size = len(vocab) | |
| def encode(text, max_len=None, add_special=True): | |
| text = normalize_text(text) | |
| toks = text.split() | |
| ids = ([BOS] if add_special else []) + [stoi.get(tok, PAD) for tok in toks] + ([EOS] if add_special else []) | |
| if max_len is not None: | |
| ids = ids[:max_len] | |
| if len(ids) < max_len: | |
| ids = ids + [PAD] * (max_len - len(ids)) | |
| return torch.tensor(ids, dtype=torch.long) | |
| def decode(ids): | |
| toks = [itos.get(i, "") for i in ids] | |
| toks = [t for t in toks if t not in ("<pad>", "<bos>", "<eos>")] | |
| out = " ".join(toks) | |
| out = re.sub(r'\s+([.!?,:;])', r'\1', out) | |
| return out.strip() | |
| # ---------------------------- | |
| # Datasets | |
| # ---------------------------- | |
| class LMPretrainDataset(Dataset): | |
| def __init__(self, texts, block_size=64): | |
| self.samples = [] | |
| for t in texts: | |
| ids = encode(t, max_len=block_size, add_special=True) | |
| self.samples.append((ids[:-1], ids[1:])) | |
| def __len__(self): return len(self.samples) | |
| def __getitem__(self, idx): return self.samples[idx] | |
| class ClassificationDataset(Dataset): | |
| def __init__(self, pairs, block_size=64): | |
| self.samples = [] | |
| for text, label in pairs: | |
| ids = encode(text, max_len=block_size, add_special=True) | |
| self.samples.append((ids, torch.tensor(label, dtype=torch.long))) | |
| def __len__(self): return len(self.samples) | |
| def __getitem__(self, idx): return self.samples[idx] | |
| class InstructionDataset(Dataset): | |
| def __init__(self, pairs, block_size=64): | |
| self.samples = [] | |
| for instr, resp in pairs: | |
| instr_ids = encode(instr, add_special=False).tolist() | |
| resp_ids = encode(resp, add_special=False).tolist() | |
| seq = [BOS] + instr_ids + [EOS] + [BOS] + resp_ids + [EOS] | |
| seq = seq[:block_size] | |
| if len(seq) < block_size: seq += [PAD] * (block_size - len(seq)) | |
| ids = torch.tensor(seq, dtype=torch.long) | |
| self.samples.append((ids[:-1], ids[1:])) | |
| def __len__(self): return len(self.samples) | |
| def __getitem__(self, idx): return self.samples[idx] | |
| # ---------------------------- | |
| # 2) Model architecture (GPT-style) | |
| # ---------------------------- | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, n_embed, n_head, dropout=0.1): | |
| super().__init__() | |
| assert n_embed % n_head == 0 | |
| self.n_head = n_head | |
| self.head_dim = n_embed // n_head | |
| self.qkv = nn.Linear(n_embed, 3 * n_embed) | |
| self.proj = nn.Linear(n_embed, n_embed) | |
| self.attn_drop = nn.Dropout(dropout) | |
| self.resid_drop = nn.Dropout(dropout) | |
| self.register_buffer("mask", None) | |
| def forward(self, x): | |
| B, T, C = x.size() | |
| qkv = self.qkv(x) | |
| q, k, v = qkv.chunk(3, dim=-1) | |
| q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| if (self.mask is None) or (self.mask.size(-1) != T): | |
| self.mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) | |
| att = att.masked_fill(self.mask == 0, float('-inf')) | |
| att = torch.softmax(att, dim=-1) | |
| att = self.attn_drop(att) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| y = self.proj(y) | |
| y = self.resid_drop(y) | |
| return y | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, n_embed, n_head, mlp_mult=4, dropout=0.1): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(n_embed) | |
| self.attn = CausalSelfAttention(n_embed, n_head, dropout) | |
| self.ln2 = nn.LayerNorm(n_embed) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(n_embed, mlp_mult * n_embed), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(mlp_mult * n_embed, n_embed), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class MiniGPT(nn.Module): | |
| def __init__(self, vocab_size, n_embed=192, n_head=6, n_layer=4, block_size=64, dropout=0.1): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.tok_emb = nn.Embedding(vocab_size, n_embed) | |
| self.pos_emb = nn.Embedding(block_size, n_embed) | |
| self.drop = nn.Dropout(dropout) | |
| self.blocks = nn.ModuleList([TransformerBlock(n_embed, n_head, 4, dropout) for _ in range(n_layer)]) | |
| self.ln_f = nn.LayerNorm(n_embed) | |
| self.head = nn.Linear(n_embed, vocab_size, bias=False) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, (nn.Linear, nn.Embedding)): | |
| nn.init.normal_(m.weight, 0.0, 0.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| def forward(self, idx): | |
| B, T = idx.size() | |
| tok = self.tok_emb(idx) | |
| pos = self.pos_emb(torch.arange(T, device=idx.device)) | |
| x = self.drop(tok + pos) | |
| for blk in self.blocks: x = blk(x) | |
| x = self.ln_f(x) | |
| return self.head(x) | |
| def generate_greedy(self, idx, max_new_tokens=20): | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -self.block_size:] | |
| logits = self(idx_cond) | |
| next_id = logits[:, -1, :].argmax(dim=-1, keepdim=True) | |
| idx = torch.cat([idx, next_id], dim=1) | |
| if next_id.item() == EOS: | |
| break | |
| return idx | |
| # ---------------------------- | |
| # 3) Training pipeline | |
| # ---------------------------- | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| block_size = 64 | |
| lm_dl = DataLoader(LMPretrainDataset(lm_corpus, block_size), batch_size=16, shuffle=True) | |
| cls_dl = DataLoader(ClassificationDataset(cls_data, block_size), batch_size=6, shuffle=True) | |
| inst_dl = DataLoader(InstructionDataset(inst_data, block_size), batch_size=32, shuffle=True) | |
| model = MiniGPT(vocab_size=vocab_size, n_embed=192, n_head=6, n_layer=4, block_size=block_size, dropout=0.1).to(device) | |
| def pretrain(model, dataloader, epochs=8, lr=3e-4, grad_clip=1.0): | |
| opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9,0.95), weight_decay=0.01) | |
| loss_fn = nn.CrossEntropyLoss(ignore_index=PAD) | |
| model.train() | |
| for _ in range(epochs): | |
| for inp, tgt in dataloader: | |
| inp, tgt = inp.to(device), tgt.to(device) | |
| logits = model(inp) | |
| B, T, V = logits.size() | |
| loss = loss_fn(logits.view(B*T, V), tgt.view(B*T)) | |
| opt.zero_grad(set_to_none=True) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| opt.step() | |
| class ClassificationHead(nn.Module): | |
| def __init__(self, backbone: MiniGPT, n_classes=2, freeze_backbone=False): | |
| super().__init__() | |
| self.backbone = backbone | |
| if freeze_backbone: | |
| for p in self.backbone.parameters(): p.requires_grad = False | |
| n_embed = backbone.head.in_features | |
| self.classifier = nn.Sequential(nn.LayerNorm(n_embed), nn.Linear(n_embed, n_classes)) | |
| def forward(self, idx): | |
| B, T = idx.size() | |
| tok = self.backbone.tok_emb(idx) | |
| pos = self.backbone.pos_emb(torch.arange(T, device=idx.device)) | |
| x = self.backbone.drop(tok + pos) | |
| for blk in self.backbone.blocks: x = blk(x) | |
| x = self.backbone.ln_f(x) | |
| eos_mask = (idx == EOS) | |
| last_idx = torch.where( | |
| eos_mask.any(dim=1), | |
| eos_mask.float().argmax(dim=1), | |
| torch.full((B,), T-1, device=idx.device) | |
| ) | |
| pooled = x[torch.arange(B, device=idx.device), last_idx] | |
| return self.classifier(pooled) | |
| clf = ClassificationHead(model, n_classes=2, freeze_backbone=False).to(device) | |
| def finetune_classification(clf, dataloader, epochs=6, lr=8e-4): | |
| opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, clf.parameters()), lr=lr) | |
| loss_fn = nn.CrossEntropyLoss() | |
| clf.train() | |
| for _ in range(epochs): | |
| for x,y in dataloader: | |
| x,y = x.to(device), y.to(device) | |
| logits = clf(x) | |
| loss = loss_fn(logits, y) | |
| opt.zero_grad(set_to_none=True); loss.backward(); opt.step() | |
| def finetune_instruction(model, dataloader, epochs=50, lr=1.5e-4, grad_clip=1.0): | |
| opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9,0.95), weight_decay=0.01) | |
| loss_fn = nn.CrossEntropyLoss(ignore_index=PAD) | |
| model.train() | |
| for _ in range(epochs): | |
| for inp, tgt in dataloader: | |
| inp, tgt = inp.to(device), tgt.to(device) | |
| logits = model(inp) | |
| B,T,V = logits.size() | |
| loss = loss_fn(logits.view(B*T, V), tgt.view(B*T)) | |
| opt.zero_grad(set_to_none=True) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| opt.step() | |
| # ---------------------------- | |
| # 4) Inference helpers | |
| # ---------------------------- | |
| def classify_text(text): | |
| ids = encode(text, max_len=block_size, add_special=True).unsqueeze(0).to(device) | |
| logits = clf(ids) | |
| pred = logits.argmax(dim=-1).item() | |
| return "positive" if pred==1 else "negative" | |
| def generate_response(instruction, max_new_tokens=12): | |
| instr = f"<INSTR> {instruction} <ENDINSTR>" | |
| resp_start = "<RESP>" | |
| prefix_ids = encode(instr, add_special=False).tolist() | |
| resp_start_ids = encode(resp_start, add_special=False).tolist() | |
| seq = [BOS] + prefix_ids + [EOS] + resp_start_ids | |
| idx = torch.tensor(seq, dtype=torch.long, device=device).unsqueeze(0) | |
| out = model.generate_greedy(idx, max_new_tokens=max_new_tokens) | |
| gen = out[0].tolist() | |
| toks = [itos[i] for i in gen] | |
| try: | |
| resp_pos = toks.index("<resp>") | |
| except ValueError: | |
| resp_pos = len(toks)-1 | |
| resp_toks = toks[resp_pos+1:] | |
| if "<endresp>" in resp_toks: | |
| end_idx = resp_toks.index("<endresp>") | |
| resp_toks = resp_toks[:end_idx] | |
| elif "<eos>" in resp_toks: | |
| end_idx = resp_toks.index("<eos>") | |
| resp_toks = resp_toks[:end_idx] | |
| text = " ".join(resp_toks) | |
| text = re.sub(r'\s+([.!?,:;])', r'\1', text).strip() | |
| return text | |
| # --- Next word + dataset sentence completion --- | |
| def predict_next_word_and_complete(prefix_two_words, max_new_tokens=16): | |
| # Normalize and validate | |
| s = normalize_text(prefix_two_words) | |
| toks = s.split() | |
| if len(toks) < 2: | |
| return "(need at least two words)", "(no match)", "(no generation)" | |
| # Moderation handled separately at UI entry | |
| # Next-word prediction via LM | |
| ids = encode(" ".join(toks), add_special=True).unsqueeze(0).to(device) | |
| logits = model(ids) | |
| next_id = logits[:, -1, :].argmax(dim=-1).item() | |
| next_word = itos.get(next_id, "") | |
| # Dataset sentence completion: exact prefix match | |
| prefix = " ".join(toks[:2]) # strictly first two words | |
| matches = [sent for sent in lm_corpus if normalize_text(sent).startswith(prefix + " ")] | |
| matched = "; ".join(matches) if matches else "(no exact dataset sentence starts with those two words)" | |
| # Fallback generation to complete a sentence-like output | |
| gen_ids = model.generate_greedy(ids, max_new_tokens=max_new_tokens) | |
| gen_text = decode(gen_ids[0].tolist()) | |
| return next_word, matched, gen_text | |
| # ---------------------------- | |
| # 5) Moderation (instant lockout) | |
| # ---------------------------- | |
| BANNED = {"hate", "kill", "self-harm", "suicide", "violence"} # extend as needed | |
| def check_banned(s: str): | |
| s_norm = normalize_text(s) | |
| toks = set(s_norm.split()) | |
| bad = toks.intersection(BANNED) | |
| if bad: | |
| raise gr.Error(f"Input contains prohibited words: {', '.join(sorted(bad))}. Submission blocked.") | |
| # ---------------------------- | |
| # 6) Train-on-start (short epochs by default) | |
| # Use env FAST_TRAIN=1 on Spaces for snappy startup | |
| # ---------------------------- | |
| FAST = os.getenv("FAST_TRAIN", "1") == "1" | |
| PRE_EPOCHS = 2 if FAST else 8 | |
| CLS_EPOCHS = 2 if FAST else 6 | |
| INST_EPOCHS = 6 if FAST else 50 | |
| def bootstrap(): | |
| pretrain(model, lm_dl, epochs=PRE_EPOCHS, lr=3e-4) | |
| finetune_classification(clf, cls_dl, epochs=CLS_EPOCHS, lr=8e-4) | |
| finetune_instruction(model, inst_dl, epochs=INST_EPOCHS, lr=1.5e-4) | |
| bootstrap() | |
| # ---------------------------- | |
| # 7) Gradio UI (large, centered) | |
| # ---------------------------- | |
| def ui_generate(instruction, max_tokens): | |
| check_banned(instruction) | |
| resp = generate_response(instruction, max_new_tokens=max_tokens) | |
| return resp if resp.strip() else "(no response)" | |
| def ui_classify(text): | |
| check_banned(text) | |
| return classify_text(text) | |
| def ui_next_word(prefix_two_words, max_tokens): | |
| check_banned(prefix_two_words) | |
| next_word, matched, gen_text = predict_next_word_and_complete(prefix_two_words, max_new_tokens=max_tokens) | |
| return next_word, matched, gen_text | |
| with gr.Blocks(title="Minimal GPT-style LLM (word-level, greedy)") as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align:center; max-width: 880px; margin:auto;"> | |
| <h1 style="font-size: 32px; margin-bottom: 10px;">Minimal GPT-style LLM</h1> | |
| <p style="font-size: 16px;"> | |
| Word-level tokenizer • Tiny transformer • Greedy decoding • Instruction fine-tuning • Sentiment classification • Next-word prediction | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Instruction to response") | |
| instr = gr.Textbox( | |
| label="Instruction", | |
| placeholder="e.g., write a short greeting", | |
| lines=2, | |
| elem_id="instr_box" | |
| ) | |
| max_toks = gr.Slider(4, 32, value=12, step=1, label="Max new tokens") | |
| gen_btn = gr.Button("Generate response", variant="primary", elem_id="gen_btn") | |
| resp = gr.Textbox(label="Model response", lines=4, interactive=False) | |
| gen_btn.click(fn=ui_generate, inputs=[instr, max_toks], outputs=resp) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Sentiment classification") | |
| cls_in = gr.Textbox( | |
| label="Text", | |
| placeholder="e.g., i like this", | |
| lines=2, | |
| elem_id="cls_box" | |
| ) | |
| cls_btn = gr.Button("Classify sentiment", variant="primary", elem_id="cls_btn") | |
| cls_out = gr.Textbox(label="Prediction", lines=1, interactive=False) | |
| cls_btn.click(fn=ui_classify, inputs=cls_in, outputs=cls_out) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Next word + dataset sentence completion") | |
| two_words = gr.Textbox( | |
| label="Enter at least two words (prefix)", | |
| placeholder="e.g., the cat", | |
| lines=1, | |
| elem_id="nw_box" | |
| ) | |
| max_toks_nw = gr.Slider(4, 32, value=16, step=1, label="Max new tokens for generation") | |
| nw_btn = gr.Button("Predict next word & complete", variant="primary", elem_id="nw_btn") | |
| next_word_out = gr.Textbox(label="Next word (LM greedy)", lines=1, interactive=False) | |
| matched_out = gr.Textbox(label="Dataset sentence match (exact prefix)", lines=2, interactive=False) | |
| gen_out = gr.Textbox(label="Generated completion (fallback)", lines=3, interactive=False) | |
| nw_btn.click(fn=ui_next_word, inputs=[two_words, max_toks_nw], outputs=[next_word_out, matched_out, gen_out]) | |
| gr.HTML( | |
| """ | |
| <style> | |
| #instr_box textarea, #cls_box textarea, #nw_box textarea { | |
| font-size: 18px; text-align: center; | |
| } | |
| #gen_btn, #cls_btn, #nw_btn { | |
| font-size: 18px; width: 100%; height: 52px; | |
| } | |
| .gradio-container { max-width: 980px !important; margin: auto !important; } | |
| </style> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |