FlameF0X commited on
Commit
b7599e3
·
verified ·
1 Parent(s): 5f2f81d

Create user.py

Browse files
Files changed (1) hide show
  1. user.py +220 -0
user.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import json, os, numpy as np
5
+
6
+ # ============================================================================
7
+ class ChunkTokenizer:
8
+ def __init__(self):
9
+ self.chunk_to_idx = {}
10
+ self.idx_to_chunk = {}
11
+ self.vocab_size = 0
12
+
13
+ def load(self, path):
14
+ with open(path, 'r') as f:
15
+ vocab_data = json.load(f)
16
+ self.chunk_to_idx = vocab_data['chunk_to_idx']
17
+ self.idx_to_chunk = {int(k): v for k, v in vocab_data['idx_to_chunk'].items()}
18
+ self.vocab_size = vocab_data['vocab_size']
19
+ print(f"Loaded tokenizer ({self.vocab_size} tokens)")
20
+
21
+ def encode(self, text):
22
+ text = text.lower()
23
+ pos, indices = 0, []
24
+ while pos < len(text):
25
+ for size in (3, 2, 1):
26
+ chunk = text[pos:pos+size]
27
+ if chunk in self.chunk_to_idx:
28
+ indices.append(self.chunk_to_idx[chunk])
29
+ pos += size
30
+ break
31
+ else:
32
+ pos += 1
33
+ return indices
34
+
35
+ def decode(self, indices):
36
+ return ''.join([self.idx_to_chunk.get(int(i), '') for i in indices])
37
+
38
+
39
+ # ============================================================================
40
+ class LoRPtLinear(nn.Module):
41
+ def __init__(self, in_features, out_features, rank=64):
42
+ super().__init__()
43
+ self.lora_A = nn.Parameter(torch.randn(out_features, rank) * 0.02)
44
+ self.lora_B = nn.Parameter(torch.randn(rank, in_features) * 0.02)
45
+ self.bias = nn.Parameter(torch.zeros(out_features))
46
+
47
+ def forward(self, x):
48
+ return F.linear(x, self.lora_A @ self.lora_B, self.bias)
49
+
50
+
51
+ class RWKVMambaHybrid(nn.Module):
52
+ def __init__(self, d_model, d_state=32):
53
+ super().__init__()
54
+ self.d_model = d_model
55
+ self.d_state = d_state
56
+ self.w_mix = nn.Parameter(torch.ones(d_model) * 0.5)
57
+ self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
58
+ self.B = nn.Parameter(torch.randn(d_state, d_model) * 0.01)
59
+ self.C = nn.Parameter(torch.randn(d_model, d_state) * 0.01)
60
+ self.D = nn.Parameter(torch.ones(d_model) * 0.1)
61
+
62
+ def forward(self, x):
63
+ B, T, C = x.shape
64
+ h = torch.zeros(B, C, device=x.device)
65
+ s = torch.zeros(B, self.d_state, device=x.device)
66
+ outputs = []
67
+ for t in range(T):
68
+ x_t = x[:, t, :]
69
+ h = self.w_mix * h + (1 - self.w_mix) * x_t
70
+ s = s @ self.A.T + x_t @ self.B.T
71
+ y_t = s @ self.C.T + h * self.D
72
+ outputs.append(y_t)
73
+ return torch.stack(outputs, dim=1)
74
+
75
+
76
+ class KQVAttention(nn.Module):
77
+ def __init__(self, d_model, n_heads=16, rank=64):
78
+ super().__init__()
79
+ self.d_model = d_model
80
+ self.n_heads = n_heads
81
+ self.head_dim = d_model // n_heads
82
+ self.q_down = nn.Linear(d_model, rank)
83
+ self.q_up = nn.Linear(rank, d_model)
84
+ self.k_down = nn.Linear(d_model, rank)
85
+ self.k_up = nn.Linear(rank, d_model)
86
+ self.v_down = nn.Linear(d_model, rank)
87
+ self.v_up = nn.Linear(rank, d_model)
88
+ self.out_proj = nn.Linear(d_model, d_model)
89
+
90
+ def forward(self, x, mask=None):
91
+ B, T, C = x.shape
92
+ q = self.q_up(self.q_down(x))
93
+ k = self.k_up(self.k_down(x))
94
+ v = self.v_up(self.v_down(x))
95
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
96
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
97
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
98
+ attn = (q @ k.transpose(-2, -1)) / np.sqrt(self.head_dim)
99
+ if mask is not None:
100
+ attn = attn.masked_fill(mask == 0, float('-inf'))
101
+ attn = F.softmax(attn, dim=-1)
102
+ out = attn @ v
103
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
104
+ return self.out_proj(out)
105
+
106
+
107
+ class i3Block(nn.Module):
108
+ def __init__(self, d_model, n_heads=16, d_state=32, rank=64, ffn_mult=4):
109
+ super().__init__()
110
+ self.hybrid = RWKVMambaHybrid(d_model, d_state)
111
+ self.ln1 = nn.LayerNorm(d_model)
112
+ self.attn = KQVAttention(d_model, n_heads, rank)
113
+ self.ln2 = nn.LayerNorm(d_model)
114
+ d_ff = d_model * ffn_mult
115
+ self.ffn = nn.Sequential(
116
+ LoRPtLinear(d_model, d_ff, rank),
117
+ nn.GELU(),
118
+ LoRPtLinear(d_ff, d_model, rank)
119
+ )
120
+ self.ln3 = nn.LayerNorm(d_model)
121
+
122
+ def forward(self, x, mask=None):
123
+ x = x + self.hybrid(self.ln1(x))
124
+ x = x + self.attn(self.ln2(x), mask)
125
+ x = x + self.ffn(self.ln3(x))
126
+ return x
127
+
128
+
129
+ class i3Model(nn.Module):
130
+ def __init__(self, vocab_size, d_model=512, n_layers=24, n_heads=16,
131
+ max_seq_len=256, rank=64, d_state=32):
132
+ super().__init__()
133
+ self.vocab_size = vocab_size
134
+ self.d_model = d_model
135
+ self.max_seq_len = max_seq_len
136
+ self.embed = nn.Embedding(vocab_size, d_model)
137
+ self.pos_embed = nn.Embedding(max_seq_len, d_model)
138
+ self.layers = nn.ModuleList([
139
+ i3Block(d_model, n_heads, d_state, rank)
140
+ for _ in range(n_layers)
141
+ ])
142
+ self.ln_f = nn.LayerNorm(d_model)
143
+ self.head = LoRPtLinear(d_model, vocab_size, rank)
144
+
145
+ def forward(self, idx):
146
+ B, T = idx.shape
147
+ pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
148
+ x = self.embed(idx) + self.pos_embed(pos)
149
+ mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T)
150
+ for layer in self.layers:
151
+ x = layer(x, mask)
152
+ x = self.ln_f(x)
153
+ return self.head(x)
154
+
155
+ @torch.no_grad()
156
+ def generate(self, idx, max_new_tokens=100, temperature=0.8, top_k=40):
157
+ for _ in range(max_new_tokens):
158
+ idx_cond = idx[:, -self.max_seq_len:]
159
+ logits = self(idx_cond)[:, -1, :] / temperature
160
+ v, _ = torch.topk(logits, top_k)
161
+ logits[logits < v[:, [-1]]] = -float("inf")
162
+ probs = F.softmax(logits, dim=-1)
163
+ idx_next = torch.multinomial(probs, 1)
164
+ idx = torch.cat((idx, idx_next), dim=1)
165
+ return idx
166
+
167
+
168
+ # ============================================================================
169
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
170
+
171
+ tokenizer = ChunkTokenizer()
172
+ tokenizer.load("tokenizer.json")
173
+
174
+ model = i3Model(
175
+ vocab_size=tokenizer.vocab_size,
176
+ d_model=512,
177
+ n_layers=24,
178
+ n_heads=16,
179
+ max_seq_len=256,
180
+ rank=64,
181
+ d_state=32
182
+ ).to(device)
183
+
184
+ state_dict = torch.load("pytorch_model.bin", map_location=device)
185
+ model.load_state_dict(state_dict)
186
+ model.eval()
187
+ print("✓ Model loaded successfully")
188
+
189
+
190
+ # ============================================================================
191
+ @torch.no_grad()
192
+ def infer(prompt, max_new_tokens=100, temperature=0.8, top_k=40):
193
+ input_ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(device)
194
+ output = model.generate(input_ids, max_new_tokens=max_new_tokens,
195
+ temperature=temperature, top_k=top_k)
196
+ return tokenizer.decode(output[0].cpu().numpy())
197
+
198
+
199
+ def chat_loop():
200
+ print("=== i3 Interactive Chat ([INST] format) ===")
201
+ history = ""
202
+ while True:
203
+ user_input = input("[You] ")
204
+ if user_input.strip().lower() in {"quit", "exit"}:
205
+ break
206
+ prompt = f"{history}[INST] {user_input.strip()} [/INST]"
207
+ reply = infer(prompt, max_new_tokens=120)
208
+ reply_clean = reply.replace(prompt.lower(), "").strip()
209
+ print("[i3]:", reply_clean)
210
+ history += f"[INST] {user_input.strip()} [/INST] {reply_clean} "
211
+
212
+
213
+ # ============================================================================
214
+ print("\nExample:")
215
+ prompt = "[INST] What can we do to make people happier [/INST]"
216
+ print("Prompt:", prompt)
217
+ print("Generated:", infer(prompt))
218
+
219
+ # Optionally start a chat loop:
220
+ # chat_loop()