File size: 13,148 Bytes
d56eb1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
# -*- coding: utf-8 -*-

"""# shared_subspace_encoder.py"""

from typing import Optional

import torch
from torch import nn

from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa

from .mla import MultiheadLatentAttention, RotaryEmbedding
from .feedforward import SubspaceFeedForward
from .shared_space_config import SharedSpaceDecoderConfig

"""`RMSNorm`

From:
https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py

TODO - May not need?
"""

class DeepseekV3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DeepseekV3RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
    """
    Create a normalization layer based on the config norm_type.
    
    Args:
        hidden_size: The dimension to normalize over
        config: Configuration containing norm_type and epsilon values
    
    Returns:
        Either a LayerNorm or RMSNorm layer
    """
    if config.norm_type == "layernorm":
        return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
    elif config.norm_type == "rmsnorm":
        return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
    else:
        # This should be caught by config validation, but being defensive
        raise ValueError(f"Unknown norm_type: {config.norm_type}")

"""#### *PreTrainedModel"""

class SharedSpaceDecoderPreTrainedModel(PreTrainedModel):
    """
    The **PreTrainedModel object:
      - Is instantiated when TODO
      - Initializes:
        - TODO
      - Provides access to TODO
      - Executes TODO
    """

    config_class = SharedSpaceDecoderConfig
    base_model_prefix = "model"

    def _init_weights(self, module: nn.Module) -> None:
        """Weight initialization hook used by :class:`PreTrainedModel`.

        ``PreTrainedModel.post_init`` will recursively apply this function to
        every submodule right after construction.  HuggingFace models override
        it so that creating a model from scratch yields the same initialization
        as ``from_pretrained`` when no checkpoint is supplied.

        This decoder-specific initialization strategy includes:
        - Proper handling of configurable normalization layers (LayerNorm or RMSNorm)
        - Special initialization for language modeling heads
        - Considerations for causal attention and autoregressive modeling
        - Support for both dense and decomposed vocabulary embeddings
        """

        if isinstance(module, nn.Linear):
            # Standard linear layer initialization
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
                
        elif isinstance(module, nn.Embedding):
            # Initialize embeddings with normal distribution
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
                
        elif isinstance(module, DeepseekV3RMSNorm):
            # RMSNorm initialization: weight to 1.0, no bias term
            module.weight.data.fill_(1.0)
            
        elif isinstance(module, nn.LayerNorm):
            # LayerNorm initialization: bias to 0, weight to 1.0
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

"""# β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚

# Classes
"""

"""#### `*Layer`"""

class SharedSpaceDecoderLayer(nn.Module):
    """
    The **Layer object:
      - Is instantiated by :class:`SharedSpaceDecoderModel` for each
        Transformer block in the decoder.
      - Initializes:
        - ``self_attn`` – multi-head latent attention implementing either
          dense or latent projections depending on the configuration.
        - ``ffn`` – a :class:`SubspaceFeedForward` block.
        - RMSNorm layers for pre-attention and pre-FFN normalization.
      - Provides access to the attention and feed-forward submodules via the
        attributes ``self_attn`` and ``ffn``.
      - Executes a single decoder block in :meth:`forward`.
    """

    def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int) -> None:

        super().__init__()

        # Norm applied prior to attention.
        self.attn_input_norm = create_norm_layer(config.hidden_size, config)
        
        # Attention block
        self.self_attn = MultiheadLatentAttention(config, layer_idx)

        # Norm applied prior to FFN
        self.ffn_input_norm = create_norm_layer(config.hidden_size, config)

        # Feed-forward network used after attention
        self.ffn = SubspaceFeedForward(config, layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor], # RoPE embeddings
        attention_mask: Optional[torch.Tensor],
    ) -> torch.Tensor:

        # ========================
        #     Self Attention
        # ========================
        residual_strm = hidden_states

        # Normalize the hidden states to create the input to attention.
        attn_input = self.attn_input_norm(hidden_states)

        # Evaluate
        attn_output = self.self_attn(
            attn_input,
            position_embeddings,
            attention_mask,
        )

        # Add the attention output (the residual) back to the non-normalized
        # hidden_states.
        hidden_states = residual_strm + attn_output

        # ===========================
        #     Feed-Forward Network
        # ===========================
        residual_strm = hidden_states

        # Normalize the updated hidden states prior to the FFN
        ffn_input = self.ffn_input_norm(hidden_states)

        # Evaluate
        ffn_output = self.ffn(ffn_input)

        # Add the output the un-normalized hidden states.
        hidden_states = residual_strm + ffn_output

        return hidden_states

"""#### *Model"""

