Kiria-Nozan commited on
Commit
92779dc
·
verified ·
1 Parent(s): 2e96d8d

initial release

Browse files
Files changed (5) hide show
  1. models/__init__.py +4 -0
  2. models/autoregressive.py +358 -0
  3. models/dimamba.py +1136 -0
  4. models/dit.py +514 -0
  5. models/ema.py +97 -0
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import dit
2
+ from . import dimamba
3
+ from . import ema
4
+ from . import autoregressive
models/autoregressive.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing
3
+
4
+ import flash_attn
5
+ import flash_attn.layers.rotary
6
+ import huggingface_hub
7
+ import omegaconf
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+
13
+ # Flags required to enable jit fusion kernels
14
+ torch._C._jit_set_profiling_mode(False)
15
+ torch._C._jit_set_profiling_executor(False)
16
+ torch._C._jit_override_can_fuse_on_cpu(True)
17
+ torch._C._jit_override_can_fuse_on_gpu(True)
18
+
19
+
20
+ def bias_dropout_add_scale(
21
+ x: torch.Tensor,
22
+ bias: typing.Optional[torch.Tensor],
23
+ scale: torch.Tensor,
24
+ residual: typing.Optional[torch.Tensor],
25
+ prob: float,
26
+ training: bool,
27
+ ) -> torch.Tensor:
28
+ if bias is not None:
29
+ out = scale * F.dropout(
30
+ x + bias, p=prob, training=training
31
+ )
32
+ else:
33
+ out = scale * F.dropout(x, p=prob, training=training)
34
+
35
+ if residual is not None:
36
+ out = residual + out
37
+ return out
38
+
39
+
40
+ def get_bias_dropout_add_scale(training):
41
+ def _bias_dropout_add(x, bias, scale, residual, prob):
42
+ return bias_dropout_add_scale(
43
+ x, bias, scale, residual, prob, training
44
+ )
45
+
46
+ return _bias_dropout_add
47
+
48
+
49
+ @torch.jit.script
50
+ def bias_dropout_add_scale_fused_train(
51
+ x: torch.Tensor,
52
+ bias: typing.Optional[torch.Tensor],
53
+ scale: torch.Tensor,
54
+ residual: typing.Optional[torch.Tensor],
55
+ prob: float,
56
+ ) -> torch.Tensor:
57
+ return bias_dropout_add_scale(
58
+ x, bias, scale, residual, prob, True
59
+ )
60
+
61
+
62
+ @torch.jit.script
63
+ def bias_dropout_add_scale_fused_inference(
64
+ x: torch.Tensor,
65
+ bias: typing.Optional[torch.Tensor],
66
+ scale: torch.Tensor,
67
+ residual: typing.Optional[torch.Tensor],
68
+ prob: float,
69
+ ) -> torch.Tensor:
70
+ return bias_dropout_add_scale(
71
+ x, bias, scale, residual, prob, False
72
+ )
73
+
74
+
75
+ class Rotary(torch.nn.Module):
76
+ def __init__(self, dim, base=10_000):
77
+ super().__init__()
78
+ inv_freq = 1.0 / (
79
+ base ** (torch.arange(0, dim, 2).float() / dim)
80
+ )
81
+ self.register_buffer('inv_freq', inv_freq)
82
+ self.seq_len_cached = None
83
+ self.cos_cached = None
84
+ self.sin_cached = None
85
+
86
+ def forward(self, x, seq_dim=1):
87
+ seq_len = x.shape[seq_dim]
88
+ if seq_len != self.seq_len_cached:
89
+ self.seq_len_cached = seq_len
90
+ t = torch.arange(
91
+ x.shape[seq_dim], device=x.device
92
+ ).type_as(self.inv_freq)
93
+ freqs = torch.einsum(
94
+ 'i,j->ij', t, self.inv_freq.clone()
95
+ )
96
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
97
+ # dims are: batch, seq_len, qkv, head, dim
98
+ self.cos_cached = emb.cos()[
99
+ None, :, None, None, :
100
+ ].repeat(1, 1, 3, 1, 1)
101
+ self.sin_cached = emb.sin()[
102
+ None, :, None, None, :
103
+ ].repeat(1, 1, 3, 1, 1)
104
+ # This makes the transformation on v an identity.
105
+ self.cos_cached[:, :, 2, :, :].fill_(1.0)
106
+ self.sin_cached[:, :, 2, :, :].fill_(0.0)
107
+
108
+ return self.cos_cached, self.sin_cached
109
+
110
+
111
+ def rotate_half(x):
112
+ x1, x2 = (
113
+ x[..., : x.shape[-1] // 2],
114
+ x[..., x.shape[-1] // 2 :],
115
+ )
116
+ return torch.cat((-x2, x1), dim=-1)
117
+
118
+
119
+ def apply_rotary_pos_emb(qkv, cos, sin):
120
+ cos = cos[0, :, 0, 0, : cos.shape[-1] // 2]
121
+ sin = sin[0, :, 0, 0, : sin.shape[-1] // 2]
122
+ return flash_attn.layers.rotary.apply_rotary_emb_qkv_(
123
+ qkv, cos, sin
124
+ )
125
+
126
+
127
+ #################################################################################
128
+ # Layers #
129
+ #################################################################################
130
+ class LayerNorm(nn.Module):
131
+ def __init__(self, dim):
132
+ super().__init__()
133
+ self.weight = nn.Parameter(torch.ones([dim]))
134
+ self.dim = dim
135
+
136
+ def forward(self, x):
137
+ with torch.cuda.amp.autocast(enabled=False):
138
+ x = F.layer_norm(x.float(), [self.dim])
139
+ return x * self.weight[None, None, :]
140
+
141
+
142
+ def residual_linear(x, W, x_skip, residual_scale):
143
+ """x_skip + residual_scale * W @ x"""
144
+ dim_out, dim_in = W.shape[0], W.shape[1]
145
+ return torch.addmm(
146
+ x_skip.view(-1, dim_out),
147
+ x.view(-1, dim_in),
148
+ W.T,
149
+ alpha=residual_scale,
150
+ ).view(*x.shape[:-1], dim_out)
151
+
152
+
153
+ #################################################################################
154
+ # Core Model #
155
+ #################################################################################
156
+
157
+
158
+ class DDiTBlock(nn.Module):
159
+ def __init__(
160
+ self,
161
+ dim,
162
+ n_heads,
163
+ cond_dim,
164
+ mlp_ratio=4,
165
+ dropout=0.1,
166
+ causal=False,
167
+ ):
168
+ super().__init__()
169
+ self.n_heads = n_heads
170
+ self.causal = causal
171
+
172
+ self.norm1 = LayerNorm(dim)
173
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
174
+ self.attn_out = nn.Linear(dim, dim, bias=False)
175
+ self.dropout1 = nn.Dropout(dropout)
176
+
177
+ self.norm2 = LayerNorm(dim)
178
+ self.mlp = nn.Sequential(
179
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
180
+ nn.GELU(approximate='tanh'),
181
+ nn.Linear(mlp_ratio * dim, dim, bias=True),
182
+ )
183
+ self.dropout2 = nn.Dropout(dropout)
184
+ self.dropout = dropout
185
+
186
+ def _get_bias_dropout_scale(self):
187
+ if self.training:
188
+ return bias_dropout_add_scale_fused_train
189
+ else:
190
+ return bias_dropout_add_scale_fused_inference
191
+
192
+ def forward(self, x, rotary_cos_sin, c, seqlens=None):
193
+ batch_size, seq_len = x.shape[0], x.shape[1]
194
+
195
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
196
+
197
+ # attention operation
198
+ x_skip = x
199
+ x = self.norm1(x)
200
+
201
+ qkv = self.attn_qkv(x)
202
+ qkv = rearrange(
203
+ qkv,
204
+ 'b s (three h d) -> b s three h d',
205
+ three=3,
206
+ h=self.n_heads,
207
+ )
208
+ with torch.cuda.amp.autocast(enabled=False):
209
+ cos, sin = rotary_cos_sin
210
+ qkv = apply_rotary_pos_emb(
211
+ qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)
212
+ )
213
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
214
+ if seqlens is None:
215
+ cu_seqlens = torch.arange(
216
+ 0,
217
+ (batch_size + 1) * seq_len,
218
+ step=seq_len,
219
+ dtype=torch.int32,
220
+ device=qkv.device,
221
+ )
222
+ else:
223
+ cu_seqlens = seqlens.cumsum(-1)
224
+ x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
225
+ qkv, cu_seqlens, seq_len, 0.0, causal=self.causal
226
+ )
227
+
228
+ x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)
229
+
230
+ scale = torch.ones(1, device=x.device, dtype=x.dtype)
231
+ x = bias_dropout_scale_fn(
232
+ self.attn_out(x), None, scale, x_skip, self.dropout
233
+ )
234
+
235
+ # mlp operation
236
+ x = bias_dropout_scale_fn(
237
+ self.mlp(self.norm2(x)), None, scale, x, self.dropout
238
+ )
239
+ return x
240
+
241
+
242
+ class EmbeddingLayer(nn.Module):
243
+ def __init__(self, dim, vocab_dim):
244
+ super().__init__()
245
+ self.embedding = nn.Parameter(
246
+ torch.empty((vocab_dim, dim))
247
+ )
248
+ torch.nn.init.kaiming_uniform_(
249
+ self.embedding, a=math.sqrt(5)
250
+ )
251
+
252
+ def forward(self, x):
253
+ return self.embedding[x]
254
+
255
+
256
+ class DDitFinalLayer(nn.Module):
257
+ def __init__(
258
+ self, hidden_size, out_channels, cond_dim, causal=False
259
+ ):
260
+ super().__init__()
261
+ self.causal = causal
262
+ assert causal == True
263
+
264
+ self.norm_final = LayerNorm(hidden_size)
265
+ self.linear = nn.Linear(hidden_size, out_channels)
266
+ self.linear.weight.data.zero_()
267
+ self.linear.bias.data.zero_()
268
+
269
+ def forward(self, x, c):
270
+ return self.linear(self.norm_final(x))
271
+
272
+
273
+ class DDIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
274
+ def __init__(self, config, vocab_size: int):
275
+ super().__init__()
276
+ if type(config) == dict:
277
+ config = omegaconf.OmegaConf.create(config)
278
+
279
+ self.config = config
280
+ self.vocab_size = vocab_size
281
+ self.causal = (
282
+ hasattr(config.model, 'causal')
283
+ and config.model.causal
284
+ )
285
+ assert self.causal == True
286
+
287
+ self.vocab_embed = EmbeddingLayer(
288
+ config.model.hidden_size, vocab_size
289
+ )
290
+ self.rotary_emb = Rotary(
291
+ config.model.hidden_size // config.model.n_heads
292
+ )
293
+
294
+ blocks = []
295
+ for _ in range(config.model.n_blocks):
296
+ blocks.append(
297
+ DDiTBlock(
298
+ config.model.hidden_size,
299
+ config.model.n_heads,
300
+ config.model.cond_dim,
301
+ dropout=config.model.dropout,
302
+ causal=self.causal,
303
+ )
304
+ )
305
+ self.blocks = nn.ModuleList(blocks)
306
+
307
+ self.output_layer = DDitFinalLayer(
308
+ config.model.hidden_size,
309
+ vocab_size,
310
+ config.model.cond_dim,
311
+ causal=self.causal,
312
+ )
313
+ self.scale_by_sigma = config.model.scale_by_sigma
314
+
315
+ def _get_bias_dropout_scale(self):
316
+ if self.training:
317
+ return bias_dropout_add_scale_fused_train
318
+ else:
319
+ return bias_dropout_add_scale_fused_inference
320
+
321
+
322
+ class AR(DDIT):
323
+ def __init__(self, config, vocab_size, mask_index):
324
+ super().__init__(config, vocab_size)
325
+ self.mask_index = mask_index
326
+ self.neg_infinity = -1000.0
327
+
328
+ def forward(self, xt, sigma):
329
+ """Forward pass of the denoising model.
330
+
331
+ Args:
332
+ xt: int torch.Tensor with shape
333
+ (batch_size, diffusion_model_input_length), token ids.
334
+ sigma: float torch.Tensor with shape
335
+ (batch_size).
336
+
337
+ Returns:
338
+ log probability with shape
339
+ (batch_size, diffusion_model_input_length, vocab_size)
340
+ """
341
+ x = self.vocab_embed(xt)
342
+
343
+ rotary_cos_sin = self.rotary_emb(x)
344
+
345
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
346
+ for i in range(len(self.blocks)):
347
+ x = self.blocks[i](
348
+ x, rotary_cos_sin, None, seqlens=None
349
+ )
350
+ output = self.output_layer(x, None)
351
+
352
+ # log prob at the mask index = - infinity
353
+ output[:, :, self.mask_index] = self.neg_infinity
354
+
355
+ # Normalize the logits such that x.exp() is
356
+ # a probability distribution over vocab_size.
357
+ # x = x - torch.logsumexp(x, dim=-1, keepdim=True)
358
+ return output.log_softmax(-1)
models/dimamba.py ADDED
@@ -0,0 +1,1136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import huggingface_hub
6
+ import numpy as np
7
+ import omegaconf
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from causal_conv1d import (
12
+ causal_conv1d_fn,
13
+ causal_conv1d_update,
14
+ )
15
+ from einops import rearrange, repeat
16
+ from mamba_ssm.ops.selective_scan_interface import (
17
+ mamba_inner_fn,
18
+ selective_scan_fn,
19
+ )
20
+ from torch import Tensor
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithNoAttention,
24
+ MaskedLMOutput,
25
+ )
26
+
27
+ try:
28
+ from mamba_ssm.ops.triton.layernorm import (
29
+ RMSNorm,
30
+ layer_norm_fn,
31
+ rms_norm_fn,
32
+ )
33
+ except ImportError:
34
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
35
+ from mamba_ssm.ops.triton.selective_state_update import (
36
+ selective_state_update,
37
+ )
38
+
39
+ from models.dit import (
40
+ TimestepEmbedder,
41
+ bias_dropout_add_scale_fused_inference,
42
+ bias_dropout_add_scale_fused_train,
43
+ modulate_fused,
44
+ )
45
+
46
+ # sys.path.append('mamba_wrappers/mamba2')
47
+ # from .mamba2.src.modules.ssd import SSD as Mamba
48
+
49
+
50
+ class Mamba(nn.Module):
51
+ def __init__(
52
+ self,
53
+ d_model,
54
+ d_state=16,
55
+ d_conv=4,
56
+ expand=2,
57
+ dt_rank='auto',
58
+ dt_min=0.001,
59
+ dt_max=0.1,
60
+ dt_init='random',
61
+ dt_scale=1.0,
62
+ dt_init_floor=1e-4,
63
+ conv_bias=True,
64
+ bias=False,
65
+ use_fast_path=True, # Fused kernel options
66
+ layer_idx=None,
67
+ device=None,
68
+ dtype=None,
69
+ ):
70
+ factory_kwargs = {'device': device, 'dtype': dtype}
71
+ super().__init__()
72
+ self.d_model = d_model
73
+ self.d_state = d_state
74
+ self.d_conv = d_conv
75
+ self.expand = expand
76
+ self.d_inner = int(self.expand * self.d_model)
77
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == 'auto' else dt_rank
78
+ self.use_fast_path = use_fast_path
79
+ self.layer_idx = layer_idx
80
+
81
+ self.in_proj = nn.Linear(
82
+ self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
83
+ )
84
+
85
+ self.conv1d = nn.Conv1d(
86
+ in_channels=self.d_inner,
87
+ out_channels=self.d_inner,
88
+ bias=conv_bias,
89
+ kernel_size=d_conv,
90
+ groups=self.d_inner,
91
+ padding=d_conv - 1,
92
+ **factory_kwargs,
93
+ )
94
+
95
+ self.activation = 'silu'
96
+ self.act = nn.SiLU()
97
+
98
+ self.x_proj = nn.Linear(
99
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
100
+ )
101
+ self.dt_proj = nn.Linear(
102
+ self.dt_rank, self.d_inner, bias=True, **factory_kwargs
103
+ )
104
+
105
+ # Initialize special dt projection to preserve variance at initialization
106
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
107
+ if dt_init == 'constant':
108
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
109
+ elif dt_init == 'random':
110
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
111
+ else:
112
+ raise NotImplementedError
113
+
114
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
115
+ dt = torch.exp(
116
+ torch.rand(self.d_inner, **factory_kwargs)
117
+ * (math.log(dt_max) - math.log(dt_min))
118
+ + math.log(dt_min)
119
+ ).clamp(min=dt_init_floor)
120
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
121
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
122
+ with torch.no_grad():
123
+ self.dt_proj.bias.copy_(inv_dt)
124
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
125
+ self.dt_proj.bias._no_reinit = True
126
+
127
+ # S4D real initialization
128
+ A = repeat(
129
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
130
+ 'n -> d n',
131
+ d=self.d_inner,
132
+ ).contiguous()
133
+ A_log = torch.log(A) # Keep A_log in fp32
134
+ self.A_log = nn.Parameter(A_log)
135
+ self.A_log._no_weight_decay = True
136
+
137
+ # D 'skip' parameter
138
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
139
+ self.D._no_weight_decay = True
140
+
141
+ self.out_proj = nn.Linear(
142
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
143
+ )
144
+
145
+ def forward(self, hidden_states, inference_params=None):
146
+ """
147
+ hidden_states: (B, L, D)
148
+ Returns: same shape as hidden_states
149
+ """
150
+ batch, seqlen, dim = hidden_states.shape
151
+
152
+ conv_state, ssm_state = None, None
153
+ if inference_params is not None:
154
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
155
+ if inference_params.seqlen_offset > 0:
156
+ # The states are updated inplace
157
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
158
+ return out
159
+
160
+ # We do matmul and transpose BLH -> HBL at the same time
161
+ xz = rearrange(
162
+ self.in_proj.weight @ rearrange(hidden_states, 'b l d -> d (b l)'),
163
+ 'd (b l) -> b d l',
164
+ l=seqlen,
165
+ )
166
+ if self.in_proj.bias is not None:
167
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), 'd -> d 1')
168
+
169
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
170
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
171
+
172
+ if (
173
+ self.use_fast_path
174
+ and causal_conv1d_fn is not None
175
+ and inference_params is None
176
+ ): # Doesn't support outputting the states
177
+ out = mamba_inner_fn(
178
+ xz,
179
+ self.conv1d.weight,
180
+ self.conv1d.bias,
181
+ self.x_proj.weight,
182
+ self.dt_proj.weight,
183
+ self.out_proj.weight,
184
+ self.out_proj.bias,
185
+ A,
186
+ None, # input-dependent B
187
+ None, # input-dependent C
188
+ self.D.float(),
189
+ delta_bias=self.dt_proj.bias.float(),
190
+ delta_softplus=True,
191
+ )
192
+
193
+ else:
194
+ x, z = xz.chunk(2, dim=1)
195
+ # Compute short convolution
196
+ if conv_state is not None:
197
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
198
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
199
+ conv_state.copy_(
200
+ F.pad(x, (self.d_conv - x.shape[-1], 0))
201
+ ) # Update state (B D W)
202
+ if causal_conv1d_fn is None:
203
+ x = self.act(self.conv1d(x)[..., :seqlen])
204
+ else:
205
+ assert self.activation in ['silu', 'swish']
206
+ x = causal_conv1d_fn(
207
+ x=x,
208
+ weight=rearrange(self.conv1d.weight, 'd 1 w -> d w'),
209
+ bias=self.conv1d.bias,
210
+ activation=self.activation,
211
+ )
212
+
213
+ # We're careful here about the layout, to avoid extra transposes.
214
+ # We want dt to have d as the slowest moving dimension
215
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
216
+ x_dbl = self.x_proj(rearrange(x, 'b d l -> (b l) d')) # (bl d)
217
+ dt, B, C = torch.split(
218
+ x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
219
+ )
220
+ dt = self.dt_proj.weight @ dt.t()
221
+ dt = rearrange(dt, 'd (b l) -> b d l', l=seqlen)
222
+ B = rearrange(B, '(b l) dstate -> b dstate l', l=seqlen).contiguous()
223
+ C = rearrange(C, '(b l) dstate -> b dstate l', l=seqlen).contiguous()
224
+
225
+ assert self.activation in ['silu', 'swish']
226
+
227
+ y = selective_scan_fn(
228
+ x,
229
+ dt,
230
+ A,
231
+ B,
232
+ C,
233
+ self.D.float(),
234
+ z=z,
235
+ delta_bias=self.dt_proj.bias.float(),
236
+ delta_softplus=True,
237
+ return_last_state=ssm_state is not None,
238
+ )
239
+
240
+ if ssm_state is not None:
241
+ y, last_state = y
242
+ ssm_state.copy_(last_state)
243
+ y = rearrange(y, 'b d l -> b l d')
244
+
245
+ out = self.out_proj(y)
246
+
247
+ return out
248
+
249
+ def step(self, hidden_states, conv_state, ssm_state):
250
+ dtype = hidden_states.dtype
251
+ assert (
252
+ hidden_states.shape[1] == 1
253
+ ), 'Only support decoding with 1 token at a time for now'
254
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
255
+ x, z = xz.chunk(2, dim=-1) # (B D)
256
+
257
+ # Conv step
258
+ if causal_conv1d_update is None:
259
+ conv_state.copy_(
260
+ torch.roll(conv_state, shifts=-1, dims=-1)
261
+ ) # Update state (B D W)
262
+ conv_state[:, :, -1] = x
263
+ x = torch.sum(
264
+ conv_state * rearrange(self.conv1d.weight, 'd 1 w -> d w'), dim=-1
265
+ ) # (B D)
266
+ if self.conv1d.bias is not None:
267
+ x = x + self.conv1d.bias
268
+ x = self.act(x).to(dtype=dtype)
269
+ else:
270
+ x = causal_conv1d_update(
271
+ x,
272
+ conv_state,
273
+ rearrange(self.conv1d.weight, 'd 1 w -> d w'),
274
+ self.conv1d.bias,
275
+ self.activation,
276
+ )
277
+
278
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
279
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
280
+ # Don't add dt_bias here
281
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
282
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
283
+
284
+ # SSM step
285
+ if selective_state_update is None:
286
+ # Discretize A and B
287
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
288
+ dA = torch.exp(torch.einsum('bd,dn->bdn', dt, A))
289
+ dB = torch.einsum('bd,bn->bdn', dt, B)
290
+ ssm_state.copy_(ssm_state * dA + rearrange(x, 'b d -> b d 1') * dB)
291
+ y = torch.einsum('bdn,bn->bd', ssm_state.to(dtype), C)
292
+ y = y + self.D.to(dtype) * x
293
+ y = y * self.act(z) # (B D)
294
+ else:
295
+ y = selective_state_update(
296
+ ssm_state,
297
+ x,
298
+ dt,
299
+ A,
300
+ B,
301
+ C,
302
+ self.D,
303
+ z=z,
304
+ dt_bias=self.dt_proj.bias,
305
+ dt_softplus=True,
306
+ )
307
+
308
+ out = self.out_proj(y)
309
+ return out.unsqueeze(1), conv_state, ssm_state
310
+
311
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
312
+ device = self.out_proj.weight.device
313
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
314
+ conv_state = torch.zeros(
315
+ batch_size,
316
+ self.d_model * self.expand,
317
+ self.d_conv,
318
+ device=device,
319
+ dtype=conv_dtype,
320
+ )
321
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
322
+ # ssm_dtype = torch.float32
323
+ ssm_state = torch.zeros(
324
+ batch_size,
325
+ self.d_model * self.expand,
326
+ self.d_state,
327
+ device=device,
328
+ dtype=ssm_dtype,
329
+ )
330
+ return conv_state, ssm_state
331
+
332
+ def _get_states_from_cache(
333
+ self, inference_params, batch_size, initialize_states=False
334
+ ):
335
+ assert self.layer_idx is not None
336
+ if self.layer_idx not in inference_params.key_value_memory_dict:
337
+ batch_shape = (batch_size,)
338
+ conv_state = torch.zeros(
339
+ batch_size,
340
+ self.d_model * self.expand,
341
+ self.d_conv,
342
+ device=self.conv1d.weight.device,
343
+ dtype=self.conv1d.weight.dtype,
344
+ )
345
+ ssm_state = torch.zeros(
346
+ batch_size,
347
+ self.d_model * self.expand,
348
+ self.d_state,
349
+ device=self.dt_proj.weight.device,
350
+ dtype=self.dt_proj.weight.dtype,
351
+ # dtype=torch.float32,
352
+ )
353
+ inference_params.key_value_memory_dict[self.layer_idx] = (
354
+ conv_state,
355
+ ssm_state,
356
+ )
357
+ else:
358
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
359
+ self.layer_idx
360
+ ]
361
+ # TODO: What if batch size changes between generation, and we reuse the same states?
362
+ if initialize_states:
363
+ conv_state.zero_()
364
+ ssm_state.zero_()
365
+ return conv_state, ssm_state
366
+
367
+
368
+ class Block(nn.Module):
369
+ def __init__(
370
+ self,
371
+ dim,
372
+ mixer_cls,
373
+ norm_cls=nn.LayerNorm,
374
+ fused_add_norm=False,
375
+ residual_in_fp32=False,
376
+ modulate=False,
377
+ t_dim=0,
378
+ ):
379
+ """
380
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection'
381
+
382
+ This Block has a slightly different structure compared to a regular
383
+ prenorm Transformer block.
384
+ The standard block is: LN -> MHA/MLP -> Add.
385
+ [Ref: https://arxiv.org/abs/2002.04745]
386
+ Here we have: Add -> LN -> Mixer, returning both
387
+ the hidden_states (output of the mixer) and the residual.
388
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
389
+ The residual needs to be provided (except for the very first block).
390
+ """
391
+ super().__init__()
392
+ self.residual_in_fp32 = residual_in_fp32
393
+ self.fused_add_norm = fused_add_norm
394
+ self.mixer = mixer_cls(dim)
395
+ self.norm = norm_cls(dim)
396
+
397
+ if self.fused_add_norm:
398
+ assert RMSNorm is not None, 'RMSNorm import fails'
399
+ assert isinstance(
400
+ self.norm, (nn.LayerNorm, RMSNorm)
401
+ ), 'Only LayerNorm and RMSNorm are supported for fused_add_norm'
402
+
403
+ self.dropout = 0.1
404
+
405
+ self.modulate = modulate
406
+ self.t_dim = t_dim
407
+ if modulate:
408
+ self.adaLN_modulation = nn.Linear(t_dim,
409
+ 3 * dim,
410
+ bias=True)
411
+ self.adaLN_modulation.weight.data.zero_()
412
+ self.adaLN_modulation.bias.data.zero_()
413
+
414
+ def _get_bias_dropout_scale(self):
415
+ return (
416
+ bias_dropout_add_scale_fused_train
417
+ if self.training
418
+ else bias_dropout_add_scale_fused_inference
419
+ )
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: Tensor,
424
+ residual: Optional[Tensor] = None,
425
+ inference_params=None,
426
+ time_embeds=None,
427
+ ):
428
+ r"""Pass the input through the encoder layer.
429
+
430
+ Args:
431
+ hidden_states: the sequence to the encoder layer (required).
432
+ residual: hidden_states = Mixer(LN(residual))
433
+ """
434
+ if not self.fused_add_norm:
435
+ residual = (
436
+ (hidden_states + residual)
437
+ if residual is not None
438
+ else hidden_states
439
+ )
440
+
441
+ hidden_states = self.norm(
442
+ residual.to(dtype=self.norm.weight.dtype))
443
+ if self.residual_in_fp32:
444
+ residual = residual.to(torch.float32)
445
+ else:
446
+ fused_add_norm_fn = (
447
+ rms_norm_fn
448
+ if isinstance(self.norm, RMSNorm)
449
+ else layer_norm_fn
450
+ )
451
+
452
+ hidden_states, residual = fused_add_norm_fn(
453
+ hidden_states,
454
+ self.norm.weight,
455
+ self.norm.bias,
456
+ residual=residual,
457
+ prenorm=True,
458
+ residual_in_fp32=self.residual_in_fp32,
459
+ eps=self.norm.eps)
460
+
461
+ if self.modulate and time_embeds is not None:
462
+ (shift_msa,
463
+ scale_msa,
464
+ gate_msa) = self.adaLN_modulation(
465
+ time_embeds)[:, None].chunk(3, dim=-1)
466
+ hidden_states = modulate_fused(hidden_states,
467
+ shift_msa,
468
+ scale_msa)
469
+
470
+ mixer_out = self.mixer(hidden_states, inference_params=inference_params)
471
+
472
+ hidden_states = mixer_out
473
+ if self.modulate and time_embeds is not None:
474
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
475
+ hidden_states = bias_dropout_scale_fn(
476
+ hidden_states,
477
+ None,
478
+ gate_msa,
479
+ residual,
480
+ self.dropout)
481
+
482
+ return hidden_states, residual
483
+
484
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
485
+ return self.mixer.allocate_inference_cache(
486
+ batch_size, max_seqlen, dtype=dtype, **kwargs)
487
+
488
+ class BiMambaConfig(PretrainedConfig):
489
+ """Config that extends the original MambaConfig with params relevant to bi-directionality."""
490
+
491
+ model_type = 'bimamba'
492
+
493
+ def __init__(
494
+ self,
495
+ # From original MambaConfig
496
+ d_model: int = 2560,
497
+ n_layer: int = 64,
498
+ vocab_size: int = 50277,
499
+ ssm_cfg: Optional[dict] = None,
500
+ rms_norm: bool = True,
501
+ residual_in_fp32: bool = True,
502
+ fused_add_norm: bool = True,
503
+ pad_vocab_size_multiple: int = 8,
504
+ tie_word_embeddings: bool = True,
505
+ # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
506
+ norm_epsilon: float = 1e-5,
507
+ # Used in init_weights
508
+ initializer_cfg: Optional[dict] = None,
509
+ # Caduceus-specific params
510
+ bidirectional: bool = True,
511
+ bidirectional_strategy: Union[str, None] = 'add',
512
+ bidirectional_weight_tie: bool = True,
513
+ temb_strategy: Union[str, None] = None,
514
+ d_temb: int = 0,
515
+ **kwargs,
516
+ ):
517
+ super().__init__(**kwargs)
518
+ self.d_model = d_model
519
+ self.n_layer = n_layer
520
+ self.vocab_size = vocab_size
521
+ self.ssm_cfg = ssm_cfg
522
+ self.rms_norm = rms_norm
523
+ self.residual_in_fp32 = residual_in_fp32
524
+ self.fused_add_norm = fused_add_norm
525
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
526
+ self.tie_word_embeddings = tie_word_embeddings
527
+ self.norm_epsilon = norm_epsilon
528
+ self.initializer_cfg = initializer_cfg
529
+ self.bidirectional = bidirectional
530
+ self.bidirectional_strategy = bidirectional_strategy
531
+ self.bidirectional_weight_tie = bidirectional_weight_tie
532
+
533
+ self.temb_strategy = temb_strategy
534
+ self.d_temb = d_temb
535
+
536
+
537
+ def create_block(
538
+ d_model,
539
+ ssm_cfg=None,
540
+ norm_epsilon=1e-5,
541
+ rms_norm=False,
542
+ residual_in_fp32=False,
543
+ fused_add_norm=False,
544
+ layer_idx=None,
545
+ bidirectional=True,
546
+ bidirectional_strategy='add',
547
+ bidirectional_weight_tie=True,
548
+ device=None,
549
+ dtype=None,
550
+ modulate=False,
551
+ d_temb=0,
552
+ ):
553
+ """Create BiMamba block.
554
+
555
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
556
+ """
557
+ if ssm_cfg is None:
558
+ ssm_cfg = {}
559
+ factory_kwargs = {'device': device, 'dtype': dtype}
560
+ bidirectional_kwargs = {
561
+ 'bidirectional': bidirectional,
562
+ 'bidirectional_strategy': bidirectional_strategy,
563
+ 'bidirectional_weight_tie': bidirectional_weight_tie,
564
+ }
565
+ mixer_cls = partial(
566
+ BiMambaWrapper,
567
+ layer_idx=layer_idx,
568
+ **ssm_cfg,
569
+ **bidirectional_kwargs,
570
+ **factory_kwargs,
571
+ )
572
+ norm_cls = partial(
573
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
574
+ )
575
+ block_cls = Block
576
+ block = block_cls(
577
+ d_model,
578
+ mixer_cls,
579
+ norm_cls=norm_cls,
580
+ fused_add_norm=fused_add_norm,
581
+ residual_in_fp32=residual_in_fp32,
582
+ t_dim=d_temb,
583
+ modulate=modulate,
584
+ )
585
+ block.layer_idx = layer_idx
586
+
587
+ return block
588
+
589
+
590
+ class BiMambaWrapper(nn.Module):
591
+ """Thin wrapper around Mamba to support bi-directionality."""
592
+
593
+ def __init__(
594
+ self,
595
+ d_model: int,
596
+ bidirectional: bool = True,
597
+ bidirectional_strategy: Optional[str] = 'add',
598
+ bidirectional_weight_tie: bool = True,
599
+ **mamba_kwargs,
600
+ ):
601
+ super().__init__()
602
+ if bidirectional and bidirectional_strategy is None:
603
+ bidirectional_strategy = 'add' # Default strategy: `add`
604
+ if bidirectional and bidirectional_strategy not in ['add', 'ew_multiply']:
605
+ raise NotImplementedError(
606
+ f'`{bidirectional_strategy}` strategy for bi-directionality is not implemented!'
607
+ )
608
+ self.bidirectional = bidirectional
609
+ self.bidirectional_strategy = bidirectional_strategy
610
+
611
+ self.mamba_fwd = Mamba(d_model=d_model, **mamba_kwargs)
612
+
613
+ self.mamba_rev = None
614
+ if bidirectional:
615
+ self.mamba_rev = Mamba(d_model=d_model, **mamba_kwargs)
616
+ if (
617
+ bidirectional_weight_tie
618
+ ): # Tie in and out projections (where most of param count lies)
619
+ self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
620
+ self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
621
+ self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
622
+ self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
623
+ else:
624
+ self.mamba_rev = None
625
+
626
+ def forward(self, hidden_states, inference_params=None):
627
+ """Bidirectional-enabled forward pass
628
+
629
+ hidden_states: (B, L, D)
630
+ Returns: same shape as hidden_states
631
+ """
632
+
633
+ out = self.mamba_fwd(
634
+ hidden_states,
635
+ inference_params=inference_params,
636
+ )
637
+
638
+ if self.bidirectional:
639
+
640
+ hidden_states_flipped = torch.flip(hidden_states, dims=(1,))
641
+
642
+ out_rev = self.mamba_rev(
643
+ hidden_states_flipped, # Flip along the sequence length dimension
644
+ inference_params=inference_params,
645
+ )
646
+
647
+ out_rev_flipped = torch.flip(out_rev, dims=(1,))
648
+ if self.bidirectional_strategy == 'add':
649
+ out = (
650
+ out + out_rev_flipped
651
+ ) # Flip back for combining with forward hidden states
652
+ elif self.bidirectional_strategy == 'ew_multiply':
653
+ out = out * out_rev_flipped
654
+ else:
655
+ raise NotImplementedError(
656
+ f'`{self.bidirectional_strategy}` for bi-directionality not implemented!'
657
+ )
658
+
659
+ return out
660
+
661
+
662
+ class BiMambaEmbeddings(nn.Module):
663
+ def __init__(
664
+ self,
665
+ config: BiMambaConfig,
666
+ input_dim=None,
667
+ device=None,
668
+ dtype=None,
669
+ ):
670
+ super().__init__()
671
+ factory_kwargs = {'device': device, 'dtype': dtype}
672
+ if input_dim is None:
673
+ input_dim = config.vocab_size
674
+ self.word_embeddings = nn.Embedding(
675
+ input_dim, config.d_model, **factory_kwargs
676
+ )
677
+
678
+ def forward(self, input_ids):
679
+ """
680
+ input_ids: (batch, seqlen)
681
+ """
682
+ return self.word_embeddings(input_ids)
683
+
684
+
685
+ class BiMambaMixerModel(nn.Module):
686
+ def __init__(
687
+ self,
688
+ config: BiMambaConfig,
689
+ device=None,
690
+ dtype=None,
691
+ ) -> None:
692
+ super().__init__()
693
+ factory_kwargs = {'device': device, 'dtype': dtype}
694
+ self.temb_strategy = config.temb_strategy
695
+ self.config = config
696
+ input_dim = config.vocab_size
697
+ d_model = config.d_model
698
+ if self.temb_strategy and self.temb_strategy == 'concat':
699
+ input_dim += config.d_temb
700
+ d_model += config.d_temb
701
+ if self.temb_strategy is None:
702
+ config.d_temb = 0
703
+
704
+ self.fused_add_norm = config.fused_add_norm
705
+ self.residual_in_fp32 = config.residual_in_fp32
706
+
707
+ self.embeddings = BiMambaEmbeddings(
708
+ config,input_dim=input_dim, **factory_kwargs)
709
+
710
+ # Mamba changes the order of residual and layer norm:
711
+ # Instead of LN -> Attn / MLP -> Add, we do:
712
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
713
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
714
+ # This is for performance reason: we can fuse add + layer_norm.
715
+ if config.fused_add_norm:
716
+ if layer_norm_fn is None or rms_norm_fn is None:
717
+ raise ImportError('Failed to import Triton LayerNorm / RMSNorm kernels')
718
+
719
+ self.layers = nn.ModuleList(
720
+ [
721
+ create_block(
722
+ d_model,
723
+ ssm_cfg=config.ssm_cfg,
724
+ norm_epsilon=config.norm_epsilon,
725
+ rms_norm=config.rms_norm,
726
+ residual_in_fp32=config.residual_in_fp32,
727
+ fused_add_norm=config.fused_add_norm,
728
+ layer_idx=i,
729
+ bidirectional=config.bidirectional,
730
+ bidirectional_strategy=config.bidirectional_strategy,
731
+ bidirectional_weight_tie=config.bidirectional_weight_tie,
732
+ modulate=True if config.temb_strategy and 'adaln' in config.temb_strategy else False,
733
+ d_temb=config.d_temb,
734
+ **factory_kwargs,
735
+ )
736
+ for i in range(config.n_layer)
737
+ ]
738
+ )
739
+
740
+ if self.temb_strategy and 'adaln' in self.temb_strategy:
741
+ self.adaLN_modulation_final = nn.Linear(
742
+ config.d_temb, 2 * d_model, bias=True
743
+ )
744
+ self.adaLN_modulation_final.weight.data.zero_()
745
+ self.adaLN_modulation_final.bias.data.zero_()
746
+
747
+ norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
748
+ d_model, eps=config.norm_epsilon, **factory_kwargs
749
+ )
750
+ self.norm_f = norm_f
751
+
752
+ def pre_apply_temb(self, input_embeds, time_embeds):
753
+ """Prepend/add time embeddings to input embeddings at the start of the forward pass.
754
+
755
+ Args:
756
+ input_embeds: Input embeddings. (batch, seqlen, d_model)
757
+ time_embeds: Timestep embeddings. (batch, d_temb)
758
+ Returns:
759
+ if self.temb_strategy == 'concat':
760
+ input_embeds: (batch, seqlen, d_model + d_temb)
761
+ if self.temb_strategy == 'add':
762
+ input_embeds: (batch, seqlen, d_model)
763
+ """
764
+ if self.temb_strategy == 'concat':
765
+ input_embeds = torch.cat([time_embeds.unsqueeze(1).tile(
766
+ 1, input_embeds.shape[1], 1), input_embeds], axis=-1)
767
+ elif self.temb_strategy == 'add':
768
+ input_embeds += time_embeds.unsqueeze(1).tile(1, input_embeds.shape[1], 1)
769
+ return input_embeds
770
+
771
+ def forward(
772
+ self,
773
+ input_ids,
774
+ inputs_embeds=None,
775
+ output_hidden_states=False,
776
+ time_embeds=None,
777
+ ):
778
+ """Mixer forward."""
779
+ all_hidden_states = []
780
+ if inputs_embeds is not None:
781
+ hidden_states = inputs_embeds
782
+ else:
783
+ hidden_states = self.embeddings(input_ids)
784
+ if (
785
+ time_embeds is not None
786
+ and self.temb_strategy in ['concat', 'add']
787
+ ):
788
+ hidden_states = self.pre_apply_temb(hidden_states, time_embeds)
789
+
790
+ residual = None
791
+
792
+ for ind, layer in enumerate(self.layers):
793
+ if output_hidden_states:
794
+ all_hidden_states.append(hidden_states)
795
+ # TODO: Add support for gradient checkpointing
796
+ layer_out = layer(
797
+ hidden_states, residual, inference_params=None, time_embeds=time_embeds
798
+ )
799
+
800
+ hidden_states, residuals = layer_out
801
+
802
+ if not self.fused_add_norm:
803
+ if self.temb_strategy and 'adaln' in self.temb_strategy:
804
+ raise NotImplementedError('adaln only implemented for fused_add_norm')
805
+ residual = (
806
+ (hidden_states + residual) if residual is not None else hidden_states
807
+ )
808
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
809
+ else:
810
+ if time_embeds is not None and self.temb_strategy and 'adaln' in self.temb_strategy:
811
+ shift, scale = self.adaLN_modulation_final(time_embeds)[:, None].chunk(
812
+ 2, dim=2
813
+ )
814
+
815
+ fused_add_norm_fn = (
816
+ rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
817
+ )
818
+
819
+ # Set prenorm=False here since we don't need the residual
820
+ hidden_states = fused_add_norm_fn(
821
+ hidden_states,
822
+ self.norm_f.weight,
823
+ self.norm_f.bias,
824
+ eps=self.norm_f.eps,
825
+ residual=residual,
826
+ prenorm=False,
827
+ residual_in_fp32=self.residual_in_fp32,
828
+ )
829
+ if time_embeds is not None and self.temb_strategy and 'adaln' in self.temb_strategy:
830
+ hidden_states = modulate_fused(hidden_states, shift, scale)
831
+
832
+ if output_hidden_states:
833
+ all_hidden_states.append(hidden_states)
834
+
835
+ return hidden_states, all_hidden_states
836
+
837
+
838
+ def cross_entropy(logits, y, ignore_index=-100):
839
+ """Cross entropy loss."""
840
+ logits = logits.view(-1, logits.shape[-1])
841
+ y = y.view(-1)
842
+ return F.cross_entropy(logits, y, ignore_index=ignore_index)
843
+
844
+
845
+ def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
846
+ """Weighted cross entropy loss (discounts certain tokens)."""
847
+ logits = logits.view(-1, logits.shape[-1])
848
+ y = y.view(-1)
849
+ ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction='none')
850
+ loss_weights = loss_weights.view(-1)
851
+ loss_weights[y == ignore_index] = 0.0
852
+ # TODO: Follows GPN implementation, but should we remove weight normalization?
853
+ return (ce * (loss_weights / loss_weights.sum())).sum()
854
+
855
+
856
+ class BiMambaPreTrainedModel(PreTrainedModel):
857
+ """PreTrainedModel wrapper for BiMamba backbone."""
858
+
859
+ config_class = BiMambaConfig
860
+ base_model_prefix = 'bimamba'
861
+ supports_gradient_checkpointing = False
862
+ _no_split_modules = ['BiMambaWrapper']
863
+
864
+ def _init_weights(
865
+ self,
866
+ module,
867
+ initializer_range=0.02, # Now only used for embedding layer.
868
+ **kwargs,
869
+ ):
870
+ """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
871
+
872
+ n_layer = self.config.n_layer
873
+ initialized_cfg = (
874
+ self.config.initializer_cfg
875
+ if self.config.initializer_cfg is not None
876
+ else {}
877
+ )
878
+ rescale_prenorm_residual = initialized_cfg.get('rescale_prenorm_residual', True)
879
+ initializer_range = initialized_cfg.get('initializer_range', initializer_range)
880
+ n_residuals_per_layer = initialized_cfg.get('n_residuals_per_layer', 1)
881
+
882
+ if isinstance(module, nn.Linear):
883
+ if module.bias is not None:
884
+ if not getattr(module.bias, '_no_reinit', False):
885
+ nn.init.zeros_(module.bias)
886
+ elif isinstance(module, nn.Embedding):
887
+ nn.init.normal_(module.weight, std=initializer_range)
888
+
889
+ if rescale_prenorm_residual:
890
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
891
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth.
892
+ # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
893
+ # residual layers.
894
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
895
+ #
896
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
897
+ for name, p in module.named_parameters():
898
+ if name in ['out_proj.weight', 'fc2.weight']:
899
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
900
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
901
+ # We need to reinit p since this code could be called multiple times
902
+ # Having just p *= scale would repeatedly scale it down
903
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
904
+ with torch.no_grad():
905
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
906
+
907
+
908
+ class BiMamba(BiMambaPreTrainedModel):
909
+ """BiMamba model that can be instantiated using HF patterns."""
910
+
911
+ def __init__(self, config: BiMambaConfig, device=None, dtype=None, **kwargs):
912
+ super().__init__(config)
913
+
914
+ # Adjust vocab size if vocab padding is set.
915
+ if config.vocab_size % config.pad_vocab_size_multiple != 0:
916
+ config.vocab_size += config.pad_vocab_size_multiple - (
917
+ config.vocab_size % config.pad_vocab_size_multiple
918
+ )
919
+
920
+ self.config = config
921
+ factory_kwargs = {'device': device, 'dtype': dtype}
922
+ self.backbone = BiMambaMixerModel(config, **factory_kwargs, **kwargs)
923
+
924
+ def forward(
925
+ self,
926
+ input_ids: torch.LongTensor = None,
927
+ inputs_embeds: Optional[torch.FloatTensor] = None,
928
+ output_hidden_states: Optional[bool] = None,
929
+ return_dict: Optional[bool] = None,
930
+ time_embeds: Optional[bool] = None,
931
+ ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
932
+ """HF-compatible forward method."""
933
+ output_hidden_states = (
934
+ output_hidden_states
935
+ if output_hidden_states is not None
936
+ else self.config.output_hidden_states
937
+ )
938
+ return_dict = (
939
+ return_dict if return_dict is not None else self.config.use_return_dict
940
+ )
941
+
942
+ backbone_out = self.backbone(
943
+ input_ids,
944
+ inputs_embeds=inputs_embeds,
945
+ output_hidden_states=output_hidden_states,
946
+ time_embeds=time_embeds,
947
+ )
948
+
949
+ hidden_states, all_hidden_states = backbone_out
950
+
951
+ if return_dict:
952
+ return BaseModelOutputWithNoAttention(
953
+ last_hidden_state=hidden_states,
954
+ hidden_states=all_hidden_states if output_hidden_states else None,
955
+ )
956
+ elif output_hidden_states:
957
+ return hidden_states, all_hidden_states
958
+ else:
959
+ return hidden_states
960
+
961
+
962
+ class BiMambaForMaskedLM(BiMambaPreTrainedModel):
963
+ """HF-compatible BiMamba model for masked language modeling."""
964
+
965
+ def __init__(self, config: BiMambaConfig, device=None, dtype=None, **kwargs):
966
+ super().__init__(config, **kwargs)
967
+ factory_kwargs = {'device': device, 'dtype': dtype}
968
+ self.bimamba = BiMamba(config, **factory_kwargs, **kwargs)
969
+ self.config = config
970
+ self.temb_strategy = config.temb_strategy
971
+ lm_head_in_dim = config.d_model
972
+ # LM head may only take in concatenated timestep embeddings
973
+ # if its weights are not tied to the vocab embedding
974
+ if (
975
+ not config.tie_word_embeddings
976
+ and config.temb_strategy == 'concat'
977
+ ):
978
+ lm_head_in_dim += config.d_temb
979
+ self.lm_head = nn.Linear(
980
+ lm_head_in_dim,
981
+ self.config.vocab_size, # Use BiMamba config as it might have been updated
982
+ bias=False,
983
+ **factory_kwargs,
984
+ )
985
+ # Initialize weights and apply final processing
986
+ self.post_init()
987
+ if self.config.tie_word_embeddings:
988
+ self.tie_weights()
989
+
990
+ def init_weights(self):
991
+ """
992
+ If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
993
+ initialization logic in `_init_weights`.
994
+ """
995
+
996
+ # Initialize weights
997
+ self.apply(self._initialize_weights)
998
+
999
+ # Tie weights should be skipped when not initializing all weights
1000
+ # since from_pretrained(...) calls tie weights anyways
1001
+
1002
+ def post_init(self):
1003
+ """
1004
+ A method executed at the end of each Transformer model initialization, to execute code that needs the model's
1005
+ modules properly initialized (such as weight initialization).
1006
+ """
1007
+ self.init_weights()
1008
+ self._backward_compatibility_gradient_checkpointing()
1009
+
1010
+ def get_input_embeddings(self):
1011
+ return self.bimamba.backbone.embeddings.word_embeddings
1012
+
1013
+ def set_input_embeddings(self, value):
1014
+ self.bimamba.backbone.embeddings.word_embeddings = value
1015
+
1016
+ def get_output_embeddings(self):
1017
+ return self.lm_head
1018
+
1019
+ def set_output_embeddings(self, new_embeddings):
1020
+ """Overrides output embeddings."""
1021
+ self.lm_head = new_embeddings
1022
+
1023
+ def tie_weights(self):
1024
+ """Tie weights."""
1025
+ super().tie_weights()
1026
+
1027
+ def get_decoder(self):
1028
+ """Get decoder (backbone) for the model."""
1029
+ return self.bimamba
1030
+
1031
+ def set_decoder(self, decoder):
1032
+ """Set decoder (backbone) for the model."""
1033
+ self.bimamba = decoder
1034
+
1035
+ def forward(
1036
+ self,
1037
+ input_ids: torch.LongTensor = None,
1038
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1039
+ labels: Optional[torch.LongTensor] = None,
1040
+ loss_weights: Optional[torch.FloatTensor] = None,
1041
+ output_hidden_states: Optional[bool] = None,
1042
+ return_dict: Optional[bool] = None,
1043
+ time_embeds: Optional[torch.FloatTensor] = None,
1044
+ ) -> Union[Tuple, MaskedLMOutput]:
1045
+ """HF-compatible forward method."""
1046
+
1047
+ output_hidden_states = (
1048
+ output_hidden_states
1049
+ if output_hidden_states is not None
1050
+ else self.config.output_hidden_states
1051
+ )
1052
+ return_dict = (
1053
+ return_dict if return_dict is not None else self.config.use_return_dict
1054
+ )
1055
+
1056
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1057
+ outputs = self.bimamba(
1058
+ input_ids=input_ids,
1059
+ inputs_embeds=inputs_embeds,
1060
+ output_hidden_states=output_hidden_states,
1061
+ return_dict=return_dict,
1062
+ time_embeds=time_embeds,
1063
+ )
1064
+ hidden_states = outputs[0]
1065
+ if (
1066
+ self.config.tie_word_embeddings
1067
+ and time_embeds is not None
1068
+ and self.temb_strategy is not None
1069
+ and self.temb_strategy == 'concat'
1070
+ ):
1071
+ hidden_states = hidden_states[:, :, self.config.d_temb:]
1072
+
1073
+ logits = self.lm_head(hidden_states)
1074
+
1075
+ loss = None
1076
+ if labels is not None:
1077
+ if loss_weights is not None:
1078
+ loss = weighted_cross_entropy(
1079
+ logits, labels, loss_weights, ignore_index=self.config.pad_token_id
1080
+ )
1081
+ else:
1082
+ loss = cross_entropy(
1083
+ logits, labels, ignore_index=self.config.pad_token_id
1084
+ )
1085
+
1086
+ if not return_dict:
1087
+ output = (logits,) + outputs[1:]
1088
+ return (loss,) + output if loss is not None else output
1089
+
1090
+ return MaskedLMOutput(
1091
+ loss=loss,
1092
+ logits=logits,
1093
+ hidden_states=outputs.hidden_states,
1094
+ )
1095
+
1096
+ class DiMamba(nn.Module, huggingface_hub.PyTorchModelHubMixin):
1097
+ def __init__(self, config, vocab_size: int, pad_token_id: int):
1098
+ super().__init__()
1099
+ if type(config) == dict:
1100
+ config = omegaconf.OmegaConf.create(config)
1101
+
1102
+ self.temb_strategy = config.model.temb_strategy
1103
+
1104
+ if self.temb_strategy == 'add':
1105
+ self.sigma_map = TimestepEmbedder(config.model.hidden_size)
1106
+ elif self.temb_strategy != 'none':
1107
+ self.sigma_map = TimestepEmbedder(config.model.cond_dim)
1108
+
1109
+ mamba_config = BiMambaConfig(
1110
+ d_model=config.model.hidden_size,
1111
+ n_layer=config.model.n_blocks,
1112
+ pad_token_id=pad_token_id,
1113
+ vocab_size=vocab_size,
1114
+ pad_vocab_size_multiple=1,
1115
+ tie_word_embeddings=config.model.tie_word_embeddings,
1116
+ temb_strategy=self.temb_strategy,
1117
+ d_temb=config.model.cond_dim,
1118
+ bidirectional=True)
1119
+
1120
+ self.model = BiMambaForMaskedLM(config=mamba_config)
1121
+
1122
+ def _get_bias_dropout_scale(self):
1123
+ if self.training:
1124
+ return bias_dropout_add_scale_fused_train
1125
+ else:
1126
+ return bias_dropout_add_scale_fused_inference
1127
+
1128
+ def forward(self, indices, sigma):
1129
+ c = None
1130
+ if self.temb_strategy is not None:
1131
+ c = F.silu(self.sigma_map(sigma))
1132
+
1133
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
1134
+ x = self.model(indices, time_embeds=c).logits
1135
+
1136
+ return x
models/dit.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing
3
+
4
+ import flash_attn
5
+ import flash_attn.layers.rotary
6
+ import huggingface_hub
7
+ import omegaconf
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+
13
+ # Flags required to enable jit fusion kernels
14
+ torch._C._jit_set_profiling_mode(False)
15
+ torch._C._jit_set_profiling_executor(False)
16
+ torch._C._jit_override_can_fuse_on_cpu(True)
17
+ torch._C._jit_override_can_fuse_on_gpu(True)
18
+
19
+
20
+ def bias_dropout_add_scale(
21
+ x: torch.Tensor,
22
+ bias: typing.Optional[torch.Tensor],
23
+ scale: torch.Tensor,
24
+ residual: typing.Optional[torch.Tensor],
25
+ prob: float,
26
+ training: bool) -> torch.Tensor:
27
+ if bias is not None:
28
+ out = scale * F.dropout(x + bias, p=prob, training=training)
29
+ else:
30
+ out = scale * F.dropout(x, p=prob, training=training)
31
+
32
+ if residual is not None:
33
+ out = residual + out
34
+ return out
35
+
36
+
37
+ def get_bias_dropout_add_scale(training):
38
+ def _bias_dropout_add(x, bias, scale, residual, prob):
39
+ return bias_dropout_add_scale(
40
+ x, bias, scale, residual, prob, training)
41
+
42
+ return _bias_dropout_add
43
+
44
+
45
+ # function overload
46
+ def modulate(x: torch.Tensor,
47
+ shift: torch.Tensor,
48
+ scale: torch.Tensor) -> torch.Tensor:
49
+ return x * (1 + scale) + shift
50
+
51
+
52
+ @torch.jit.script
53
+ def bias_dropout_add_scale_fused_train(
54
+ x: torch.Tensor,
55
+ bias: typing.Optional[torch.Tensor],
56
+ scale: torch.Tensor,
57
+ residual: typing.Optional[torch.Tensor],
58
+ prob: float) -> torch.Tensor:
59
+ return bias_dropout_add_scale(
60
+ x, bias, scale, residual, prob, True)
61
+
62
+
63
+ @torch.jit.script
64
+ def bias_dropout_add_scale_fused_inference(
65
+ x: torch.Tensor,
66
+ bias: typing.Optional[torch.Tensor],
67
+ scale: torch.Tensor,
68
+ residual: typing.Optional[torch.Tensor],
69
+ prob: float) -> torch.Tensor:
70
+ return bias_dropout_add_scale(
71
+ x, bias, scale, residual, prob, False)
72
+
73
+
74
+ @torch.jit.script
75
+ def modulate_fused(x: torch.Tensor,
76
+ shift: torch.Tensor,
77
+ scale: torch.Tensor) -> torch.Tensor:
78
+ return modulate(x, shift, scale)
79
+
80
+
81
+ class Rotary(torch.nn.Module):
82
+ def __init__(self, dim, base=10_000):
83
+ super().__init__()
84
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
85
+ self.register_buffer('inv_freq', inv_freq)
86
+ self.seq_len_cached = None
87
+ self.cos_cached = None
88
+ self.sin_cached = None
89
+
90
+ def forward(self, x, seq_dim=1):
91
+ seq_len = x.shape[seq_dim]
92
+ if seq_len != self.seq_len_cached:
93
+ self.seq_len_cached = seq_len
94
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
95
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
96
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
97
+ # dims are: batch, seq_len, qkv, head, dim
98
+ self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
99
+ self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
100
+ # This makes the transformation on v an identity.
101
+ self.cos_cached[:,:,2,:,:].fill_(1.)
102
+ self.sin_cached[:,:,2,:,:].fill_(0.)
103
+
104
+ return self.cos_cached, self.sin_cached
105
+
106
+
107
+ def rotate_half(x):
108
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
109
+ return torch.cat((-x2, x1), dim=-1)
110
+
111
+
112
+ def apply_rotary_pos_emb(qkv, cos, sin):
113
+ cos = cos[0,:,0,0,:cos.shape[-1]//2]
114
+ sin = sin[0,:,0,0,:sin.shape[-1]//2]
115
+ return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
116
+
117
+
118
+ # function overload
119
+ def modulate(x, shift, scale):
120
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
121
+
122
+
123
+ #################################################################################
124
+ # Layers #
125
+ #################################################################################
126
+ class LayerNorm(nn.Module):
127
+ def __init__(self, dim):
128
+ super().__init__()
129
+ self.weight = nn.Parameter(torch.ones([dim]))
130
+ self.dim = dim
131
+ def forward(self, x):
132
+ with torch.cuda.amp.autocast(enabled=False):
133
+ x = F.layer_norm(x.float(), [self.dim])
134
+ return x * self.weight[None,None,:]
135
+
136
+
137
+ def residual_linear(x, W, x_skip, residual_scale):
138
+ """x_skip + residual_scale * W @ x"""
139
+ dim_out, dim_in = W.shape[0], W.shape[1]
140
+ return torch.addmm(
141
+ x_skip.view(-1, dim_out),
142
+ x.view(-1, dim_in),
143
+ W.T,
144
+ alpha=residual_scale).view(*x.shape[:-1], dim_out)
145
+
146
+
147
+ #################################################################################
148
+ # Embedding Layers for Timesteps and Class Labels #
149
+ #################################################################################
150
+ class TimestepEmbedder(nn.Module):
151
+ """
152
+ Embeds scalar timesteps into vector representations.
153
+ """
154
+ def __init__(self, hidden_size, frequency_embedding_size=256):
155
+ super().__init__()
156
+ self.mlp = nn.Sequential(
157
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
158
+ nn.SiLU(),
159
+ nn.Linear(hidden_size, hidden_size, bias=True))
160
+ self.frequency_embedding_size = frequency_embedding_size
161
+
162
+ @staticmethod
163
+ def timestep_embedding(t, dim, max_period=10000):
164
+ """
165
+ Create sinusoidal timestep embeddings.
166
+ :param t: a 1-D Tensor of N indices, one per batch element.
167
+ These may be fractional.
168
+ :param dim: the dimension of the output.
169
+ :param max_period: controls the minimum frequency of the embeddings.
170
+ :return: an (N, D) Tensor of positional embeddings.
171
+ """
172
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
173
+ half = dim // 2
174
+ freqs = torch.exp(
175
+ - math.log(max_period)
176
+ * torch.arange(start=0, end=half, dtype=torch.float32)
177
+ / half).to(device=t.device)
178
+ args = t[:, None].float() * freqs[None]
179
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
180
+ if dim % 2:
181
+ embedding = torch.cat(
182
+ [embedding,
183
+ torch.zeros_like(embedding[:, :1])], dim=-1)
184
+ return embedding
185
+
186
+ def forward(self, t):
187
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
188
+ t_emb = self.mlp(t_freq)
189
+ return t_emb
190
+
191
+
192
+ class LabelEmbedder(nn.Module):
193
+ """Embeds class labels into vector representations.
194
+
195
+ Also handles label dropout for classifier-free guidance.
196
+ """
197
+ def __init__(self, num_classes, cond_size):
198
+ super().__init__()
199
+ self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
200
+ self.num_classes = num_classes
201
+
202
+ # TODO think of initializing with 0.02 std deviation like in original DiT paper
203
+
204
+ def forward(self, labels):
205
+ embeddings = self.embedding_table(labels)
206
+ return embeddings
207
+
208
+
209
+ #################################################################################
210
+ # Core Model #
211
+ #################################################################################
212
+
213
+
214
+ class DDiTBlock(nn.Module):
215
+ def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
216
+ super().__init__()
217
+ self.n_heads = n_heads
218
+
219
+ self.norm1 = LayerNorm(dim)
220
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
221
+ self.attn_out = nn.Linear(dim, dim, bias=False)
222
+ self.dropout1 = nn.Dropout(dropout)
223
+
224
+ self.norm2 = LayerNorm(dim)
225
+ self.mlp = nn.Sequential(
226
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
227
+ nn.GELU(approximate='tanh'),
228
+ nn.Linear(mlp_ratio * dim, dim, bias=True))
229
+ self.dropout2 = nn.Dropout(dropout)
230
+ self.dropout = dropout
231
+
232
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
233
+ self.adaLN_modulation.weight.data.zero_()
234
+ self.adaLN_modulation.bias.data.zero_()
235
+
236
+
237
+ def _get_bias_dropout_scale(self):
238
+ if self.training:
239
+ return bias_dropout_add_scale_fused_train
240
+ else:
241
+ return bias_dropout_add_scale_fused_inference
242
+
243
+
244
+ def forward(self, x, rotary_cos_sin, c, seqlens=None):
245
+ batch_size, seq_len = x.shape[0], x.shape[1]
246
+
247
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
248
+
249
+ (shift_msa, scale_msa, gate_msa, shift_mlp,
250
+ scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
251
+
252
+ # attention operation
253
+ x_skip = x
254
+ x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
255
+
256
+ qkv = self.attn_qkv(x) # dim -> 3 * dim
257
+ qkv = rearrange(qkv,
258
+ 'b s (three h d) -> b s three h d',
259
+ three=3,
260
+ h=self.n_heads)
261
+ with torch.cuda.amp.autocast(enabled=False):
262
+ cos, sin = rotary_cos_sin
263
+ qkv = apply_rotary_pos_emb(
264
+ qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
265
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
266
+ if seqlens is None:
267
+ cu_seqlens = torch.arange(
268
+ 0, (batch_size + 1) * seq_len, step=seq_len,
269
+ dtype=torch.int32, device=qkv.device)
270
+ else:
271
+ cu_seqlens = seqlens.cumsum(-1)
272
+ x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
273
+ qkv, cu_seqlens, seq_len, 0., causal=False)
274
+
275
+ x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)
276
+
277
+ x = bias_dropout_scale_fn(self.attn_out(x),
278
+ None,
279
+ gate_msa,
280
+ x_skip,
281
+ self.dropout)
282
+
283
+ # mlp operation
284
+ x = bias_dropout_scale_fn(
285
+ self.mlp(modulate_fused(
286
+ self.norm2(x), shift_mlp, scale_mlp)),
287
+ None, gate_mlp, x, self.dropout)
288
+ return x
289
+
290
+
291
+ class DDiTBlock_non_pad(nn.Module):
292
+ def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
293
+ super().__init__()
294
+ self.n_heads = n_heads
295
+
296
+ self.norm1 = LayerNorm(dim)
297
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
298
+ self.attn_out = nn.Linear(dim, dim, bias=False)
299
+ self.dropout1 = nn.Dropout(dropout)
300
+
301
+ self.norm2 = LayerNorm(dim)
302
+ self.mlp = nn.Sequential(
303
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
304
+ nn.GELU(approximate='tanh'),
305
+ nn.Linear(mlp_ratio * dim, dim, bias=True))
306
+ self.dropout2 = nn.Dropout(dropout)
307
+ self.dropout = dropout
308
+
309
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
310
+ self.adaLN_modulation.weight.data.zero_()
311
+ self.adaLN_modulation.bias.data.zero_()
312
+
313
+ def _get_bias_dropout_scale(self):
314
+ if self.training:
315
+ return bias_dropout_add_scale_fused_train
316
+ else:
317
+ return bias_dropout_add_scale_fused_inference
318
+
319
+ def forward(self, x, rotary_cos_sin, c, seqlens=None, attnmask = None):
320
+ batch_size, seq_len = x.shape[0], x.shape[1]
321
+
322
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
323
+
324
+ (shift_msa, scale_msa, gate_msa, shift_mlp,
325
+ scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
326
+
327
+ # attention operation
328
+ x_skip = x
329
+ x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
330
+
331
+ qkv = self.attn_qkv(x) # dim -> 3 * dim
332
+ qkv = rearrange(qkv,
333
+ 'b s (three h d) -> b s three h d',
334
+ three=3,
335
+ h=self.n_heads)
336
+ with torch.cuda.amp.autocast(enabled=True):
337
+ cos, sin = rotary_cos_sin
338
+ qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
339
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
340
+
341
+ # --------------------------------
342
+ mask_flat = attnmask.reshape(-1)
343
+ qkv = qkv[mask_flat]
344
+ seqlens = attnmask.sum(dim=1)
345
+ pad_seq_len = torch.zeros(len(seqlens)+1, dtype=torch.int32, device=qkv.device)
346
+ pad_seq_len[1:] = seqlens
347
+ seqlens = pad_seq_len
348
+ # cu_seqlens = pad_seq_len.cumsum(-1)
349
+ # x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
350
+ # qkv, cu_seqlens, seq_len, 0., causal=False)
351
+ # --------------------------------
352
+
353
+ if seqlens is None:
354
+ cu_seqlens = torch.arange(
355
+ 0, (batch_size + 1) * seq_len, step=seq_len,
356
+ dtype=torch.int32, device=qkv.device)
357
+ else:
358
+ cu_seqlens = seqlens.cumsum(-1).to(torch.int32)
359
+
360
+ assert cu_seqlens.min() == 0, "cu_seqlens 最小值必须等于 0"
361
+ assert qkv.size(0) == cu_seqlens[-1], "token 总数和 cu_seqlens 不符"
362
+
363
+ x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
364
+ qkv, cu_seqlens, seq_len, 0., causal=False)
365
+
366
+ # --------------------------------
367
+ out_flat = torch.zeros([batch_size*seq_len, x.shape[1], x.shape[2]]).to(x.device).to(x.dtype)
368
+ out_flat[mask_flat] = x
369
+ x = out_flat
370
+ # --------------------------------
371
+
372
+ x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)
373
+
374
+ x = bias_dropout_scale_fn(self.attn_out(x),
375
+ None,
376
+ gate_msa,
377
+ x_skip,
378
+ self.dropout)
379
+
380
+ # mlp operation
381
+ x = bias_dropout_scale_fn(
382
+ self.mlp(modulate_fused(
383
+ self.norm2(x), shift_mlp, scale_mlp)),
384
+ None, gate_mlp, x, self.dropout)
385
+ return x
386
+
387
+
388
+ class EmbeddingLayer(nn.Module):
389
+ def __init__(self, dim, vocab_dim):
390
+ super().__init__()
391
+ self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
392
+ torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
393
+
394
+ def forward(self, x):
395
+ return self.embedding[x]
396
+
397
+
398
+ class DDitFinalLayer(nn.Module):
399
+ def __init__(self, hidden_size, out_channels, cond_dim):
400
+ super().__init__()
401
+ self.norm_final = LayerNorm(hidden_size)
402
+ self.linear = nn.Linear(hidden_size, out_channels)
403
+ self.linear.weight.data.zero_()
404
+ self.linear.bias.data.zero_()
405
+
406
+ self.adaLN_modulation = nn.Linear(cond_dim,
407
+ 2 * hidden_size,
408
+ bias=True)
409
+ self.adaLN_modulation.weight.data.zero_()
410
+ self.adaLN_modulation.bias.data.zero_()
411
+
412
+
413
+ def forward(self, x, c):
414
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
415
+ x = modulate_fused(self.norm_final(x), shift, scale)
416
+ x = self.linear(x)
417
+ return x
418
+
419
+
420
+ class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
421
+ def __init__(self, config, vocab_size: int):
422
+ super().__init__()
423
+ if type(config) == dict:
424
+ config = omegaconf.OmegaConf.create(config)
425
+
426
+ self.config = config
427
+ self.vocab_size = vocab_size
428
+
429
+ self.vocab_embed = EmbeddingLayer(config.model.hidden_size,
430
+ vocab_size)
431
+ self.sigma_map = TimestepEmbedder(config.model.cond_dim)
432
+ self.rotary_emb = Rotary(
433
+ config.model.hidden_size // config.model.n_heads)
434
+
435
+ blocks = []
436
+ for _ in range(config.model.n_blocks):
437
+ blocks.append(DDiTBlock(config.model.hidden_size,
438
+ config.model.n_heads,
439
+ config.model.cond_dim,
440
+ dropout=config.model.dropout))
441
+ self.blocks = nn.ModuleList(blocks)
442
+
443
+ self.output_layer = DDitFinalLayer(
444
+ config.model.hidden_size,
445
+ vocab_size,
446
+ config.model.cond_dim)
447
+ self.scale_by_sigma = config.model.scale_by_sigma
448
+
449
+ def _get_bias_dropout_scale(self):
450
+ if self.training:
451
+ return bias_dropout_add_scale_fused_train
452
+ else:
453
+ return bias_dropout_add_scale_fused_inference
454
+
455
+ def forward(self, indices, sigma):
456
+ x = self.vocab_embed(indices)
457
+ c = F.silu(self.sigma_map(sigma))
458
+
459
+ rotary_cos_sin = self.rotary_emb(x)
460
+
461
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
462
+ for i in range(len(self.blocks)):
463
+ x = self.blocks[i](x, rotary_cos_sin, c, seqlens=None)
464
+ x = self.output_layer(x, c)
465
+
466
+ return x
467
+
468
+ class DIT_non_pad(nn.Module, huggingface_hub.PyTorchModelHubMixin):
469
+ def __init__(self, config, vocab_size: int):
470
+ super().__init__()
471
+ if type(config) == dict:
472
+ config = omegaconf.OmegaConf.create(config)
473
+
474
+ self.config = config
475
+ self.vocab_size = vocab_size
476
+
477
+ self.vocab_embed = EmbeddingLayer(config.model.hidden_size,
478
+ vocab_size)
479
+ self.sigma_map = TimestepEmbedder(config.model.cond_dim)
480
+ self.rotary_emb = Rotary(
481
+ config.model.hidden_size // config.model.n_heads)
482
+
483
+ blocks = []
484
+ for _ in range(config.model.n_blocks):
485
+ blocks.append(DDiTBlock_non_pad(config.model.hidden_size,
486
+ config.model.n_heads,
487
+ config.model.cond_dim,
488
+ dropout=config.model.dropout))
489
+ self.blocks = nn.ModuleList(blocks)
490
+
491
+ self.output_layer = DDitFinalLayer(
492
+ config.model.hidden_size,
493
+ vocab_size,
494
+ config.model.cond_dim)
495
+ self.scale_by_sigma = config.model.scale_by_sigma
496
+
497
+ def _get_bias_dropout_scale(self):
498
+ if self.training:
499
+ return bias_dropout_add_scale_fused_train
500
+ else:
501
+ return bias_dropout_add_scale_fused_inference
502
+
503
+ def forward(self, indices, sigma, attnmask):
504
+ x = self.vocab_embed(indices)
505
+ c = F.silu(self.sigma_map(sigma))
506
+
507
+ rotary_cos_sin = self.rotary_emb(x)
508
+
509
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
510
+ for i in range(len(self.blocks)):
511
+ x = self.blocks[i](x, rotary_cos_sin, c, seqlens=None, attnmask=attnmask)
512
+ x = self.output_layer(x, c)
513
+
514
+ return x
models/ema.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ExponentialMovingAverage:
5
+ """
6
+ Maintains (exponential) moving average of a set of parameters.
7
+ """
8
+
9
+ def __init__(self, parameters, decay, use_num_updates=True):
10
+ """
11
+ Args:
12
+ parameters: Iterable of `torch.nn.Parameter`; usually the result of
13
+ `model.parameters()`.
14
+ decay: The exponential decay.
15
+ use_num_updates: Whether to use number of updates when computing
16
+ averages.
17
+ """
18
+ if decay < 0.0 or decay > 1.0:
19
+ raise ValueError('Decay must be between 0 and 1')
20
+ self.decay = decay
21
+ self.num_updates = 0 if use_num_updates else None
22
+ self.shadow_params = [p.clone().detach()
23
+ for p in parameters if p.requires_grad]
24
+ self.collected_params = []
25
+
26
+ def move_shadow_params_to_device(self, device):
27
+ self.shadow_params = [i.to(device) for i in self.shadow_params]
28
+
29
+ def update(self, parameters):
30
+ """
31
+ Update currently maintained parameters.
32
+
33
+ Call this every time the parameters are updated, such as the result of
34
+ the `optimizer.step()` call.
35
+
36
+ Args:
37
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
38
+ parameters used to initialize this object.
39
+ """
40
+ decay = self.decay
41
+ if self.num_updates is not None:
42
+ self.num_updates += 1
43
+ decay = min(decay, (1 + self.num_updates) /
44
+ (10 + self.num_updates))
45
+ one_minus_decay = 1.0 - decay
46
+ with torch.no_grad():
47
+ parameters = [p for p in parameters if p.requires_grad]
48
+ for s_param, param in zip(self.shadow_params, parameters):
49
+ s_param.sub_(one_minus_decay * (s_param - param))
50
+
51
+ def copy_to(self, parameters):
52
+ """
53
+ Copy current parameters into given collection of parameters.
54
+
55
+ Args:
56
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
57
+ updated with the stored moving averages.
58
+ """
59
+ parameters = [p for p in parameters if p.requires_grad]
60
+ for s_param, param in zip(self.shadow_params, parameters):
61
+ if param.requires_grad:
62
+ param.data.copy_(s_param.data)
63
+
64
+ def store(self, parameters):
65
+ """
66
+ Save the current parameters for restoring later.
67
+
68
+ Args:
69
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
+ temporarily stored.
71
+ """
72
+ self.collected_params = [param.clone() for param in parameters]
73
+
74
+ def restore(self, parameters):
75
+ """
76
+ Restore the parameters stored with the `store` method.
77
+ Useful to validate the model with EMA parameters without affecting the
78
+ original optimization process. Store the parameters before the
79
+ `copy_to` method. After validation (or model saving), use this to
80
+ restore the former parameters.
81
+
82
+ Args:
83
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
84
+ updated with the stored parameters.
85
+ """
86
+ for c_param, param in zip(self.collected_params, parameters):
87
+ param.data.copy_(c_param.data)
88
+
89
+ def state_dict(self):
90
+ return dict(decay=self.decay,
91
+ num_updates=self.num_updates,
92
+ shadow_params=self.shadow_params)
93
+
94
+ def load_state_dict(self, state_dict):
95
+ self.decay = state_dict['decay']
96
+ self.num_updates = state_dict['num_updates']
97
+ self.shadow_params = state_dict['shadow_params']