lirannoc commited on
Commit
c251aa5
·
verified ·
1 Parent(s): fc20137

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +2 -2
modeling_super_linear.py CHANGED
@@ -474,7 +474,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
474
  labels: Optional[torch.Tensor] = None,
475
  **kwargs,) -> CausalLMOutputWithCrossAttentions:
476
 
477
-
478
  if inputs_embeds is None:
479
  raise ValueError("Pass the time‑series as `inputs_embeds`")
480
 
@@ -482,7 +482,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
482
  x_enc = inputs_embeds
483
 
484
  # backbone returns (B, pred_len, C)
485
- preds = self.backbone(x_enc)
486
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
487
 
488
 
 
474
  labels: Optional[torch.Tensor] = None,
475
  **kwargs,) -> CausalLMOutputWithCrossAttentions:
476
 
477
+
478
  if inputs_embeds is None:
479
  raise ValueError("Pass the time‑series as `inputs_embeds`")
480
 
 
482
  x_enc = inputs_embeds
483
 
484
  # backbone returns (B, pred_len, C)
485
+ preds = self.backbone(x_enc, pred_len=kwargs.get("pred_len", default_value))
486
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
487
 
488