| import torch.nn as nn | |
| from src.models.gated_deltaproduct import GatedDeltaProductConfig | |
| from src.models.gated_deltaproduct.modeling_gated_deltaproduct import ( | |
| GatedDeltaProductBlock, | |
| ) | |
| class GatedDeltaProductEncoder(nn.Module): | |
| """ | |
| GatedDeltaNet encoder using GatedDeltaProductBlock for sequence modeling. | |
| """ | |
| def __init__( | |
| self, | |
| layer_idx: int, | |
| token_embed_dim: int, | |
| num_heads: int = 4, | |
| attn_mode: str = "chunk", | |
| expand_v: float = 1.0, | |
| use_gate: bool = False, | |
| use_short_conv: bool = True, | |
| conv_size: int = 4, | |
| hidden_ratio: int = 1.0, | |
| allow_neg_eigval: bool = True, | |
| use_forget_gate: bool = True, | |
| num_householder: int = 1, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| config = GatedDeltaProductConfig( | |
| attn_mode=attn_mode, | |
| hidden_size=token_embed_dim, | |
| expand_v=expand_v, | |
| use_gate=use_gate, | |
| use_short_conv=use_short_conv, | |
| conv_size=conv_size, | |
| head_dim=token_embed_dim // num_heads, | |
| hidden_ratio=hidden_ratio, | |
| num_heads=num_heads, | |
| allow_neg_eigval=allow_neg_eigval, | |
| use_forget_gate=use_forget_gate, | |
| num_householder=num_householder, | |
| ) | |
| self.encoder_layer = GatedDeltaProductBlock(layer_idx=layer_idx, config=config) | |
| def forward(self, x, initial_state=None): | |
| """ | |
| Forward pass through the GatedDeltaProductBlock. | |
| Args: | |
| x: Input tensor of shape [batch_size, seq_len, hidden_size] | |
| Returns: | |
| Output tensor of same shape as input | |
| """ | |
| x, last_hidden_state, _ = self.encoder_layer(x, output_attentions=True, initial_state=initial_state) | |
| return x, last_hidden_state | |