lirannoc commited on
Commit
41441d8
·
verified ·
1 Parent(s): c251aa5

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +2 -5
modeling_super_linear.py CHANGED
@@ -392,7 +392,6 @@ class Model(nn.Module):
392
  - Prediction tensor
393
  - (Optional) Expert selection probabilities if get_prob is True
394
  """
395
- print(pred_len)
396
  if pred_len is None:
397
  pred_len = self.train_pred_len
398
 
@@ -421,8 +420,6 @@ class Model(nn.Module):
421
  else:
422
  out = self.moe(x)
423
 
424
- print(pred_len)
425
- print(self.train_pred_len)
426
  if self.train_pred_len < pred_len:
427
  outputs = [out]
428
  ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
@@ -474,7 +471,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
474
  labels: Optional[torch.Tensor] = None,
475
  **kwargs,) -> CausalLMOutputWithCrossAttentions:
476
 
477
-
478
  if inputs_embeds is None:
479
  raise ValueError("Pass the time‑series as `inputs_embeds`")
480
 
@@ -482,7 +479,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
482
  x_enc = inputs_embeds
483
 
484
  # backbone returns (B, pred_len, C)
485
- preds = self.backbone(x_enc, pred_len=kwargs.get("pred_len", default_value))
486
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
487
 
488
 
 
392
  - Prediction tensor
393
  - (Optional) Expert selection probabilities if get_prob is True
394
  """
 
395
  if pred_len is None:
396
  pred_len = self.train_pred_len
397
 
 
420
  else:
421
  out = self.moe(x)
422
 
 
 
423
  if self.train_pred_len < pred_len:
424
  outputs = [out]
425
  ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
 
471
  labels: Optional[torch.Tensor] = None,
472
  **kwargs,) -> CausalLMOutputWithCrossAttentions:
473
 
474
+
475
  if inputs_embeds is None:
476
  raise ValueError("Pass the time‑series as `inputs_embeds`")
477
 
 
479
  x_enc = inputs_embeds
480
 
481
  # backbone returns (B, pred_len, C)
482
+ preds = self.backbone(x_enc, pred_len=kwargs.get("pred_len", None))
483
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
484
 
485