Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # decoder.py | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| pe = torch.zeros(max_len, d_model) # [max_len, d_model] | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1] | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) # dim 2i | |
| pe[:, 1::2] = torch.cos(position * div_term) # dim 2i+1 | |
| pe = pe.unsqueeze(1) # [max_len, 1, d_model] | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| # x: [seq_len, batch_size, d_model] | |
| x = x + self.pe[:x.size(0)] | |
| return x | |
| def generate_square_subsequent_mask(sz): | |
| mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
| mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
| return mask | |
| class TransformerDecoder(nn.Module): | |
| def __init__(self, vocab_size, hidden_dim=512, encoder_dim=768, num_layers=2): | |
| super(TransformerDecoder, self).__init__() | |
| self.vocab_size = vocab_size | |
| self.embedding = nn.Embedding(vocab_size, hidden_dim) | |
| self.positional_encoding = PositionalEncoding(hidden_dim) | |
| decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8) | |
| self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) | |
| self.fc_out = nn.Linear(hidden_dim, vocab_size) | |
| # Project ViT encoder output to decoder hidden_dim if needed | |
| self.encoder_projection = nn.Linear(encoder_dim, hidden_dim) | |
| def forward(self, input_ids, encoder_outputs, tgt_attention_mask=None): | |
| embedded = self.embedding(input_ids).permute(1, 0, 2) | |
| embedded = self.positional_encoding(embedded) | |
| memory = self.encoder_projection(encoder_outputs).unsqueeze(0) | |
| tgt_mask = generate_square_subsequent_mask(embedded.size(0)).to(embedded.device) | |
| if tgt_attention_mask is not None: | |
| tgt_key_padding_mask = ~tgt_attention_mask.bool() | |
| else: | |
| tgt_key_padding_mask = None | |
| output = self.transformer_decoder( | |
| tgt=embedded, | |
| memory=memory, | |
| tgt_mask=tgt_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask | |
| ) | |
| output = self.fc_out(output).permute(1, 0, 2) | |
| return output | |
| def generate( | |
| self, | |
| encoder_outputs, | |
| start_token_id=101, # [CLS] token for BERT | |
| eos_token_id=102, | |
| max_length=50, | |
| mode="greedy", # "greedy", "beam", "topk", "topp" | |
| num_beams=3, | |
| top_k=50, | |
| top_p=0.95, | |
| length_penalty=1.0 | |
| ): | |
| device = encoder_outputs.device | |
| """ | |
| Generate caption using specified decoding mode. | |
| """ | |
| batch_size = encoder_outputs.size(0) | |
| input_ids = torch.full( | |
| (batch_size, 1), | |
| start_token_id, | |
| dtype=torch.long, | |
| device=device | |
| ) | |
| if mode == "beam": | |
| return self._generate_beam_search( | |
| encoder_outputs, | |
| input_ids, | |
| max_length, | |
| eos_token_id, | |
| num_beams, | |
| length_penalty | |
| ) | |
| # Greedy or sampling | |
| generated = input_ids | |
| for _ in range(max_length): | |
| logits = self.forward(generated, encoder_outputs) # (batch, seq_len, vocab) | |
| next_token_logits = logits[:, -1, :] # (batch, vocab) | |
| if mode == "greedy": | |
| next_token = next_token_logits.argmax(dim=-1, keepdim=True) | |
| elif mode == "topk": | |
| probs = F.softmax(next_token_logits, dim=-1) | |
| topk_probs, topk_indices = torch.topk(probs, top_k) | |
| next_token = topk_indices[ | |
| torch.arange(probs.size(0)), | |
| torch.multinomial(topk_probs, num_samples=1).squeeze(-1) | |
| ].unsqueeze(-1) | |
| elif mode == "topp": | |
| probs = F.softmax(next_token_logits, dim=-1) | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| # Remove tokens with cumulative probs above threshold | |
| sorted_mask = cumulative_probs <= top_p | |
| sorted_mask[..., 0] = 1 # Always keep at least 1 token | |
| filtered_probs = sorted_probs * sorted_mask | |
| filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True) | |
| next_token = sorted_indices[ | |
| torch.arange(probs.size(0)), | |
| torch.multinomial(filtered_probs, num_samples=1).squeeze(-1) | |
| ].unsqueeze(-1) | |
| else: | |
| raise ValueError(f"Unknown mode: {mode}") | |
| generated = torch.cat((generated, next_token), dim=1) | |
| if eos_token_id is not None: | |
| if (next_token == eos_token_id).all(): | |
| break | |
| return generated[:, 1:] # Remove BOS if needed | |
| def _generate_beam_search( | |
| self, | |
| encoder_outputs, | |
| input_ids, | |
| max_length=50, | |
| eos_token_id=102, | |
| num_beams=3, | |
| length_penalty=1.0 | |
| ): | |
| """ | |
| Custom beam search decoder for batch_size = 1. | |
| """ | |
| device = encoder_outputs.device | |
| batch_size = encoder_outputs.size(0) | |
| vocab_size = self.vocab_size | |
| # Assume batch_size = 1 for simplicity | |
| assert batch_size == 1, "Basic beam search only supports batch size 1 here." | |
| # Initialize beams | |
| beam_sequences = [input_ids] * num_beams | |
| beam_scores = torch.zeros(num_beams, device=device) | |
| finished_sequences = [] | |
| finished_scores = [] | |
| for step in range(max_length): | |
| all_candidates = [] | |
| for beam_idx in range(num_beams): | |
| seq = beam_sequences[beam_idx] | |
| score = beam_scores[beam_idx] | |
| logits = self.forward(seq, encoder_outputs) # (1, seq_len, vocab) | |
| next_token_logits = logits[:, -1, :] # (1, vocab) | |
| log_probs = F.log_softmax(next_token_logits, dim=-1).squeeze(0) # (vocab,) | |
| for token_id in range(vocab_size): | |
| new_seq = torch.cat([seq, torch.tensor([[token_id]], device=device)], dim=1) | |
| new_score = score + log_probs[token_id] | |
| all_candidates.append((new_seq, new_score)) | |
| # Get top beams | |
| all_candidates.sort(key=lambda x: x[1], reverse=True) | |
| beam_sequences = [] | |
| beam_scores = [] | |
| for seq, score in all_candidates[:num_beams]: | |
| if eos_token_id is not None and seq[0, -1].item() == eos_token_id: | |
| finished_sequences.append(seq) | |
| finished_scores.append(score) | |
| else: | |
| beam_sequences.append(seq) | |
| beam_scores.append(score) | |
| beam_scores = torch.stack(beam_scores) if beam_scores else torch.tensor([], device=device) | |
| # Early stopping if all beams ended | |
| if len(beam_sequences) == 0: | |
| break | |
| # Add unfinished beams to finished | |
| if not finished_sequences: | |
| finished_sequences = beam_sequences | |
| finished_scores = beam_scores | |
| # Length penalty | |
| finished_scores = [s / (len(seq[0]) ** length_penalty) for seq, s in zip(finished_sequences, finished_scores)] | |
| # Pick best | |
| best_idx = torch.tensor(finished_scores).argmax().item() | |
| best_seq = finished_sequences[best_idx] | |
| return best_seq[:, 1:] # remove BOS if needed | |