Update modeling_time_moe.py
Browse files- modeling_time_moe.py +3 -6
modeling_time_moe.py
CHANGED
|
@@ -25,6 +25,7 @@ try:
|
|
| 25 |
except:
|
| 26 |
pass
|
| 27 |
|
|
|
|
| 28 |
def _get_unpad_data(attention_mask):
|
| 29 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 30 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
@@ -66,7 +67,7 @@ def load_balancing_loss_func(
|
|
| 66 |
The auxiliary loss.
|
| 67 |
"""
|
| 68 |
if gate_logits is None or not isinstance(gate_logits, (tuple, list)) or gate_logits[0] is None:
|
| 69 |
-
return
|
| 70 |
|
| 71 |
compute_device = gate_logits[0].device
|
| 72 |
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
|
@@ -293,7 +294,7 @@ class TimeMoeSparseExpertsLayer(nn.Module):
|
|
| 293 |
""" """
|
| 294 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 295 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 296 |
-
# router_logits
|
| 297 |
router_logits = self.gate(hidden_states)
|
| 298 |
|
| 299 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
@@ -764,8 +765,6 @@ class TimeMoeModel(TimeMoePreTrainedModel):
|
|
| 764 |
|
| 765 |
def __init__(self, config: TimeMoeConfig):
|
| 766 |
super().__init__(config)
|
| 767 |
-
# self.padding_idx = config.pad_token_id
|
| 768 |
-
|
| 769 |
self.embed_layer = TimeMoeInputEmbedding(config)
|
| 770 |
self.layers = nn.ModuleList(
|
| 771 |
[TimeMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
@@ -1096,12 +1095,10 @@ class TimeMoeForPrediction(TimeMoePreTrainedModel, TSGenerationMixin):
|
|
| 1096 |
shift_labels = labels
|
| 1097 |
|
| 1098 |
# Calculate loss with mask
|
| 1099 |
-
# losses = self.loss_function(shift_predictions.to(torch.float32), shift_labels.to(torch.float32))
|
| 1100 |
losses = self.loss_function(shift_predictions, shift_labels)
|
| 1101 |
|
| 1102 |
if loss_masks is not None:
|
| 1103 |
losses = losses * loss_masks
|
| 1104 |
-
|
| 1105 |
loss = losses.sum() / loss_masks.sum()
|
| 1106 |
else:
|
| 1107 |
loss = torch.mean(losses)
|
|
|
|
| 25 |
except:
|
| 26 |
pass
|
| 27 |
|
| 28 |
+
|
| 29 |
def _get_unpad_data(attention_mask):
|
| 30 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 31 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
|
| 67 |
The auxiliary loss.
|
| 68 |
"""
|
| 69 |
if gate_logits is None or not isinstance(gate_logits, (tuple, list)) or gate_logits[0] is None:
|
| 70 |
+
return 0.0
|
| 71 |
|
| 72 |
compute_device = gate_logits[0].device
|
| 73 |
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
|
|
|
| 294 |
""" """
|
| 295 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 296 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 297 |
+
# router_logits -> (batch * sequence_length, n_experts)
|
| 298 |
router_logits = self.gate(hidden_states)
|
| 299 |
|
| 300 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
|
|
| 765 |
|
| 766 |
def __init__(self, config: TimeMoeConfig):
|
| 767 |
super().__init__(config)
|
|
|
|
|
|
|
| 768 |
self.embed_layer = TimeMoeInputEmbedding(config)
|
| 769 |
self.layers = nn.ModuleList(
|
| 770 |
[TimeMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
|
|
| 1095 |
shift_labels = labels
|
| 1096 |
|
| 1097 |
# Calculate loss with mask
|
|
|
|
| 1098 |
losses = self.loss_function(shift_predictions, shift_labels)
|
| 1099 |
|
| 1100 |
if loss_masks is not None:
|
| 1101 |
losses = losses * loss_masks
|
|
|
|
| 1102 |
loss = losses.sum() / loss_masks.sum()
|
| 1103 |
else:
|
| 1104 |
loss = torch.mean(losses)
|