class SharedSpaceDecoderModel(SharedSpaceDecoderPreTrainedModel):
    """
    The **Model object:
      - Initializes:
        - The vocabulary embeddings (and optional decomposition)
        - Position embeddings (calculated in RotaryEmbedding)
        - All of the **Layer objects.
      - Provides interface to vocab embeddings.
      - Executes the whole decoder model in `forward` with causal attention.

      This is the base decoder without the language modeling head.
      Use SubspaceDecoderForCausalLM for language modeling tasks.
    """

    def __init__(self, config: SharedSpaceDecoderConfig) -> None:
        super().__init__(config)

        # ============================
        #    Vocabulary Embeddings
        # ============================
        # Decomposing the vocabulary (if enabled) defines a shared projection
        # which constrains the model to store semantic information (and
        # whatever other static token knowledge) into a limited set of
        # feature directions.

        # If we're decomposing the token embeddings,
        # TODO - Rename to vocab_subspace.
        if config.vocab_subspace:

            # Create the embedding table. Vocabulary embeddings are learned
            # in a lower dimensional latent space.
            self.vocab_embed = nn.Embedding(
                config.vocab_size, # Number of tokens
                config.vocab_rank  # Subspace dimension
            )

            # Create a
            # Selected token latents will be projected up to model size.
            # vocab_proj has shape [vocab_rank x model_size]
            self.vocab_proj = nn.Linear(
                config.vocab_rank,  # Size of latents
                config.hidden_size, # Model size
                bias=False
            )

        # Otherwise, for a dense vocabulary,
        else:
            # Create the dense embedding table in model space.
            self.vocab_embed = nn.Embedding(
                config.vocab_size,  # Number of tokens
                config.hidden_size  # Model size
            )

            self.vocab_proj = None

        # =====================
        #   RoPE Embeddings
        # =====================

        # Pre-computes the table of RoPE embeddings, leaving them in
        # GPU memory.
        self.rope = RotaryEmbedding(config)

        # ===================
        #    Create Layers
        # ===================

        layers = []

        # For each layer,
        for i in range(config.num_hidden_layers):
            # Create a **Layer, providing the config and indicating its number.
            layers.append(
                SharedSpaceDecoderLayer(
                    config,
                    layer_idx = i
                )
            )

        # Wrap in torch ModuleList
        self.layers = nn.ModuleList(layers)

        # Whatever huggingface does behind the scenes...
        self.post_init()

    # Agents: Do not define boilerplate helpers, e.g., get/set_input_embeddings


    def embed(self, input_ids: torch.LongTensor) -> torch.Tensor:
        """
        Return token embeddings for input ids.
        This will perform the up projection to model space if the vocabulary is
        decomposed.

        input_ids have shape [batch_size, seq_len]
        """

        # If the vocabulary is decomposed,
        if self.vocab_proj is not None:

            # Retrieve the latents
            #  input_ids: [batch_size, seq_len]
            #          x: [batch_size, seq_len, latent_dim]
            x = self.vocab_embed(input_ids)

            #  Project the latents back to model space and return.
            return(self.vocab_proj(x))

        # If the vocabulary is dense,
        else:
            # Just return the embeddings.
            return self.vocab_embed(input_ids)

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        """
        Run the full decoder stack with causal attention.

        Inputs:
            input_ids       [batch_size, seq_len]
            attention_mask  [batch_size, seq_len] - 1 for real tokens, 0 for padding

        Returns:
            Final decoder layer output   [batch_size, seq_len, model_size]
        """

        # Retrieve the token embeddings for this sequence.
        # These are model_size, regardless of whether the vocab is decompd.
        hidden_states = self.embed(input_ids)

        # Retrieve the rotary position embeddings for all of the positions in
        # our current input sequence.

        seq_len = hidden_states.size(1)

        # Retrieves just the ones necessary for the sequence length of the
        # input. These are vectors, two per token. Their length is the
        # number of head dimensions we're applying RoPE to.
        #  Input
        #     cos: [max_seq_len, rope_dims]
        #     sin: [max_seq_len, rope_dims]
        #  Outputs:
        #     R_cos [seq_len, rope_dims]
        #     R_sin [seq_len, rope_dims]
        R_cos = self.rope.cos[:seq_len]
        R_sin = self.rope.sin[:seq_len]


        # ===============================
        #   Attention Mask Conversion
        # ===============================

        """
        use_sdpa_attention_masks = (
            self.attn_implementation == "sdpa"
            and self.position_embedding_type == "absolute"
            and head_mask is None
            and not output_attentions
        )
        """

        # Expand the attention mask
        #if use_sdpa_attention_masks and attention_mask.dim() == 2:
        if True:
            # Expand the attention mask for SDPA.
            # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
            extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
                attention_mask,
                hidden_states.dtype,
                tgt_len = seq_len
            )
            attention_mask = extended_attention_mask


        # Run the model!

        # For each decoder layer,
        for layer_i, layer in enumerate(self.layers):

            # Evaluate the layer
            hidden_states = layer(
                hidden_states,       # Token embeddings
                (R_cos, R_sin),      # Rope embeddings, passed as a tuple.
                attention_mask,      # Attn mask
            )

        # Return the final output of the decoder stack.
        return hidden_states