File size: 6,767 Bytes
d56eb1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""# β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚

# `feedforward.py`

Regarding dropout:

- I don't see it applied to the MoE in DeepSeek-V3, [here](https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py).

- I don't see it applied in [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L140)

Norms:

* nn.RMSNorm [here](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html)

## FFN
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from .shared_space_config import SharedSpaceDecoderConfig


def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
    """
    Create a normalization layer based on the config norm_type.
    
    Args:
        hidden_size: The dimension to normalize over
        config: Configuration containing norm_type and epsilon values
    
    Returns:
        Either a LayerNorm or RMSNorm layer
    """
    if config.norm_type == "layernorm":
        return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
    elif config.norm_type == "rmsnorm":
        return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
    else:
        # This should be caught by config validation, but being defensive
        raise ValueError(f"Unknown norm_type: {config.norm_type}")


# TODO - Find a shared place to put this.
class DeepseekV3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DeepseekV3RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class SubspaceFeedForward(nn.Module):
    """
    Feed-forward block for SharedSpaceDecoder.

    Implements SwiGLU:
        FFN(x) = W_out( Swish(W_in(x)) βŠ™ W_gate(x) ) + residual

    Supports both dense and decomposed MLP variants.

    Dense:
        - W_in:   Linear(hidden_dim β†’ intermediate_dim)
        - W_gate: Linear(hidden_dim β†’ intermediate_dim)
        - W_out:  Linear(intermediate_dim β†’ hidden_dim)

    Decomposed:
        - W_in_shared:   Linear(hidden_dim β†’ rank, bias=False)
        - W_in_shared_norm: RMSNorm
        - W_in:          Linear(rank β†’ intermediate_dim)
        - W_gate_shared: Linear(hidden_dim β†’ rank, bias=False)
        - W_gate_shared_norm: RMSNorm
        - W_gate:        Linear(rank β†’ intermediate_dim)
        - W_out:         Linear(intermediate_dim β†’ rank, bias=False)
        - W_out_shared:  Linear(rank β†’ hidden_dim)

    Residual, dropout, and post-norm are handled inside the block.
    """

    def __init__(self, config, layer_idx):
        super().__init__()


        #dropout_prob = config.hidden_dropout_prob # TODO - Style -- don't define variables if only used once.

        # Determine whether this is a dense or decomposed layer.
        # It's dense if either:
        #  - ffn_decompose is disabled (no dense layers at all)
        #  - ffn_decompose is enabled, but this is one of the early dense layers.
        self.is_dense = (not config.ffn_decompose) or (layer_idx < config.num_dense_layers)

        hidden_dim = config.hidden_size
        intermediate_dim = config.intermediate_size # TODO - Find something shorter, and use the same name.

        # If it's one of the dense layers,
        if self.is_dense:
            # === Dense FFN Projections ===
            self.W_in = nn.Linear(hidden_dim, intermediate_dim)
            self.W_gate = nn.Linear(hidden_dim, intermediate_dim)
            self.W_out = nn.Linear(intermediate_dim, hidden_dim)

        # Define weights for the decomposed version.
        else:
            rank = config.ffn_rank

            print("hidden_dim:", hidden_dim)
            print("rank:", rank)

            # === Input Projections ===
            self.W_in_shared = nn.Linear(hidden_dim, rank, bias=False)
            self.W_in_shared_norm = create_norm_layer(rank, config)
            self.W_in = nn.Linear(rank, intermediate_dim, bias=True)

            # === Gate Projections ===
            self.W_gate_shared = nn.Linear(hidden_dim, rank, bias=False)
            self.W_gate_shared_norm = create_norm_layer(rank, config)
            self.W_gate = nn.Linear(rank, intermediate_dim, bias=True)

            # === Output Projection ===
            self.W_out = nn.Linear(intermediate_dim, rank, bias=False)
            # TODO - Could experiment with this.
            #self.W_out_shared_layernorm = DeepseekV3RMSNorm(rank, eps=config.eps)
            self.W_out_shared = nn.Linear(rank, hidden_dim, bias=True)

        # See notes no dropout
        #self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # === Tensor Dimension Symbols ===
        # B: batch_size     β€” number of samples in the batch
        # T: seq_len        β€” number of tokens per sample
        # D: hidden_dim     β€” model embedding size
        # R: ffn_rank       β€” latent shared subspace dimension
        # D_ff: intermediate_size β€” FFN hidden dimension

        # =========================
        #    Gated Feedforward
        # =========================

        if self.is_dense:
            # =============
            #     Dense
            # =============

            # Input:  x [B, T, D]
            # Output: x_proj [B, T, D_ff]
            x_proj = self.W_in(x)

            # Output: gate [B, T, D_ff]
            gate = self.W_gate(x)

            # SwiGLU nonlinearity
            x = F.silu(x_proj) * gate  # [B, T, D_ff]

            # See notes on dropout
            #x = self.dropout(x)

            # Output: x [B, T, D]
            x = self.W_out(x)

        else:
            # ==================
            #     Decomposed
            # ==================

            # Input:  x [B, T, D]
            # Output: x_proj [B, T, D_ff]
            x_proj = self.W_in(self.W_in_shared_norm(self.W_in_shared(x)))

            # Input:  x [B, T, D]
            # Output: gate [B, T, D_ff]
            gate = self.W_gate(self.W_gate_shared_norm(self.W_gate_shared(x)))

            # SwiGLU nonlinearity
            x = F.silu(x_proj) * gate  # [B, T, D_ff]

            # See notes on dropout
            #x = self.dropout(x)

            # Output: x [B, T, D]
            x = self.W_out_shared(self.W_out(x))


        return x