Update modeling_super_linear.py
Browse files- modeling_super_linear.py +2 -5
modeling_super_linear.py
CHANGED
|
@@ -392,7 +392,6 @@ class Model(nn.Module):
|
|
| 392 |
- Prediction tensor
|
| 393 |
- (Optional) Expert selection probabilities if get_prob is True
|
| 394 |
"""
|
| 395 |
-
print(pred_len)
|
| 396 |
if pred_len is None:
|
| 397 |
pred_len = self.train_pred_len
|
| 398 |
|
|
@@ -421,8 +420,6 @@ class Model(nn.Module):
|
|
| 421 |
else:
|
| 422 |
out = self.moe(x)
|
| 423 |
|
| 424 |
-
print(pred_len)
|
| 425 |
-
print(self.train_pred_len)
|
| 426 |
if self.train_pred_len < pred_len:
|
| 427 |
outputs = [out]
|
| 428 |
ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
|
|
@@ -474,7 +471,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 474 |
labels: Optional[torch.Tensor] = None,
|
| 475 |
**kwargs,) -> CausalLMOutputWithCrossAttentions:
|
| 476 |
|
| 477 |
-
|
| 478 |
if inputs_embeds is None:
|
| 479 |
raise ValueError("Pass the time‑series as `inputs_embeds`")
|
| 480 |
|
|
@@ -482,7 +479,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 482 |
x_enc = inputs_embeds
|
| 483 |
|
| 484 |
# backbone returns (B, pred_len, C)
|
| 485 |
-
preds = self.backbone(x_enc, pred_len=kwargs.get("pred_len",
|
| 486 |
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
|
| 487 |
|
| 488 |
|
|
|
|
| 392 |
- Prediction tensor
|
| 393 |
- (Optional) Expert selection probabilities if get_prob is True
|
| 394 |
"""
|
|
|
|
| 395 |
if pred_len is None:
|
| 396 |
pred_len = self.train_pred_len
|
| 397 |
|
|
|
|
| 420 |
else:
|
| 421 |
out = self.moe(x)
|
| 422 |
|
|
|
|
|
|
|
| 423 |
if self.train_pred_len < pred_len:
|
| 424 |
outputs = [out]
|
| 425 |
ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
|
|
|
|
| 471 |
labels: Optional[torch.Tensor] = None,
|
| 472 |
**kwargs,) -> CausalLMOutputWithCrossAttentions:
|
| 473 |
|
| 474 |
+
|
| 475 |
if inputs_embeds is None:
|
| 476 |
raise ValueError("Pass the time‑series as `inputs_embeds`")
|
| 477 |
|
|
|
|
| 479 |
x_enc = inputs_embeds
|
| 480 |
|
| 481 |
# backbone returns (B, pred_len, C)
|
| 482 |
+
preds = self.backbone(x_enc, pred_len=kwargs.get("pred_len", None))
|
| 483 |
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
|
| 484 |
|
| 485 |
|