File size: 6,767 Bytes
d56eb1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
"""# ββββββββββββ
# `feedforward.py`
Regarding dropout:
- I don't see it applied to the MoE in DeepSeek-V3, [here](https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py).
- I don't see it applied in [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L140)
Norms:
* nn.RMSNorm [here](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html)
## FFN
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .shared_space_config import SharedSpaceDecoderConfig
def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
"""
Create a normalization layer based on the config norm_type.
Args:
hidden_size: The dimension to normalize over
config: Configuration containing norm_type and epsilon values
Returns:
Either a LayerNorm or RMSNorm layer
"""
if config.norm_type == "layernorm":
return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
elif config.norm_type == "rmsnorm":
return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
else:
# This should be caught by config validation, but being defensive
raise ValueError(f"Unknown norm_type: {config.norm_type}")
# TODO - Find a shared place to put this.
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class SubspaceFeedForward(nn.Module):
"""
Feed-forward block for SharedSpaceDecoder.
Implements SwiGLU:
FFN(x) = W_out( Swish(W_in(x)) β W_gate(x) ) + residual
Supports both dense and decomposed MLP variants.
Dense:
- W_in: Linear(hidden_dim β intermediate_dim)
- W_gate: Linear(hidden_dim β intermediate_dim)
- W_out: Linear(intermediate_dim β hidden_dim)
Decomposed:
- W_in_shared: Linear(hidden_dim β rank, bias=False)
- W_in_shared_norm: RMSNorm
- W_in: Linear(rank β intermediate_dim)
- W_gate_shared: Linear(hidden_dim β rank, bias=False)
- W_gate_shared_norm: RMSNorm
- W_gate: Linear(rank β intermediate_dim)
- W_out: Linear(intermediate_dim β rank, bias=False)
- W_out_shared: Linear(rank β hidden_dim)
Residual, dropout, and post-norm are handled inside the block.
"""
def __init__(self, config, layer_idx):
super().__init__()
#dropout_prob = config.hidden_dropout_prob # TODO - Style -- don't define variables if only used once.
# Determine whether this is a dense or decomposed layer.
# It's dense if either:
# - ffn_decompose is disabled (no dense layers at all)
# - ffn_decompose is enabled, but this is one of the early dense layers.
self.is_dense = (not config.ffn_decompose) or (layer_idx < config.num_dense_layers)
hidden_dim = config.hidden_size
intermediate_dim = config.intermediate_size # TODO - Find something shorter, and use the same name.
# If it's one of the dense layers,
if self.is_dense:
# === Dense FFN Projections ===
self.W_in = nn.Linear(hidden_dim, intermediate_dim)
self.W_gate = nn.Linear(hidden_dim, intermediate_dim)
self.W_out = nn.Linear(intermediate_dim, hidden_dim)
# Define weights for the decomposed version.
else:
rank = config.ffn_rank
print("hidden_dim:", hidden_dim)
print("rank:", rank)
# === Input Projections ===
self.W_in_shared = nn.Linear(hidden_dim, rank, bias=False)
self.W_in_shared_norm = create_norm_layer(rank, config)
self.W_in = nn.Linear(rank, intermediate_dim, bias=True)
# === Gate Projections ===
self.W_gate_shared = nn.Linear(hidden_dim, rank, bias=False)
self.W_gate_shared_norm = create_norm_layer(rank, config)
self.W_gate = nn.Linear(rank, intermediate_dim, bias=True)
# === Output Projection ===
self.W_out = nn.Linear(intermediate_dim, rank, bias=False)
# TODO - Could experiment with this.
#self.W_out_shared_layernorm = DeepseekV3RMSNorm(rank, eps=config.eps)
self.W_out_shared = nn.Linear(rank, hidden_dim, bias=True)
# See notes no dropout
#self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# === Tensor Dimension Symbols ===
# B: batch_size β number of samples in the batch
# T: seq_len β number of tokens per sample
# D: hidden_dim β model embedding size
# R: ffn_rank β latent shared subspace dimension
# D_ff: intermediate_size β FFN hidden dimension
# =========================
# Gated Feedforward
# =========================
if self.is_dense:
# =============
# Dense
# =============
# Input: x [B, T, D]
# Output: x_proj [B, T, D_ff]
x_proj = self.W_in(x)
# Output: gate [B, T, D_ff]
gate = self.W_gate(x)
# SwiGLU nonlinearity
x = F.silu(x_proj) * gate # [B, T, D_ff]
# See notes on dropout
#x = self.dropout(x)
# Output: x [B, T, D]
x = self.W_out(x)
else:
# ==================
# Decomposed
# ==================
# Input: x [B, T, D]
# Output: x_proj [B, T, D_ff]
x_proj = self.W_in(self.W_in_shared_norm(self.W_in_shared(x)))
# Input: x [B, T, D]
# Output: gate [B, T, D_ff]
gate = self.W_gate(self.W_gate_shared_norm(self.W_gate_shared(x)))
# SwiGLU nonlinearity
x = F.silu(x_proj) * gate # [B, T, D_ff]
# See notes on dropout
#x = self.dropout(x)
# Output: x [B, T, D]
x = self.W_out_shared(self.W_out(x))
return x
|