Upload 2 files
Browse files- model_minimind.py +6 -15
- model_vlm.py +3 -5
model_minimind.py
CHANGED
|
@@ -194,13 +194,7 @@ class Attention(nn.Module):
|
|
| 194 |
)
|
| 195 |
|
| 196 |
if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
|
| 197 |
-
|
| 198 |
-
None
|
| 199 |
-
if attention_mask is None
|
| 200 |
-
else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool()
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
| 204 |
else:
|
| 205 |
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 206 |
scores = scores + torch.triu(
|
|
@@ -445,7 +439,6 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 445 |
self.model = MiniMindModel(self.config)
|
| 446 |
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 447 |
self.model.embed_tokens.weight = self.lm_head.weight
|
| 448 |
-
self.OUT = CausalLMOutputWithPast()
|
| 449 |
|
| 450 |
def forward(self,
|
| 451 |
input_ids: Optional[torch.Tensor] = None,
|
|
@@ -454,7 +447,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 454 |
use_cache: bool = False,
|
| 455 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 456 |
**args):
|
| 457 |
-
|
| 458 |
input_ids=input_ids,
|
| 459 |
attention_mask=attention_mask,
|
| 460 |
past_key_values=past_key_values,
|
|
@@ -462,9 +455,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 462 |
**args
|
| 463 |
)
|
| 464 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 465 |
-
logits = self.lm_head(
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
self.OUT.__setitem__('past_key_values', past_kvs)
|
| 470 |
-
return self.OUT
|
|
|
|
| 194 |
)
|
| 195 |
|
| 196 |
if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
|
| 197 |
+
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
else:
|
| 199 |
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 200 |
scores = scores + torch.triu(
|
|
|
|
| 439 |
self.model = MiniMindModel(self.config)
|
| 440 |
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 441 |
self.model.embed_tokens.weight = self.lm_head.weight
|
|
|
|
| 442 |
|
| 443 |
def forward(self,
|
| 444 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 447 |
use_cache: bool = False,
|
| 448 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 449 |
**args):
|
| 450 |
+
hidden_states, past_key_values, aux_loss = self.model(
|
| 451 |
input_ids=input_ids,
|
| 452 |
attention_mask=attention_mask,
|
| 453 |
past_key_values=past_key_values,
|
|
|
|
| 455 |
**args
|
| 456 |
)
|
| 457 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 458 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 459 |
+
output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
|
| 460 |
+
output.aux_loss = aux_loss
|
| 461 |
+
return output
|
|
|
|
|
|
model_vlm.py
CHANGED
|
@@ -162,8 +162,6 @@ class MiniMindVLM(MiniMindForCausalLM):
|
|
| 162 |
)
|
| 163 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 164 |
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
self.OUT.__setitem__('past_key_values', presents)
|
| 169 |
-
return self.OUT
|
|
|
|
| 162 |
)
|
| 163 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 164 |
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 165 |
+
output = CausalLMOutputWithPast(logits=logits, past_key_values=presents, hidden_states=hidden_states)
|
| 166 |
+
output.aux_loss = aux_loss
|
| 167 |
+
return output
|
|
|
|
|
|