File size: 1,843 Bytes
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch.nn as nn

from src.models.gated_deltaproduct import GatedDeltaProductConfig
from src.models.gated_deltaproduct.modeling_gated_deltaproduct import (
    GatedDeltaProductBlock,
)


class GatedDeltaProductEncoder(nn.Module):
    """
    GatedDeltaNet encoder using GatedDeltaProductBlock for sequence modeling.
    """

    def __init__(
        self,
        layer_idx: int,
        token_embed_dim: int,
        num_heads: int = 4,
        attn_mode: str = "chunk",
        expand_v: float = 1.0,
        use_gate: bool = False,
        use_short_conv: bool = True,
        conv_size: int = 4,
        hidden_ratio: int = 1.0,
        allow_neg_eigval: bool = True,
        use_forget_gate: bool = True,
        num_householder: int = 1,
        **kwargs,
    ):
        super().__init__()
        config = GatedDeltaProductConfig(
            attn_mode=attn_mode,
            hidden_size=token_embed_dim,
            expand_v=expand_v,
            use_gate=use_gate,
            use_short_conv=use_short_conv,
            conv_size=conv_size,
            head_dim=token_embed_dim // num_heads,
            hidden_ratio=hidden_ratio,
            num_heads=num_heads,
            allow_neg_eigval=allow_neg_eigval,
            use_forget_gate=use_forget_gate,
            num_householder=num_householder,
        )

        self.encoder_layer = GatedDeltaProductBlock(layer_idx=layer_idx, config=config)

    def forward(self, x, initial_state=None):
        """
        Forward pass through the GatedDeltaProductBlock.

        Args:
            x: Input tensor of shape [batch_size, seq_len, hidden_size]

        Returns:
            Output tensor of same shape as input
        """
        x, last_hidden_state, _ = self.encoder_layer(x, output_attentions=True, initial_state=initial_state)
        return x, last_hidden_state