File size: 1,843 Bytes
c4b87d2 0a58567 c4b87d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch.nn as nn
from src.models.gated_deltaproduct import GatedDeltaProductConfig
from src.models.gated_deltaproduct.modeling_gated_deltaproduct import (
GatedDeltaProductBlock,
)
class GatedDeltaProductEncoder(nn.Module):
"""
GatedDeltaNet encoder using GatedDeltaProductBlock for sequence modeling.
"""
def __init__(
self,
layer_idx: int,
token_embed_dim: int,
num_heads: int = 4,
attn_mode: str = "chunk",
expand_v: float = 1.0,
use_gate: bool = False,
use_short_conv: bool = True,
conv_size: int = 4,
hidden_ratio: int = 1.0,
allow_neg_eigval: bool = True,
use_forget_gate: bool = True,
num_householder: int = 1,
**kwargs,
):
super().__init__()
config = GatedDeltaProductConfig(
attn_mode=attn_mode,
hidden_size=token_embed_dim,
expand_v=expand_v,
use_gate=use_gate,
use_short_conv=use_short_conv,
conv_size=conv_size,
head_dim=token_embed_dim // num_heads,
hidden_ratio=hidden_ratio,
num_heads=num_heads,
allow_neg_eigval=allow_neg_eigval,
use_forget_gate=use_forget_gate,
num_householder=num_householder,
)
self.encoder_layer = GatedDeltaProductBlock(layer_idx=layer_idx, config=config)
def forward(self, x, initial_state=None):
"""
Forward pass through the GatedDeltaProductBlock.
Args:
x: Input tensor of shape [batch_size, seq_len, hidden_size]
Returns:
Output tensor of same shape as input
"""
x, last_hidden_state, _ = self.encoder_layer(x, output_attentions=True, initial_state=initial_state)
return x, last_hidden_state
|