jingyaogong commited on
Commit
14a5552
·
verified ·
1 Parent(s): 69d0681

Upload 2 files

Browse files
Files changed (2) hide show
  1. model_minimind.py +6 -15
  2. model_vlm.py +3 -5
model_minimind.py CHANGED
@@ -194,13 +194,7 @@ class Attention(nn.Module):
194
  )
195
 
196
  if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
197
- attn_mask = (
198
- None
199
- if attention_mask is None
200
- else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool()
201
- )
202
-
203
- output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
204
  else:
205
  scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
206
  scores = scores + torch.triu(
@@ -445,7 +439,6 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
445
  self.model = MiniMindModel(self.config)
446
  self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
447
  self.model.embed_tokens.weight = self.lm_head.weight
448
- self.OUT = CausalLMOutputWithPast()
449
 
450
  def forward(self,
451
  input_ids: Optional[torch.Tensor] = None,
@@ -454,7 +447,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
454
  use_cache: bool = False,
455
  logits_to_keep: Union[int, torch.Tensor] = 0,
456
  **args):
457
- h, past_kvs, aux_loss = self.model(
458
  input_ids=input_ids,
459
  attention_mask=attention_mask,
460
  past_key_values=past_key_values,
@@ -462,9 +455,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
462
  **args
463
  )
464
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
465
- logits = self.lm_head(h[:, slice_indices, :])
466
- self.OUT.__setitem__('last_hidden_state', h)
467
- self.OUT.__setitem__('logits', logits)
468
- self.OUT.__setitem__('aux_loss', aux_loss)
469
- self.OUT.__setitem__('past_key_values', past_kvs)
470
- return self.OUT
 
194
  )
195
 
196
  if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
197
+ output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
 
 
 
 
 
 
198
  else:
199
  scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
200
  scores = scores + torch.triu(
 
439
  self.model = MiniMindModel(self.config)
440
  self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
441
  self.model.embed_tokens.weight = self.lm_head.weight
 
442
 
443
  def forward(self,
444
  input_ids: Optional[torch.Tensor] = None,
 
447
  use_cache: bool = False,
448
  logits_to_keep: Union[int, torch.Tensor] = 0,
449
  **args):
450
+ hidden_states, past_key_values, aux_loss = self.model(
451
  input_ids=input_ids,
452
  attention_mask=attention_mask,
453
  past_key_values=past_key_values,
 
455
  **args
456
  )
457
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
458
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
459
+ output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
460
+ output.aux_loss = aux_loss
461
+ return output
 
 
model_vlm.py CHANGED
@@ -162,8 +162,6 @@ class MiniMindVLM(MiniMindForCausalLM):
162
  )
163
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
164
  logits = self.lm_head(hidden_states[:, slice_indices, :])
165
- self.OUT.__setitem__('last_hidden_state', hidden_states)
166
- self.OUT.__setitem__('logits', logits)
167
- self.OUT.__setitem__('aux_loss', aux_loss)
168
- self.OUT.__setitem__('past_key_values', presents)
169
- return self.OUT
 
162
  )
163
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
164
  logits = self.lm_head(hidden_states[:, slice_indices, :])
165
+ output = CausalLMOutputWithPast(logits=logits, past_key_values=presents, hidden_states=hidden_states)
166
+ output.aux_loss = aux_loss
167
+ return output