SubspaceDecoder_mla0-0-0 / feedforward.py
kokolamba's picture
Update model files
d56eb1d
"""# β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚
# `feedforward.py`
Regarding dropout:
- I don't see it applied to the MoE in DeepSeek-V3, [here](https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py).
- I don't see it applied in [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L140)
Norms:
* nn.RMSNorm [here](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html)
## FFN
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .shared_space_config import SharedSpaceDecoderConfig
def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
"""
Create a normalization layer based on the config norm_type.
Args:
hidden_size: The dimension to normalize over
config: Configuration containing norm_type and epsilon values
Returns:
Either a LayerNorm or RMSNorm layer
"""
if config.norm_type == "layernorm":
return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
elif config.norm_type == "rmsnorm":
return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
else:
# This should be caught by config validation, but being defensive
raise ValueError(f"Unknown norm_type: {config.norm_type}")
# TODO - Find a shared place to put this.
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class SubspaceFeedForward(nn.Module):
"""
Feed-forward block for SharedSpaceDecoder.
Implements SwiGLU:
FFN(x) = W_out( Swish(W_in(x)) βŠ™ W_gate(x) ) + residual
Supports both dense and decomposed MLP variants.
Dense:
- W_in: Linear(hidden_dim β†’ intermediate_dim)
- W_gate: Linear(hidden_dim β†’ intermediate_dim)
- W_out: Linear(intermediate_dim β†’ hidden_dim)
Decomposed:
- W_in_shared: Linear(hidden_dim β†’ rank, bias=False)
- W_in_shared_norm: RMSNorm
- W_in: Linear(rank β†’ intermediate_dim)
- W_gate_shared: Linear(hidden_dim β†’ rank, bias=False)
- W_gate_shared_norm: RMSNorm
- W_gate: Linear(rank β†’ intermediate_dim)
- W_out: Linear(intermediate_dim β†’ rank, bias=False)
- W_out_shared: Linear(rank β†’ hidden_dim)
Residual, dropout, and post-norm are handled inside the block.
"""
def __init__(self, config, layer_idx):
super().__init__()
#dropout_prob = config.hidden_dropout_prob # TODO - Style -- don't define variables if only used once.
# Determine whether this is a dense or decomposed layer.
# It's dense if either:
# - ffn_decompose is disabled (no dense layers at all)
# - ffn_decompose is enabled, but this is one of the early dense layers.
self.is_dense = (not config.ffn_decompose) or (layer_idx < config.num_dense_layers)
hidden_dim = config.hidden_size
intermediate_dim = config.intermediate_size # TODO - Find something shorter, and use the same name.
# If it's one of the dense layers,
if self.is_dense:
# === Dense FFN Projections ===
self.W_in = nn.Linear(hidden_dim, intermediate_dim)
self.W_gate = nn.Linear(hidden_dim, intermediate_dim)
self.W_out = nn.Linear(intermediate_dim, hidden_dim)
# Define weights for the decomposed version.
else:
rank = config.ffn_rank
print("hidden_dim:", hidden_dim)
print("rank:", rank)
# === Input Projections ===
self.W_in_shared = nn.Linear(hidden_dim, rank, bias=False)
self.W_in_shared_norm = create_norm_layer(rank, config)
self.W_in = nn.Linear(rank, intermediate_dim, bias=True)
# === Gate Projections ===
self.W_gate_shared = nn.Linear(hidden_dim, rank, bias=False)
self.W_gate_shared_norm = create_norm_layer(rank, config)
self.W_gate = nn.Linear(rank, intermediate_dim, bias=True)
# === Output Projection ===
self.W_out = nn.Linear(intermediate_dim, rank, bias=False)
# TODO - Could experiment with this.
#self.W_out_shared_layernorm = DeepseekV3RMSNorm(rank, eps=config.eps)
self.W_out_shared = nn.Linear(rank, hidden_dim, bias=True)
# See notes no dropout
#self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# === Tensor Dimension Symbols ===
# B: batch_size β€” number of samples in the batch
# T: seq_len β€” number of tokens per sample
# D: hidden_dim β€” model embedding size
# R: ffn_rank β€” latent shared subspace dimension
# D_ff: intermediate_size β€” FFN hidden dimension
# =========================
# Gated Feedforward
# =========================
if self.is_dense:
# =============
# Dense
# =============
# Input: x [B, T, D]
# Output: x_proj [B, T, D_ff]
x_proj = self.W_in(x)
# Output: gate [B, T, D_ff]
gate = self.W_gate(x)
# SwiGLU nonlinearity
x = F.silu(x_proj) * gate # [B, T, D_ff]
# See notes on dropout
#x = self.dropout(x)
# Output: x [B, T, D]
x = self.W_out(x)
else:
# ==================
# Decomposed
# ==================
# Input: x [B, T, D]
# Output: x_proj [B, T, D_ff]
x_proj = self.W_in(self.W_in_shared_norm(self.W_in_shared(x)))
# Input: x [B, T, D]
# Output: gate [B, T, D_ff]
gate = self.W_gate(self.W_gate_shared_norm(self.W_gate_shared(x)))
# SwiGLU nonlinearity
x = F.silu(x_proj) * gate # [B, T, D_ff]
# See notes on dropout
#x = self.dropout(x)
# Output: x [B, T, D]
x = self.W_out_shared(self.W_out(x))
return x