wsntxxn
commited on
Commit
·
dd3d338
1
Parent(s):
f729a94
Change to Hugging Face calling
Browse files- app.py +28 -35
- checkpoints/audiocaps/ckpt.pth +0 -3
- checkpoints/audiocaps/config.yaml +0 -30
- checkpoints/clotho/ckpt.pth +0 -3
- checkpoints/clotho/config.yaml +0 -30
- models/__init__.py +0 -92
- models/base.py +0 -504
- models/cnn_encoder.py +0 -808
- models/eff_latent_encoder.py +0 -347
- models/kd_wrapper.py +0 -226
- models/transformer_decoder.py +0 -214
- models/transformer_model.py +0 -264
- requirements.txt +2 -2
- text_tokenizer.py +0 -107
- utils/model_util.py +0 -186
- utils/train_util.py +0 -117
app.py
CHANGED
|
@@ -1,25 +1,28 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
import argparse
|
| 3 |
from functools import partial
|
| 4 |
import gradio as gr
|
| 5 |
import torch
|
| 6 |
from torchaudio.functional import resample
|
|
|
|
| 7 |
|
| 8 |
-
import utils.train_util as train_util
|
| 9 |
|
| 10 |
-
|
| 11 |
-
def load_model(cfg,
|
| 12 |
-
ckpt_path,
|
| 13 |
device):
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
return model, tokenizer
|
| 24 |
|
| 25 |
|
|
@@ -34,19 +37,13 @@ def infer(file, runner):
|
|
| 34 |
wav = wav.mean(1)
|
| 35 |
wav = resample(wav, sr, runner.target_sr)
|
| 36 |
wav_len = len(wav)
|
| 37 |
-
wav = wav.float().unsqueeze(0)
|
| 38 |
-
input_dict = {
|
| 39 |
-
"mode": "inference",
|
| 40 |
-
"wav": wav,
|
| 41 |
-
"wav_len": [wav_len],
|
| 42 |
-
"specaug": False,
|
| 43 |
-
"sample_method": "beam",
|
| 44 |
-
"beam_size": 3,
|
| 45 |
-
}
|
| 46 |
with torch.no_grad():
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
return cap
|
| 51 |
|
| 52 |
# def input_toggle(input_type):
|
|
@@ -59,16 +56,12 @@ class InferRunner:
|
|
| 59 |
|
| 60 |
def __init__(self, model_name):
|
| 61 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
|
| 65 |
-
self.target_sr = cfg["target_sr"]
|
| 66 |
|
| 67 |
def change_model(self, model_name):
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
|
| 71 |
-
self.target_sr = cfg["target_sr"]
|
| 72 |
|
| 73 |
|
| 74 |
def change_model(radio):
|
|
|
|
|
|
|
|
|
|
| 1 |
from functools import partial
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
from torchaudio.functional import resample
|
| 5 |
+
from transformers import AutoModel, PreTrainedTokenizerFast
|
| 6 |
|
|
|
|
| 7 |
|
| 8 |
+
def load_model(model_name,
|
|
|
|
|
|
|
| 9 |
device):
|
| 10 |
+
if model_name == "AudioCaps":
|
| 11 |
+
model = AutoModel.from_pretrained(
|
| 12 |
+
"wsntxxn/effb2-trm-audiocaps-captioning",
|
| 13 |
+
trust_remote_code=True
|
| 14 |
+
).to(device)
|
| 15 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
| 16 |
+
"wsntxxn/audiocaps-simple-tokenizer"
|
| 17 |
+
)
|
| 18 |
+
elif model_name == "Clotho":
|
| 19 |
+
model = AutoModel.from_pretrained(
|
| 20 |
+
"wsntxxn/effb2-trm-clotho-captioning",
|
| 21 |
+
trust_remote_code=True
|
| 22 |
+
).to(device)
|
| 23 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
| 24 |
+
"wsntxxn/clotho-simple-tokenizer"
|
| 25 |
+
)
|
| 26 |
return model, tokenizer
|
| 27 |
|
| 28 |
|
|
|
|
| 37 |
wav = wav.mean(1)
|
| 38 |
wav = resample(wav, sr, runner.target_sr)
|
| 39 |
wav_len = len(wav)
|
| 40 |
+
wav = wav.float().unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
with torch.no_grad():
|
| 42 |
+
word_idx = runner.model(
|
| 43 |
+
audio=wav,
|
| 44 |
+
audio_length=[wav_len]
|
| 45 |
+
)[0]
|
| 46 |
+
cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True)
|
| 47 |
return cap
|
| 48 |
|
| 49 |
# def input_toggle(input_type):
|
|
|
|
| 56 |
|
| 57 |
def __init__(self, model_name):
|
| 58 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 59 |
+
self.model, self.tokenizer = load_model(model_name, self.device)
|
| 60 |
+
self.target_sr = self.model.config.sample_rate
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def change_model(self, model_name):
|
| 63 |
+
self.model, self.tokenizer = load_model(model_name, self.device)
|
| 64 |
+
self.target_sr = self.model.config.sample_rate
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def change_model(radio):
|
checkpoints/audiocaps/ckpt.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e1c435b1cf05a2b0058dae6f096c4eb4e71c685a19754ed84ea1ee812257434b
|
| 3 |
-
size 55293225
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints/audiocaps/config.yaml
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
tokenizer:
|
| 2 |
-
type: text_tokenizer.DictTokenizer
|
| 3 |
-
args:
|
| 4 |
-
max_length: 20
|
| 5 |
-
|
| 6 |
-
target_sr: 16000
|
| 7 |
-
|
| 8 |
-
model:
|
| 9 |
-
args:
|
| 10 |
-
shared_dim: 1024
|
| 11 |
-
tchr_dim: 768
|
| 12 |
-
model:
|
| 13 |
-
args: {}
|
| 14 |
-
decoder:
|
| 15 |
-
args:
|
| 16 |
-
attn_emb_dim: 1408
|
| 17 |
-
dropout: 0.2
|
| 18 |
-
emb_dim: 256
|
| 19 |
-
fc_emb_dim: 1408
|
| 20 |
-
nlayers: 2
|
| 21 |
-
tie_weights: true
|
| 22 |
-
vocab_size: 4981
|
| 23 |
-
type: models.transformer_decoder.TransformerDecoder
|
| 24 |
-
encoder:
|
| 25 |
-
args:
|
| 26 |
-
freeze: false
|
| 27 |
-
pretrained: true
|
| 28 |
-
type: models.cnn_encoder.EfficientNetB2
|
| 29 |
-
type: models.transformer_model.TransformerModel
|
| 30 |
-
type: models.kd_wrapper.ContraEncoderKdWrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints/clotho/ckpt.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:694c9e7139be7ec5aff2153d1af980d6bc305403a76be0d8940481579ea51483
|
| 3 |
-
size 54651005
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints/clotho/config.yaml
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
tokenizer:
|
| 2 |
-
type: text_tokenizer.DictTokenizer
|
| 3 |
-
args:
|
| 4 |
-
max_length: 20
|
| 5 |
-
|
| 6 |
-
target_sr: 16000
|
| 7 |
-
|
| 8 |
-
model:
|
| 9 |
-
args:
|
| 10 |
-
shared_dim: 1024
|
| 11 |
-
tchr_dim: 768
|
| 12 |
-
model:
|
| 13 |
-
args: {}
|
| 14 |
-
decoder:
|
| 15 |
-
args:
|
| 16 |
-
attn_emb_dim: 1408
|
| 17 |
-
dropout: 0.2
|
| 18 |
-
emb_dim: 256
|
| 19 |
-
fc_emb_dim: 1408
|
| 20 |
-
nlayers: 2
|
| 21 |
-
tie_weights: true
|
| 22 |
-
vocab_size: 4368
|
| 23 |
-
type: models.transformer_decoder.TransformerDecoder
|
| 24 |
-
encoder:
|
| 25 |
-
args:
|
| 26 |
-
freeze: false
|
| 27 |
-
pretrained: true
|
| 28 |
-
type: models.cnn_encoder.EfficientNetB2
|
| 29 |
-
type: models.transformer_model.TransformerModel
|
| 30 |
-
type: models.kd_wrapper.ContraEncoderKdWrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/__init__.py
DELETED
|
@@ -1,92 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
|
| 5 |
-
from utils.model_util import max_with_lens, mean_with_lens
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def embedding_pooling(x, lens, pooling="mean"):
|
| 9 |
-
if pooling == "max":
|
| 10 |
-
fc_embs = max_with_lens(x, lens)
|
| 11 |
-
elif pooling == "mean":
|
| 12 |
-
fc_embs = mean_with_lens(x, lens)
|
| 13 |
-
elif pooling == "mean+max":
|
| 14 |
-
x_mean = mean_with_lens(x, lens)
|
| 15 |
-
x_max = max_with_lens(x, lens)
|
| 16 |
-
fc_embs = x_mean + x_max
|
| 17 |
-
elif pooling == "last":
|
| 18 |
-
indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
|
| 19 |
-
# indices: [N, 1, hidden]
|
| 20 |
-
fc_embs = torch.gather(x, 1, indices).squeeze(1)
|
| 21 |
-
else:
|
| 22 |
-
raise Exception(f"pooling method {pooling} not support")
|
| 23 |
-
return fc_embs
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class BaseEncoder(nn.Module):
|
| 27 |
-
|
| 28 |
-
"""
|
| 29 |
-
Encode the given audio into embedding
|
| 30 |
-
Base encoder class, cannot be called directly
|
| 31 |
-
All encoders should inherit from this class
|
| 32 |
-
"""
|
| 33 |
-
|
| 34 |
-
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
|
| 35 |
-
super(BaseEncoder, self).__init__()
|
| 36 |
-
self.spec_dim = spec_dim
|
| 37 |
-
self.fc_feat_dim = fc_feat_dim
|
| 38 |
-
self.attn_feat_dim = attn_feat_dim
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def forward(self, x):
|
| 42 |
-
#########################
|
| 43 |
-
# Arguments:
|
| 44 |
-
# `x`: {
|
| 45 |
-
# (may contain)
|
| 46 |
-
# wav: [batch_size, n_samples],
|
| 47 |
-
# spec: [batch_size, n_frames, spec_dim],
|
| 48 |
-
# fc: [batch_size, fc_feat_dim],
|
| 49 |
-
# attn: [batch_size, attn_max_len, attn_feat_dim],
|
| 50 |
-
# attn_len: [batch_size,]
|
| 51 |
-
# ......
|
| 52 |
-
# }
|
| 53 |
-
#
|
| 54 |
-
# Returns:
|
| 55 |
-
# `encoded`: {
|
| 56 |
-
# fc_emb: [batch_size, fc_emb_dim],
|
| 57 |
-
# attn_emb: [batch_size, attn_max_len, attn_emb_dim],
|
| 58 |
-
# attn_emb_lens: [batch_size,]
|
| 59 |
-
# }
|
| 60 |
-
#########################
|
| 61 |
-
raise NotImplementedError
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
class BaseDecoder(nn.Module):
|
| 65 |
-
"""
|
| 66 |
-
Take word/audio embeddings and output the next word probs
|
| 67 |
-
"""
|
| 68 |
-
def __init__(self, emb_dim, vocab_size, fc_emb_dim,
|
| 69 |
-
attn_emb_dim, dropout=0.2, tie_weights=False):
|
| 70 |
-
super().__init__()
|
| 71 |
-
self.emb_dim = emb_dim
|
| 72 |
-
self.vocab_size = vocab_size
|
| 73 |
-
self.fc_emb_dim = fc_emb_dim
|
| 74 |
-
self.attn_emb_dim = attn_emb_dim
|
| 75 |
-
self.tie_weights = tie_weights
|
| 76 |
-
self.word_embedding = nn.Embedding(vocab_size, emb_dim)
|
| 77 |
-
self.in_dropout = nn.Dropout(dropout)
|
| 78 |
-
|
| 79 |
-
def forward(self, x):
|
| 80 |
-
raise NotImplementedError
|
| 81 |
-
|
| 82 |
-
def load_word_embedding(self, weight, freeze=True):
|
| 83 |
-
embedding = np.load(weight)
|
| 84 |
-
assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
|
| 85 |
-
assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
|
| 86 |
-
|
| 87 |
-
# embeddings = torch.as_tensor(embeddings).float()
|
| 88 |
-
# self.word_embeddings.weight = nn.Parameter(embeddings)
|
| 89 |
-
# for para in self.word_embeddings.parameters():
|
| 90 |
-
# para.requires_grad = tune
|
| 91 |
-
self.word_embedding = nn.Embedding.from_pretrained(embedding,
|
| 92 |
-
freeze=freeze)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/base.py
DELETED
|
@@ -1,504 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
from typing import Dict
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
|
| 8 |
-
from utils.model_util import mean_with_lens, repeat_tensor
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class CaptionMetaMixin:
|
| 12 |
-
pad_idx = 0
|
| 13 |
-
start_idx = 1
|
| 14 |
-
end_idx = 2
|
| 15 |
-
max_length = 20
|
| 16 |
-
|
| 17 |
-
@classmethod
|
| 18 |
-
def set_index(cls, start_idx, end_idx, pad_idx):
|
| 19 |
-
cls.start_idx = start_idx
|
| 20 |
-
cls.end_idx = end_idx
|
| 21 |
-
cls.pad_idx = pad_idx
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class CaptionModel(nn.Module, CaptionMetaMixin):
|
| 25 |
-
"""
|
| 26 |
-
Encoder-decoder captioning model.
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
| 30 |
-
super().__init__()
|
| 31 |
-
self.encoder = encoder
|
| 32 |
-
self.decoder = decoder
|
| 33 |
-
self.vocab_size = decoder.vocab_size
|
| 34 |
-
self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
|
| 35 |
-
self.inference_forward_keys = ["sample_method", "max_length", "temp"]
|
| 36 |
-
freeze_encoder = kwargs.get("freeze_encoder", False)
|
| 37 |
-
if freeze_encoder:
|
| 38 |
-
for param in self.encoder.parameters():
|
| 39 |
-
param.requires_grad = False
|
| 40 |
-
self.check_decoder_compatibility()
|
| 41 |
-
|
| 42 |
-
def check_decoder_compatibility(self):
|
| 43 |
-
compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
|
| 44 |
-
assert isinstance(self.decoder, self.compatible_decoders), \
|
| 45 |
-
f"{self.decoder.__class__.__name__} is incompatible with " \
|
| 46 |
-
f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
|
| 47 |
-
|
| 48 |
-
def forward(self, input_dict: Dict):
|
| 49 |
-
"""
|
| 50 |
-
input_dict: {
|
| 51 |
-
(required)
|
| 52 |
-
mode: train/inference,
|
| 53 |
-
[spec, spec_len],
|
| 54 |
-
[fc],
|
| 55 |
-
[attn, attn_len],
|
| 56 |
-
[wav, wav_len],
|
| 57 |
-
[sample_method: greedy],
|
| 58 |
-
[temp: 1.0] (in case of no teacher forcing)
|
| 59 |
-
|
| 60 |
-
(optional, mode=train)
|
| 61 |
-
cap,
|
| 62 |
-
cap_len,
|
| 63 |
-
ss_ratio,
|
| 64 |
-
|
| 65 |
-
(optional, mode=inference)
|
| 66 |
-
sample_method: greedy/beam,
|
| 67 |
-
max_length,
|
| 68 |
-
temp,
|
| 69 |
-
beam_size (optional, sample_method=beam),
|
| 70 |
-
n_best (optional, sample_method=beam),
|
| 71 |
-
}
|
| 72 |
-
"""
|
| 73 |
-
encoder_output_dict = self.encoder(input_dict)
|
| 74 |
-
output = self.forward_decoder(input_dict, encoder_output_dict)
|
| 75 |
-
return output
|
| 76 |
-
|
| 77 |
-
def forward_decoder(self, input_dict: Dict, encoder_output_dict: Dict):
|
| 78 |
-
if input_dict["mode"] == "train":
|
| 79 |
-
forward_dict = {
|
| 80 |
-
"mode": "train", "sample_method": "greedy", "temp": 1.0
|
| 81 |
-
}
|
| 82 |
-
for key in self.train_forward_keys:
|
| 83 |
-
forward_dict[key] = input_dict[key]
|
| 84 |
-
forward_dict.update(encoder_output_dict)
|
| 85 |
-
output = self.train_forward(forward_dict)
|
| 86 |
-
elif input_dict["mode"] == "inference":
|
| 87 |
-
forward_dict = {"mode": "inference"}
|
| 88 |
-
default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
|
| 89 |
-
for key in self.inference_forward_keys:
|
| 90 |
-
if key in input_dict:
|
| 91 |
-
forward_dict[key] = input_dict[key]
|
| 92 |
-
else:
|
| 93 |
-
forward_dict[key] = default_args[key]
|
| 94 |
-
|
| 95 |
-
if forward_dict["sample_method"] == "beam":
|
| 96 |
-
forward_dict["beam_size"] = input_dict.get("beam_size", 3)
|
| 97 |
-
forward_dict["n_best"] = input_dict.get("n_best", False)
|
| 98 |
-
forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
|
| 99 |
-
elif forward_dict["sample_method"] == "dbs":
|
| 100 |
-
forward_dict["beam_size"] = input_dict.get("beam_size", 6)
|
| 101 |
-
forward_dict["group_size"] = input_dict.get("group_size", 3)
|
| 102 |
-
forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
|
| 103 |
-
forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
|
| 104 |
-
|
| 105 |
-
forward_dict.update(encoder_output_dict)
|
| 106 |
-
output = self.inference_forward(forward_dict)
|
| 107 |
-
else:
|
| 108 |
-
raise Exception("mode should be either 'train' or 'inference'")
|
| 109 |
-
output.update(encoder_output_dict)
|
| 110 |
-
return output
|
| 111 |
-
|
| 112 |
-
def prepare_output(self, input_dict):
|
| 113 |
-
output = {}
|
| 114 |
-
batch_size = input_dict["fc_emb"].size(0)
|
| 115 |
-
if input_dict["mode"] == "train":
|
| 116 |
-
max_length = input_dict["cap"].size(1) - 1
|
| 117 |
-
elif input_dict["mode"] == "inference":
|
| 118 |
-
max_length = input_dict["max_length"]
|
| 119 |
-
else:
|
| 120 |
-
raise Exception("mode should be either 'train' or 'inference'")
|
| 121 |
-
device = input_dict["fc_emb"].device
|
| 122 |
-
output["seq"] = torch.full((batch_size, max_length), self.end_idx,
|
| 123 |
-
dtype=torch.long)
|
| 124 |
-
output["logit"] = torch.empty(batch_size, max_length,
|
| 125 |
-
self.vocab_size).to(device)
|
| 126 |
-
output["sampled_logprob"] = torch.zeros(batch_size, max_length)
|
| 127 |
-
output["embed"] = torch.empty(batch_size, max_length,
|
| 128 |
-
self.decoder.d_model).to(device)
|
| 129 |
-
return output
|
| 130 |
-
|
| 131 |
-
def train_forward(self, input_dict):
|
| 132 |
-
if input_dict["ss_ratio"] != 1: # scheduled sampling training
|
| 133 |
-
input_dict["mode"] = "train"
|
| 134 |
-
return self.stepwise_forward(input_dict)
|
| 135 |
-
output = self.seq_forward(input_dict)
|
| 136 |
-
self.train_process(output, input_dict)
|
| 137 |
-
return output
|
| 138 |
-
|
| 139 |
-
def seq_forward(self, input_dict):
|
| 140 |
-
raise NotImplementedError
|
| 141 |
-
|
| 142 |
-
def train_process(self, output, input_dict):
|
| 143 |
-
pass
|
| 144 |
-
|
| 145 |
-
def inference_forward(self, input_dict):
|
| 146 |
-
if input_dict["sample_method"] == "beam":
|
| 147 |
-
return self.beam_search(input_dict)
|
| 148 |
-
elif input_dict["sample_method"] == "dbs":
|
| 149 |
-
return self.diverse_beam_search(input_dict)
|
| 150 |
-
return self.stepwise_forward(input_dict)
|
| 151 |
-
|
| 152 |
-
def stepwise_forward(self, input_dict):
|
| 153 |
-
"""Step-by-step decoding"""
|
| 154 |
-
output = self.prepare_output(input_dict)
|
| 155 |
-
max_length = output["seq"].size(1)
|
| 156 |
-
# start sampling
|
| 157 |
-
for t in range(max_length):
|
| 158 |
-
input_dict["t"] = t
|
| 159 |
-
self.decode_step(input_dict, output)
|
| 160 |
-
if input_dict["mode"] == "inference": # decide whether to stop when sampling
|
| 161 |
-
unfinished_t = output["seq"][:, t] != self.end_idx
|
| 162 |
-
if t == 0:
|
| 163 |
-
unfinished = unfinished_t
|
| 164 |
-
else:
|
| 165 |
-
unfinished *= unfinished_t
|
| 166 |
-
output["seq"][:, t][~unfinished] = self.end_idx
|
| 167 |
-
if unfinished.sum() == 0:
|
| 168 |
-
break
|
| 169 |
-
self.stepwise_process(output)
|
| 170 |
-
return output
|
| 171 |
-
|
| 172 |
-
def decode_step(self, input_dict, output):
|
| 173 |
-
"""Decoding operation of timestep t"""
|
| 174 |
-
decoder_input = self.prepare_decoder_input(input_dict, output)
|
| 175 |
-
# feed to the decoder to get logit
|
| 176 |
-
output_t = self.decoder(decoder_input)
|
| 177 |
-
logit_t = output_t["logit"]
|
| 178 |
-
# assert logit_t.ndim == 3
|
| 179 |
-
if logit_t.size(1) == 1:
|
| 180 |
-
logit_t = logit_t.squeeze(1)
|
| 181 |
-
embed_t = output_t["embed"].squeeze(1)
|
| 182 |
-
elif logit_t.size(1) > 1:
|
| 183 |
-
logit_t = logit_t[:, -1, :]
|
| 184 |
-
embed_t = output_t["embed"][:, -1, :]
|
| 185 |
-
else:
|
| 186 |
-
raise Exception("no logit output")
|
| 187 |
-
# sample the next input word and get the corresponding logit
|
| 188 |
-
sampled = self.sample_next_word(logit_t,
|
| 189 |
-
method=input_dict["sample_method"],
|
| 190 |
-
temp=input_dict["temp"])
|
| 191 |
-
|
| 192 |
-
output_t.update(sampled)
|
| 193 |
-
output_t["t"] = input_dict["t"]
|
| 194 |
-
output_t["logit"] = logit_t
|
| 195 |
-
output_t["embed"] = embed_t
|
| 196 |
-
self.stepwise_process_step(output, output_t)
|
| 197 |
-
|
| 198 |
-
def prepare_decoder_input(self, input_dict, output):
|
| 199 |
-
"""Prepare the inp ut dict for the decoder"""
|
| 200 |
-
raise NotImplementedError
|
| 201 |
-
|
| 202 |
-
def stepwise_process_step(self, output, output_t):
|
| 203 |
-
"""Postprocessing (save output values) after each timestep t"""
|
| 204 |
-
t = output_t["t"]
|
| 205 |
-
output["logit"][:, t, :] = output_t["logit"]
|
| 206 |
-
output["seq"][:, t] = output_t["word"]
|
| 207 |
-
output["sampled_logprob"][:, t] = output_t["probs"]
|
| 208 |
-
output["embed"][:, t, :] = output_t["embed"]
|
| 209 |
-
|
| 210 |
-
def stepwise_process(self, output):
|
| 211 |
-
"""Postprocessing after the whole step-by-step autoregressive decoding"""
|
| 212 |
-
pass
|
| 213 |
-
|
| 214 |
-
def sample_next_word(self, logit, method, temp):
|
| 215 |
-
"""Sample the next word, given probs output by the decoder"""
|
| 216 |
-
logprob = torch.log_softmax(logit, dim=1)
|
| 217 |
-
if method == "greedy":
|
| 218 |
-
sampled_logprob, word = torch.max(logprob.detach(), 1)
|
| 219 |
-
elif method == "gumbel":
|
| 220 |
-
def sample_gumbel(shape, eps=1e-20):
|
| 221 |
-
U = torch.rand(shape).to(logprob.device)
|
| 222 |
-
return -torch.log(-torch.log(U + eps) + eps)
|
| 223 |
-
def gumbel_softmax_sample(logit, temperature):
|
| 224 |
-
y = logit + sample_gumbel(logit.size())
|
| 225 |
-
return torch.log_softmax(y / temperature, dim=-1)
|
| 226 |
-
_logprob = gumbel_softmax_sample(logprob, temp)
|
| 227 |
-
_, word = torch.max(_logprob.data, 1)
|
| 228 |
-
sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
|
| 229 |
-
else:
|
| 230 |
-
logprob = logprob / temp
|
| 231 |
-
if method.startswith("top"):
|
| 232 |
-
top_num = float(method[3:])
|
| 233 |
-
if 0 < top_num < 1: # top-p sampling
|
| 234 |
-
probs = torch.softmax(logit, dim=1)
|
| 235 |
-
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
|
| 236 |
-
_cumsum = sorted_probs.cumsum(1)
|
| 237 |
-
mask = _cumsum < top_num
|
| 238 |
-
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
|
| 239 |
-
sorted_probs = sorted_probs * mask.to(sorted_probs)
|
| 240 |
-
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
|
| 241 |
-
logprob.scatter_(1, sorted_indices, sorted_probs.log())
|
| 242 |
-
else: # top-k sampling
|
| 243 |
-
k = int(top_num)
|
| 244 |
-
tmp = torch.empty_like(logprob).fill_(float('-inf'))
|
| 245 |
-
topk, indices = torch.topk(logprob, k, dim=1)
|
| 246 |
-
tmp = tmp.scatter(1, indices, topk)
|
| 247 |
-
logprob = tmp
|
| 248 |
-
word = torch.distributions.Categorical(logits=logprob.detach()).sample()
|
| 249 |
-
sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
|
| 250 |
-
word = word.detach().long()
|
| 251 |
-
# sampled_logprob: [N,], word: [N,]
|
| 252 |
-
return {"word": word, "probs": sampled_logprob}
|
| 253 |
-
|
| 254 |
-
def beam_search(self, input_dict):
|
| 255 |
-
output = self.prepare_output(input_dict)
|
| 256 |
-
max_length = input_dict["max_length"]
|
| 257 |
-
beam_size = input_dict["beam_size"]
|
| 258 |
-
if input_dict["n_best"]:
|
| 259 |
-
n_best_size = input_dict["n_best_size"]
|
| 260 |
-
batch_size, max_length = output["seq"].size()
|
| 261 |
-
output["seq"] = torch.full((batch_size, n_best_size, max_length),
|
| 262 |
-
self.end_idx, dtype=torch.long)
|
| 263 |
-
|
| 264 |
-
temp = input_dict["temp"]
|
| 265 |
-
# instance by instance beam seach
|
| 266 |
-
for i in range(output["seq"].size(0)):
|
| 267 |
-
output_i = self.prepare_beamsearch_output(input_dict)
|
| 268 |
-
input_dict["sample_idx"] = i
|
| 269 |
-
for t in range(max_length):
|
| 270 |
-
input_dict["t"] = t
|
| 271 |
-
output_t = self.beamsearch_step(input_dict, output_i)
|
| 272 |
-
#######################################
|
| 273 |
-
# merge with previous beam and select the current max prob beam
|
| 274 |
-
#######################################
|
| 275 |
-
logit_t = output_t["logit"]
|
| 276 |
-
if logit_t.size(1) == 1:
|
| 277 |
-
logit_t = logit_t.squeeze(1)
|
| 278 |
-
elif logit_t.size(1) > 1:
|
| 279 |
-
logit_t = logit_t[:, -1, :]
|
| 280 |
-
else:
|
| 281 |
-
raise Exception("no logit output")
|
| 282 |
-
logprob_t = torch.log_softmax(logit_t, dim=1)
|
| 283 |
-
logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
|
| 284 |
-
logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
|
| 285 |
-
if t == 0: # for the first step, all k seq will have the same probs
|
| 286 |
-
topk_logprob, topk_words = logprob_t[0].topk(
|
| 287 |
-
beam_size, 0, True, True)
|
| 288 |
-
else: # unroll and find top logprob, and their unrolled indices
|
| 289 |
-
topk_logprob, topk_words = logprob_t.view(-1).topk(
|
| 290 |
-
beam_size, 0, True, True)
|
| 291 |
-
topk_words = topk_words.cpu()
|
| 292 |
-
output_i["topk_logprob"] = topk_logprob
|
| 293 |
-
# output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
|
| 294 |
-
output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
|
| 295 |
-
rounding_mode='trunc')
|
| 296 |
-
output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
|
| 297 |
-
if t == 0:
|
| 298 |
-
output_i["seq"] = output_i["next_word"].unsqueeze(1)
|
| 299 |
-
else:
|
| 300 |
-
output_i["seq"] = torch.cat([
|
| 301 |
-
output_i["seq"][output_i["prev_words_beam"]],
|
| 302 |
-
output_i["next_word"].unsqueeze(1)], dim=1)
|
| 303 |
-
|
| 304 |
-
# add finished beams to results
|
| 305 |
-
is_end = output_i["next_word"] == self.end_idx
|
| 306 |
-
if t == max_length - 1:
|
| 307 |
-
is_end.fill_(1)
|
| 308 |
-
|
| 309 |
-
for beam_idx in range(beam_size):
|
| 310 |
-
if is_end[beam_idx]:
|
| 311 |
-
final_beam = {
|
| 312 |
-
"seq": output_i["seq"][beam_idx].clone(),
|
| 313 |
-
"score": output_i["topk_logprob"][beam_idx].item()
|
| 314 |
-
}
|
| 315 |
-
final_beam["score"] = final_beam["score"] / (t + 1)
|
| 316 |
-
output_i["done_beams"].append(final_beam)
|
| 317 |
-
output_i["topk_logprob"][is_end] -= 1000
|
| 318 |
-
|
| 319 |
-
self.beamsearch_process_step(output_i, output_t)
|
| 320 |
-
|
| 321 |
-
self.beamsearch_process(output, output_i, input_dict)
|
| 322 |
-
return output
|
| 323 |
-
|
| 324 |
-
def prepare_beamsearch_output(self, input_dict):
|
| 325 |
-
beam_size = input_dict["beam_size"]
|
| 326 |
-
device = input_dict["fc_emb"].device
|
| 327 |
-
output = {
|
| 328 |
-
"topk_logprob": torch.zeros(beam_size).to(device),
|
| 329 |
-
"seq": None,
|
| 330 |
-
"prev_words_beam": None,
|
| 331 |
-
"next_word": None,
|
| 332 |
-
"done_beams": [],
|
| 333 |
-
}
|
| 334 |
-
return output
|
| 335 |
-
|
| 336 |
-
def beamsearch_step(self, input_dict, output_i):
|
| 337 |
-
decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
|
| 338 |
-
output_t = self.decoder(decoder_input)
|
| 339 |
-
output_t["t"] = input_dict["t"]
|
| 340 |
-
return output_t
|
| 341 |
-
|
| 342 |
-
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
| 343 |
-
raise NotImplementedError
|
| 344 |
-
|
| 345 |
-
def beamsearch_process_step(self, output_i, output_t):
|
| 346 |
-
pass
|
| 347 |
-
|
| 348 |
-
def beamsearch_process(self, output, output_i, input_dict):
|
| 349 |
-
i = input_dict["sample_idx"]
|
| 350 |
-
done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
|
| 351 |
-
if input_dict["n_best"]:
|
| 352 |
-
done_beams = done_beams[:input_dict["n_best_size"]]
|
| 353 |
-
for out_idx, done_beam in enumerate(done_beams):
|
| 354 |
-
seq = done_beam["seq"]
|
| 355 |
-
output["seq"][i][out_idx, :len(seq)] = seq
|
| 356 |
-
else:
|
| 357 |
-
seq = done_beams[0]["seq"]
|
| 358 |
-
output["seq"][i][:len(seq)] = seq
|
| 359 |
-
|
| 360 |
-
def diverse_beam_search(self, input_dict):
|
| 361 |
-
|
| 362 |
-
def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
|
| 363 |
-
local_time = t - divm
|
| 364 |
-
unaug_logprob = logprob.clone()
|
| 365 |
-
|
| 366 |
-
if divm > 0:
|
| 367 |
-
change = torch.zeros(logprob.size(-1))
|
| 368 |
-
for prev_choice in range(divm):
|
| 369 |
-
prev_decisions = seq_table[prev_choice][..., local_time]
|
| 370 |
-
for prev_labels in range(bdash):
|
| 371 |
-
change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
|
| 372 |
-
|
| 373 |
-
change = change.to(logprob.device)
|
| 374 |
-
logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
|
| 375 |
-
|
| 376 |
-
return logprob, unaug_logprob
|
| 377 |
-
|
| 378 |
-
output = self.prepare_output(input_dict)
|
| 379 |
-
group_size = input_dict["group_size"]
|
| 380 |
-
batch_size = output["seq"].size(0)
|
| 381 |
-
beam_size = input_dict["beam_size"]
|
| 382 |
-
bdash = beam_size // group_size
|
| 383 |
-
input_dict["bdash"] = bdash
|
| 384 |
-
diversity_lambda = input_dict["diversity_lambda"]
|
| 385 |
-
device = input_dict["fc_emb"].device
|
| 386 |
-
max_length = input_dict["max_length"]
|
| 387 |
-
temp = input_dict["temp"]
|
| 388 |
-
group_nbest = input_dict["group_nbest"]
|
| 389 |
-
batch_size, max_length = output["seq"].size()
|
| 390 |
-
if group_nbest:
|
| 391 |
-
output["seq"] = torch.full((batch_size, beam_size, max_length),
|
| 392 |
-
self.end_idx, dtype=torch.long)
|
| 393 |
-
else:
|
| 394 |
-
output["seq"] = torch.full((batch_size, group_size, max_length),
|
| 395 |
-
self.end_idx, dtype=torch.long)
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
for i in range(batch_size):
|
| 399 |
-
input_dict["sample_idx"] = i
|
| 400 |
-
seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
|
| 401 |
-
logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
|
| 402 |
-
done_beams_table = [[] for _ in range(group_size)]
|
| 403 |
-
|
| 404 |
-
output_i = {
|
| 405 |
-
"prev_words_beam": [None for _ in range(group_size)],
|
| 406 |
-
"next_word": [None for _ in range(group_size)],
|
| 407 |
-
"state": [None for _ in range(group_size)]
|
| 408 |
-
}
|
| 409 |
-
|
| 410 |
-
for t in range(max_length + group_size - 1):
|
| 411 |
-
input_dict["t"] = t
|
| 412 |
-
for divm in range(group_size):
|
| 413 |
-
input_dict["divm"] = divm
|
| 414 |
-
if t >= divm and t <= max_length + divm - 1:
|
| 415 |
-
local_time = t - divm
|
| 416 |
-
decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
|
| 417 |
-
output_t = self.decoder(decoder_input)
|
| 418 |
-
output_t["divm"] = divm
|
| 419 |
-
logit_t = output_t["logit"]
|
| 420 |
-
if logit_t.size(1) == 1:
|
| 421 |
-
logit_t = logit_t.squeeze(1)
|
| 422 |
-
elif logit_t.size(1) > 1:
|
| 423 |
-
logit_t = logit_t[:, -1, :]
|
| 424 |
-
else:
|
| 425 |
-
raise Exception("no logit output")
|
| 426 |
-
logprob_t = torch.log_softmax(logit_t, dim=1)
|
| 427 |
-
logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
|
| 428 |
-
logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
|
| 429 |
-
logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
|
| 430 |
-
if local_time == 0: # for the first step, all k seq will have the same probs
|
| 431 |
-
topk_logprob, topk_words = logprob_t[0].topk(
|
| 432 |
-
bdash, 0, True, True)
|
| 433 |
-
else: # unroll and find top logprob, and their unrolled indices
|
| 434 |
-
topk_logprob, topk_words = logprob_t.view(-1).topk(
|
| 435 |
-
bdash, 0, True, True)
|
| 436 |
-
topk_words = topk_words.cpu()
|
| 437 |
-
logprob_table[divm] = topk_logprob
|
| 438 |
-
output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
|
| 439 |
-
output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
|
| 440 |
-
if local_time > 0:
|
| 441 |
-
seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
|
| 442 |
-
seq_table[divm] = torch.cat([
|
| 443 |
-
seq_table[divm],
|
| 444 |
-
output_i["next_word"][divm].unsqueeze(-1)], -1)
|
| 445 |
-
|
| 446 |
-
is_end = seq_table[divm][:, t-divm] == self.end_idx
|
| 447 |
-
assert seq_table[divm].shape[-1] == t - divm + 1
|
| 448 |
-
if t == max_length + divm - 1:
|
| 449 |
-
is_end.fill_(1)
|
| 450 |
-
for beam_idx in range(bdash):
|
| 451 |
-
if is_end[beam_idx]:
|
| 452 |
-
final_beam = {
|
| 453 |
-
"seq": seq_table[divm][beam_idx].clone(),
|
| 454 |
-
"score": logprob_table[divm][beam_idx].item()
|
| 455 |
-
}
|
| 456 |
-
final_beam["score"] = final_beam["score"] / (t - divm + 1)
|
| 457 |
-
done_beams_table[divm].append(final_beam)
|
| 458 |
-
logprob_table[divm][is_end] -= 1000
|
| 459 |
-
self.dbs_process_step(output_i, output_t)
|
| 460 |
-
done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
|
| 461 |
-
if group_nbest:
|
| 462 |
-
done_beams = sum(done_beams_table, [])
|
| 463 |
-
else:
|
| 464 |
-
done_beams = [group_beam[0] for group_beam in done_beams_table]
|
| 465 |
-
for _, done_beam in enumerate(done_beams):
|
| 466 |
-
output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
|
| 467 |
-
|
| 468 |
-
return output
|
| 469 |
-
|
| 470 |
-
def prepare_dbs_decoder_input(self, input_dict, output_i):
|
| 471 |
-
raise NotImplementedError
|
| 472 |
-
|
| 473 |
-
def dbs_process_step(self, output_i, output_t):
|
| 474 |
-
pass
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
class CaptionSequenceModel(nn.Module, CaptionMetaMixin):
|
| 478 |
-
|
| 479 |
-
def __init__(self, model, seq_output_size):
|
| 480 |
-
super().__init__()
|
| 481 |
-
self.model = model
|
| 482 |
-
if model.decoder.d_model != seq_output_size:
|
| 483 |
-
self.output_transform = nn.Linear(model.decoder.d_model, seq_output_size)
|
| 484 |
-
else:
|
| 485 |
-
self.output_transform = lambda x: x
|
| 486 |
-
|
| 487 |
-
def forward(self, input_dict):
|
| 488 |
-
output = self.model(input_dict)
|
| 489 |
-
|
| 490 |
-
if input_dict["mode"] == "train":
|
| 491 |
-
lens = input_dict["cap_len"] - 1
|
| 492 |
-
# seq_outputs: [N, d_model]
|
| 493 |
-
elif input_dict["mode"] == "inference":
|
| 494 |
-
if "sample_method" in input_dict and input_dict["sample_method"] == "beam":
|
| 495 |
-
return output
|
| 496 |
-
seq = output["seq"]
|
| 497 |
-
lens = torch.where(seq == self.model.end_idx, torch.zeros_like(seq), torch.ones_like(seq)).sum(dim=1)
|
| 498 |
-
else:
|
| 499 |
-
raise Exception("mode should be either 'train' or 'inference'")
|
| 500 |
-
seq_output = mean_with_lens(output["embed"], lens)
|
| 501 |
-
seq_output = self.output_transform(seq_output)
|
| 502 |
-
output["seq_output"] = seq_output
|
| 503 |
-
return output
|
| 504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/cnn_encoder.py
DELETED
|
@@ -1,808 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
from torchaudio import transforms
|
| 7 |
-
|
| 8 |
-
from utils.model_util import mean_with_lens, max_with_lens
|
| 9 |
-
from utils.train_util import merge_load_state_dict
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def init_layer(layer):
|
| 13 |
-
"""Initialize a Linear or Convolutional layer. """
|
| 14 |
-
nn.init.xavier_uniform_(layer.weight)
|
| 15 |
-
|
| 16 |
-
if hasattr(layer, 'bias'):
|
| 17 |
-
if layer.bias is not None:
|
| 18 |
-
layer.bias.data.fill_(0.)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def init_bn(bn):
|
| 22 |
-
"""Initialize a Batchnorm layer. """
|
| 23 |
-
bn.bias.data.fill_(0.)
|
| 24 |
-
bn.weight.data.fill_(1.)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class ConvBlock(nn.Module):
|
| 28 |
-
def __init__(self, in_channels, out_channels):
|
| 29 |
-
|
| 30 |
-
super(ConvBlock, self).__init__()
|
| 31 |
-
|
| 32 |
-
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
| 33 |
-
out_channels=out_channels,
|
| 34 |
-
kernel_size=(3, 3), stride=(1, 1),
|
| 35 |
-
padding=(1, 1), bias=False)
|
| 36 |
-
|
| 37 |
-
self.conv2 = nn.Conv2d(in_channels=out_channels,
|
| 38 |
-
out_channels=out_channels,
|
| 39 |
-
kernel_size=(3, 3), stride=(1, 1),
|
| 40 |
-
padding=(1, 1), bias=False)
|
| 41 |
-
|
| 42 |
-
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 43 |
-
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 44 |
-
|
| 45 |
-
self.init_weight()
|
| 46 |
-
|
| 47 |
-
def init_weight(self):
|
| 48 |
-
init_layer(self.conv1)
|
| 49 |
-
init_layer(self.conv2)
|
| 50 |
-
init_bn(self.bn1)
|
| 51 |
-
init_bn(self.bn2)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
|
| 55 |
-
|
| 56 |
-
x = input
|
| 57 |
-
x = F.relu_(self.bn1(self.conv1(x)))
|
| 58 |
-
x = F.relu_(self.bn2(self.conv2(x)))
|
| 59 |
-
if pool_type == 'max':
|
| 60 |
-
x = F.max_pool2d(x, kernel_size=pool_size)
|
| 61 |
-
elif pool_type == 'avg':
|
| 62 |
-
x = F.avg_pool2d(x, kernel_size=pool_size)
|
| 63 |
-
elif pool_type == 'avg+max':
|
| 64 |
-
x1 = F.avg_pool2d(x, kernel_size=pool_size)
|
| 65 |
-
x2 = F.max_pool2d(x, kernel_size=pool_size)
|
| 66 |
-
x = x1 + x2
|
| 67 |
-
else:
|
| 68 |
-
raise Exception('Incorrect argument!')
|
| 69 |
-
|
| 70 |
-
return x
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class ConvBlock5x5(nn.Module):
|
| 74 |
-
def __init__(self, in_channels, out_channels):
|
| 75 |
-
|
| 76 |
-
super(ConvBlock5x5, self).__init__()
|
| 77 |
-
|
| 78 |
-
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
| 79 |
-
out_channels=out_channels,
|
| 80 |
-
kernel_size=(5, 5), stride=(1, 1),
|
| 81 |
-
padding=(2, 2), bias=False)
|
| 82 |
-
|
| 83 |
-
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 84 |
-
|
| 85 |
-
self.init_weight()
|
| 86 |
-
|
| 87 |
-
def init_weight(self):
|
| 88 |
-
init_layer(self.conv1)
|
| 89 |
-
init_bn(self.bn1)
|
| 90 |
-
|
| 91 |
-
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
|
| 92 |
-
|
| 93 |
-
x = input
|
| 94 |
-
x = F.relu_(self.bn1(self.conv1(x)))
|
| 95 |
-
if pool_type == 'max':
|
| 96 |
-
x = F.max_pool2d(x, kernel_size=pool_size)
|
| 97 |
-
elif pool_type == 'avg':
|
| 98 |
-
x = F.avg_pool2d(x, kernel_size=pool_size)
|
| 99 |
-
elif pool_type == 'avg+max':
|
| 100 |
-
x1 = F.avg_pool2d(x, kernel_size=pool_size)
|
| 101 |
-
x2 = F.max_pool2d(x, kernel_size=pool_size)
|
| 102 |
-
x = x1 + x2
|
| 103 |
-
else:
|
| 104 |
-
raise Exception('Incorrect argument!')
|
| 105 |
-
|
| 106 |
-
return x
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
class Cnn6Encoder(nn.Module):
|
| 110 |
-
|
| 111 |
-
def __init__(self, sample_rate=32000, freeze=False):
|
| 112 |
-
super().__init__()
|
| 113 |
-
|
| 114 |
-
sr_to_fmax = {
|
| 115 |
-
32000: 14000,
|
| 116 |
-
16000: 8000
|
| 117 |
-
}
|
| 118 |
-
# Logmel spectrogram extractor
|
| 119 |
-
self.melspec_extractor = transforms.MelSpectrogram(
|
| 120 |
-
sample_rate=sample_rate,
|
| 121 |
-
n_fft=32 * sample_rate // 1000,
|
| 122 |
-
win_length=32 * sample_rate // 1000,
|
| 123 |
-
hop_length=10 * sample_rate // 1000,
|
| 124 |
-
f_min=50,
|
| 125 |
-
f_max=sr_to_fmax[sample_rate],
|
| 126 |
-
n_mels=64,
|
| 127 |
-
norm="slaney",
|
| 128 |
-
mel_scale="slaney"
|
| 129 |
-
)
|
| 130 |
-
self.hop_length = 10 * sample_rate // 1000
|
| 131 |
-
self.db_transform = transforms.AmplitudeToDB()
|
| 132 |
-
|
| 133 |
-
self.bn0 = nn.BatchNorm2d(64)
|
| 134 |
-
|
| 135 |
-
self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
|
| 136 |
-
self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
|
| 137 |
-
self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
|
| 138 |
-
self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
|
| 139 |
-
|
| 140 |
-
self.downsample_ratio = 16
|
| 141 |
-
|
| 142 |
-
self.fc1 = nn.Linear(512, 512, bias=True)
|
| 143 |
-
self.fc_emb_size = 512
|
| 144 |
-
self.init_weight()
|
| 145 |
-
self.freeze = freeze
|
| 146 |
-
|
| 147 |
-
def init_weight(self):
|
| 148 |
-
init_bn(self.bn0)
|
| 149 |
-
init_layer(self.fc1)
|
| 150 |
-
|
| 151 |
-
def load_pretrained(self, pretrained, output_fn):
|
| 152 |
-
checkpoint = torch.load(pretrained, map_location="cpu")
|
| 153 |
-
|
| 154 |
-
if "model" in checkpoint:
|
| 155 |
-
state_dict = checkpoint["model"]
|
| 156 |
-
else:
|
| 157 |
-
raise Exception("Unkown checkpoint format")
|
| 158 |
-
|
| 159 |
-
loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
|
| 160 |
-
if self.freeze:
|
| 161 |
-
for name, param in self.named_parameters():
|
| 162 |
-
if name in loaded_keys:
|
| 163 |
-
param.requires_grad = False
|
| 164 |
-
else:
|
| 165 |
-
param.requires_grad = True
|
| 166 |
-
|
| 167 |
-
def forward(self, input_dict):
|
| 168 |
-
waveform = input_dict["wav"]
|
| 169 |
-
wave_length = input_dict["wav_len"]
|
| 170 |
-
specaug = input_dict["specaug"]
|
| 171 |
-
x = self.melspec_extractor(waveform)
|
| 172 |
-
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
|
| 173 |
-
x = x.transpose(1, 2)
|
| 174 |
-
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
|
| 175 |
-
|
| 176 |
-
x = x.transpose(1, 3)
|
| 177 |
-
x = self.bn0(x)
|
| 178 |
-
x = x.transpose(1, 3)
|
| 179 |
-
|
| 180 |
-
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
|
| 181 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 182 |
-
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
|
| 183 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 184 |
-
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
|
| 185 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 186 |
-
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
|
| 187 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 188 |
-
|
| 189 |
-
x = torch.mean(x, dim=3)
|
| 190 |
-
attn_emb = x.transpose(1, 2)
|
| 191 |
-
wave_length = torch.as_tensor(wave_length)
|
| 192 |
-
feat_length = torch.div(wave_length, self.hop_length,
|
| 193 |
-
rounding_mode="floor") + 1
|
| 194 |
-
feat_length = torch.div(feat_length, self.downsample_ratio,
|
| 195 |
-
rounding_mode="floor")
|
| 196 |
-
x_max = max_with_lens(attn_emb, feat_length)
|
| 197 |
-
x_mean = mean_with_lens(attn_emb, feat_length)
|
| 198 |
-
x = x_max + x_mean
|
| 199 |
-
x = F.dropout(x, p=0.5, training=self.training)
|
| 200 |
-
x = F.relu_(self.fc1(x))
|
| 201 |
-
fc_emb = F.dropout(x, p=0.5, training=self.training)
|
| 202 |
-
|
| 203 |
-
return {
|
| 204 |
-
"attn_emb": attn_emb,
|
| 205 |
-
"fc_emb": fc_emb,
|
| 206 |
-
"attn_emb_len": feat_length
|
| 207 |
-
}
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
class Cnn10Encoder(nn.Module):
|
| 211 |
-
|
| 212 |
-
def __init__(self, sample_rate=32000, freeze=False):
|
| 213 |
-
super().__init__()
|
| 214 |
-
|
| 215 |
-
sr_to_fmax = {
|
| 216 |
-
32000: 14000,
|
| 217 |
-
16000: 8000
|
| 218 |
-
}
|
| 219 |
-
# Logmel spectrogram extractor
|
| 220 |
-
self.melspec_extractor = transforms.MelSpectrogram(
|
| 221 |
-
sample_rate=sample_rate,
|
| 222 |
-
n_fft=32 * sample_rate // 1000,
|
| 223 |
-
win_length=32 * sample_rate // 1000,
|
| 224 |
-
hop_length=10 * sample_rate // 1000,
|
| 225 |
-
f_min=50,
|
| 226 |
-
f_max=sr_to_fmax[sample_rate],
|
| 227 |
-
n_mels=64,
|
| 228 |
-
norm="slaney",
|
| 229 |
-
mel_scale="slaney"
|
| 230 |
-
)
|
| 231 |
-
self.hop_length = 10 * sample_rate // 1000
|
| 232 |
-
self.db_transform = transforms.AmplitudeToDB()
|
| 233 |
-
|
| 234 |
-
self.bn0 = nn.BatchNorm2d(64)
|
| 235 |
-
|
| 236 |
-
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
|
| 237 |
-
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
|
| 238 |
-
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
|
| 239 |
-
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
|
| 240 |
-
|
| 241 |
-
self.downsample_ratio = 16
|
| 242 |
-
|
| 243 |
-
self.fc1 = nn.Linear(512, 512, bias=True)
|
| 244 |
-
self.fc_emb_size = 512
|
| 245 |
-
self.init_weight()
|
| 246 |
-
self.freeze = freeze
|
| 247 |
-
|
| 248 |
-
def init_weight(self):
|
| 249 |
-
init_bn(self.bn0)
|
| 250 |
-
init_layer(self.fc1)
|
| 251 |
-
|
| 252 |
-
def load_pretrained(self, pretrained, output_fn):
|
| 253 |
-
checkpoint = torch.load(pretrained, map_location="cpu")
|
| 254 |
-
|
| 255 |
-
if "model" in checkpoint:
|
| 256 |
-
state_dict = checkpoint["model"]
|
| 257 |
-
else:
|
| 258 |
-
raise Exception("Unkown checkpoint format")
|
| 259 |
-
|
| 260 |
-
loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
|
| 261 |
-
if self.freeze:
|
| 262 |
-
for name, param in self.named_parameters():
|
| 263 |
-
if name in loaded_keys:
|
| 264 |
-
param.requires_grad = False
|
| 265 |
-
else:
|
| 266 |
-
param.requires_grad = True
|
| 267 |
-
|
| 268 |
-
def forward(self, input_dict):
|
| 269 |
-
waveform = input_dict["wav"]
|
| 270 |
-
wave_length = input_dict["wav_len"]
|
| 271 |
-
specaug = input_dict["specaug"]
|
| 272 |
-
x = self.melspec_extractor(waveform)
|
| 273 |
-
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
|
| 274 |
-
x = x.transpose(1, 2)
|
| 275 |
-
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
|
| 276 |
-
|
| 277 |
-
x = x.transpose(1, 3)
|
| 278 |
-
x = self.bn0(x)
|
| 279 |
-
x = x.transpose(1, 3)
|
| 280 |
-
|
| 281 |
-
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
|
| 282 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 283 |
-
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
|
| 284 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 285 |
-
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
|
| 286 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 287 |
-
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
|
| 288 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 289 |
-
|
| 290 |
-
x = torch.mean(x, dim=3)
|
| 291 |
-
attn_emb = x.transpose(1, 2)
|
| 292 |
-
wave_length = torch.as_tensor(wave_length)
|
| 293 |
-
feat_length = torch.div(wave_length, self.hop_length,
|
| 294 |
-
rounding_mode="floor") + 1
|
| 295 |
-
feat_length = torch.div(feat_length, self.downsample_ratio,
|
| 296 |
-
rounding_mode="floor")
|
| 297 |
-
x_max = max_with_lens(attn_emb, feat_length)
|
| 298 |
-
x_mean = mean_with_lens(attn_emb, feat_length)
|
| 299 |
-
x = x_max + x_mean
|
| 300 |
-
x = F.dropout(x, p=0.5, training=self.training)
|
| 301 |
-
x = F.relu_(self.fc1(x))
|
| 302 |
-
fc_emb = F.dropout(x, p=0.5, training=self.training)
|
| 303 |
-
|
| 304 |
-
return {
|
| 305 |
-
"attn_emb": attn_emb,
|
| 306 |
-
"fc_emb": fc_emb,
|
| 307 |
-
"attn_emb_len": feat_length
|
| 308 |
-
}
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
class Cnn14Encoder(nn.Module):
|
| 312 |
-
def __init__(self, sample_rate=32000, freeze=False):
|
| 313 |
-
super().__init__()
|
| 314 |
-
sr_to_fmax = {
|
| 315 |
-
32000: 14000,
|
| 316 |
-
16000: 8000
|
| 317 |
-
}
|
| 318 |
-
# Logmel spectrogram extractor
|
| 319 |
-
self.melspec_extractor = transforms.MelSpectrogram(
|
| 320 |
-
sample_rate=sample_rate,
|
| 321 |
-
n_fft=32 * sample_rate // 1000,
|
| 322 |
-
win_length=32 * sample_rate // 1000,
|
| 323 |
-
hop_length=10 * sample_rate // 1000,
|
| 324 |
-
f_min=50,
|
| 325 |
-
f_max=sr_to_fmax[sample_rate],
|
| 326 |
-
n_mels=64,
|
| 327 |
-
norm="slaney",
|
| 328 |
-
mel_scale="slaney"
|
| 329 |
-
)
|
| 330 |
-
self.hop_length = 10 * sample_rate // 1000
|
| 331 |
-
self.db_transform = transforms.AmplitudeToDB()
|
| 332 |
-
|
| 333 |
-
self.bn0 = nn.BatchNorm2d(64)
|
| 334 |
-
|
| 335 |
-
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
|
| 336 |
-
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
|
| 337 |
-
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
|
| 338 |
-
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
|
| 339 |
-
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
|
| 340 |
-
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
|
| 341 |
-
|
| 342 |
-
self.downsample_ratio = 32
|
| 343 |
-
|
| 344 |
-
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
| 345 |
-
self.fc_emb_size = 2048
|
| 346 |
-
|
| 347 |
-
self.init_weight()
|
| 348 |
-
self.freeze = freeze
|
| 349 |
-
|
| 350 |
-
def init_weight(self):
|
| 351 |
-
init_bn(self.bn0)
|
| 352 |
-
init_layer(self.fc1)
|
| 353 |
-
|
| 354 |
-
def load_pretrained(self, pretrained, output_fn):
|
| 355 |
-
checkpoint = torch.load(pretrained, map_location="cpu")
|
| 356 |
-
|
| 357 |
-
if "model" in checkpoint:
|
| 358 |
-
state_keys = checkpoint["model"].keys()
|
| 359 |
-
backbone = False
|
| 360 |
-
for key in state_keys:
|
| 361 |
-
if key.startswith("backbone."):
|
| 362 |
-
backbone = True
|
| 363 |
-
break
|
| 364 |
-
|
| 365 |
-
if backbone: # COLA
|
| 366 |
-
state_dict = {}
|
| 367 |
-
for key, value in checkpoint["model"].items():
|
| 368 |
-
if key.startswith("backbone."):
|
| 369 |
-
model_key = key.replace("backbone.", "")
|
| 370 |
-
state_dict[model_key] = value
|
| 371 |
-
else: # PANNs
|
| 372 |
-
state_dict = checkpoint["model"]
|
| 373 |
-
elif "state_dict" in checkpoint: # BLAT
|
| 374 |
-
state_dict = checkpoint["state_dict"]
|
| 375 |
-
state_dict_keys = list(filter(
|
| 376 |
-
lambda x: "audio_encoder" in x, state_dict.keys()))
|
| 377 |
-
state_dict = {
|
| 378 |
-
key.replace('audio_encoder.', ''): state_dict[key]
|
| 379 |
-
for key in state_dict_keys
|
| 380 |
-
}
|
| 381 |
-
else:
|
| 382 |
-
raise Exception("Unkown checkpoint format")
|
| 383 |
-
|
| 384 |
-
loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
|
| 385 |
-
if self.freeze:
|
| 386 |
-
for name, param in self.named_parameters():
|
| 387 |
-
if name in loaded_keys:
|
| 388 |
-
param.requires_grad = False
|
| 389 |
-
else:
|
| 390 |
-
param.requires_grad = True
|
| 391 |
-
|
| 392 |
-
def forward(self, input_dict):
|
| 393 |
-
waveform = input_dict["wav"]
|
| 394 |
-
wave_length = input_dict["wav_len"]
|
| 395 |
-
specaug = input_dict["specaug"]
|
| 396 |
-
x = self.melspec_extractor(waveform)
|
| 397 |
-
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
|
| 398 |
-
x = x.transpose(1, 2)
|
| 399 |
-
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
|
| 400 |
-
|
| 401 |
-
x = x.transpose(1, 3)
|
| 402 |
-
x = self.bn0(x)
|
| 403 |
-
x = x.transpose(1, 3)
|
| 404 |
-
|
| 405 |
-
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
|
| 406 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 407 |
-
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
|
| 408 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 409 |
-
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
|
| 410 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 411 |
-
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
|
| 412 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 413 |
-
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
|
| 414 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 415 |
-
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
|
| 416 |
-
x = F.dropout(x, p=0.2, training=self.training)
|
| 417 |
-
x = torch.mean(x, dim=3)
|
| 418 |
-
attn_emb = x.transpose(1, 2)
|
| 419 |
-
|
| 420 |
-
wave_length = torch.as_tensor(wave_length)
|
| 421 |
-
feat_length = torch.div(wave_length, self.hop_length,
|
| 422 |
-
rounding_mode="floor") + 1
|
| 423 |
-
feat_length = torch.div(feat_length, self.downsample_ratio,
|
| 424 |
-
rounding_mode="floor")
|
| 425 |
-
x_max = max_with_lens(attn_emb, feat_length)
|
| 426 |
-
x_mean = mean_with_lens(attn_emb, feat_length)
|
| 427 |
-
x = x_max + x_mean
|
| 428 |
-
x = F.dropout(x, p=0.5, training=self.training)
|
| 429 |
-
x = F.relu_(self.fc1(x))
|
| 430 |
-
fc_emb = F.dropout(x, p=0.5, training=self.training)
|
| 431 |
-
|
| 432 |
-
output_dict = {
|
| 433 |
-
'fc_emb': fc_emb,
|
| 434 |
-
'attn_emb': attn_emb,
|
| 435 |
-
'attn_emb_len': feat_length
|
| 436 |
-
}
|
| 437 |
-
|
| 438 |
-
return output_dict
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
class InvertedResidual(nn.Module):
|
| 442 |
-
|
| 443 |
-
def __init__(self, inp, oup, stride, expand_ratio):
|
| 444 |
-
super().__init__()
|
| 445 |
-
self.stride = stride
|
| 446 |
-
assert stride in [1, 2]
|
| 447 |
-
|
| 448 |
-
hidden_dim = round(inp * expand_ratio)
|
| 449 |
-
self.use_res_connect = self.stride == 1 and inp == oup
|
| 450 |
-
|
| 451 |
-
if expand_ratio == 1:
|
| 452 |
-
_layers = [
|
| 453 |
-
nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False),
|
| 454 |
-
nn.AvgPool2d(stride),
|
| 455 |
-
nn.BatchNorm2d(hidden_dim),
|
| 456 |
-
nn.ReLU6(inplace=True),
|
| 457 |
-
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 458 |
-
nn.BatchNorm2d(oup)
|
| 459 |
-
]
|
| 460 |
-
_layers = nn.Sequential(*_layers)
|
| 461 |
-
init_layer(_layers[0])
|
| 462 |
-
init_bn(_layers[2])
|
| 463 |
-
init_layer(_layers[4])
|
| 464 |
-
init_bn(_layers[5])
|
| 465 |
-
self.conv = _layers
|
| 466 |
-
else:
|
| 467 |
-
_layers = [
|
| 468 |
-
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
| 469 |
-
nn.BatchNorm2d(hidden_dim),
|
| 470 |
-
nn.ReLU6(inplace=True),
|
| 471 |
-
nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False),
|
| 472 |
-
nn.AvgPool2d(stride),
|
| 473 |
-
nn.BatchNorm2d(hidden_dim),
|
| 474 |
-
nn.ReLU6(inplace=True),
|
| 475 |
-
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 476 |
-
nn.BatchNorm2d(oup)
|
| 477 |
-
]
|
| 478 |
-
_layers = nn.Sequential(*_layers)
|
| 479 |
-
init_layer(_layers[0])
|
| 480 |
-
init_bn(_layers[1])
|
| 481 |
-
init_layer(_layers[3])
|
| 482 |
-
init_bn(_layers[5])
|
| 483 |
-
init_layer(_layers[7])
|
| 484 |
-
init_bn(_layers[8])
|
| 485 |
-
self.conv = _layers
|
| 486 |
-
|
| 487 |
-
def forward(self, x):
|
| 488 |
-
if self.use_res_connect:
|
| 489 |
-
return x + self.conv(x)
|
| 490 |
-
else:
|
| 491 |
-
return self.conv(x)
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
class MobileNetV2(nn.Module):
|
| 495 |
-
def __init__(self, sample_rate):
|
| 496 |
-
|
| 497 |
-
super().__init__()
|
| 498 |
-
|
| 499 |
-
sr_to_fmax = {
|
| 500 |
-
32000: 14000,
|
| 501 |
-
16000: 8000
|
| 502 |
-
}
|
| 503 |
-
# Logmel spectrogram extractor
|
| 504 |
-
self.melspec_extractor = transforms.MelSpectrogram(
|
| 505 |
-
sample_rate=sample_rate,
|
| 506 |
-
n_fft=32 * sample_rate // 1000,
|
| 507 |
-
win_length=32 * sample_rate // 1000,
|
| 508 |
-
hop_length=10 * sample_rate // 1000,
|
| 509 |
-
f_min=50,
|
| 510 |
-
f_max=sr_to_fmax[sample_rate],
|
| 511 |
-
n_mels=64,
|
| 512 |
-
norm="slaney",
|
| 513 |
-
mel_scale="slaney"
|
| 514 |
-
)
|
| 515 |
-
self.hop_length = 10 * sample_rate // 1000
|
| 516 |
-
self.db_transform = transforms.AmplitudeToDB()
|
| 517 |
-
|
| 518 |
-
self.bn0 = nn.BatchNorm2d(64)
|
| 519 |
-
|
| 520 |
-
width_mult=1.
|
| 521 |
-
block = InvertedResidual
|
| 522 |
-
input_channel = 32
|
| 523 |
-
last_channel = 1280
|
| 524 |
-
interverted_residual_setting = [
|
| 525 |
-
# t, c, n, s
|
| 526 |
-
[1, 16, 1, 1],
|
| 527 |
-
[6, 24, 2, 2],
|
| 528 |
-
[6, 32, 3, 2],
|
| 529 |
-
[6, 64, 4, 2],
|
| 530 |
-
[6, 96, 3, 2],
|
| 531 |
-
[6, 160, 3, 1],
|
| 532 |
-
[6, 320, 1, 1],
|
| 533 |
-
]
|
| 534 |
-
|
| 535 |
-
self.downsample_ratio = 32
|
| 536 |
-
|
| 537 |
-
def conv_bn(inp, oup, stride):
|
| 538 |
-
_layers = [
|
| 539 |
-
nn.Conv2d(inp, oup, 3, 1, 1, bias=False),
|
| 540 |
-
nn.AvgPool2d(stride),
|
| 541 |
-
nn.BatchNorm2d(oup),
|
| 542 |
-
nn.ReLU6(inplace=True)
|
| 543 |
-
]
|
| 544 |
-
_layers = nn.Sequential(*_layers)
|
| 545 |
-
init_layer(_layers[0])
|
| 546 |
-
init_bn(_layers[2])
|
| 547 |
-
return _layers
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
def conv_1x1_bn(inp, oup):
|
| 551 |
-
_layers = nn.Sequential(
|
| 552 |
-
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
| 553 |
-
nn.BatchNorm2d(oup),
|
| 554 |
-
nn.ReLU6(inplace=True)
|
| 555 |
-
)
|
| 556 |
-
init_layer(_layers[0])
|
| 557 |
-
init_bn(_layers[1])
|
| 558 |
-
return _layers
|
| 559 |
-
|
| 560 |
-
# building first layer
|
| 561 |
-
input_channel = int(input_channel * width_mult)
|
| 562 |
-
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
|
| 563 |
-
self.features = [conv_bn(1, input_channel, 2)]
|
| 564 |
-
# building inverted residual blocks
|
| 565 |
-
for t, c, n, s in interverted_residual_setting:
|
| 566 |
-
output_channel = int(c * width_mult)
|
| 567 |
-
for i in range(n):
|
| 568 |
-
if i == 0:
|
| 569 |
-
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
|
| 570 |
-
else:
|
| 571 |
-
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
|
| 572 |
-
input_channel = output_channel
|
| 573 |
-
# building last several layers
|
| 574 |
-
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
|
| 575 |
-
# make it nn.Sequential
|
| 576 |
-
self.features = nn.Sequential(*self.features)
|
| 577 |
-
|
| 578 |
-
self.fc1 = nn.Linear(1280, 1024, bias=True)
|
| 579 |
-
|
| 580 |
-
self.init_weight()
|
| 581 |
-
|
| 582 |
-
def init_weight(self):
|
| 583 |
-
init_bn(self.bn0)
|
| 584 |
-
init_layer(self.fc1)
|
| 585 |
-
|
| 586 |
-
def forward(self, input_dict):
|
| 587 |
-
|
| 588 |
-
waveform = input_dict["wav"]
|
| 589 |
-
wave_length = input_dict["wav_len"]
|
| 590 |
-
specaug = input_dict["specaug"]
|
| 591 |
-
x = self.melspec_extractor(waveform)
|
| 592 |
-
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
|
| 593 |
-
x = x.transpose(1, 2)
|
| 594 |
-
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
|
| 595 |
-
|
| 596 |
-
x = x.transpose(1, 3)
|
| 597 |
-
x = self.bn0(x)
|
| 598 |
-
x = x.transpose(1, 3)
|
| 599 |
-
|
| 600 |
-
x = self.features(x)
|
| 601 |
-
|
| 602 |
-
x = torch.mean(x, dim=3)
|
| 603 |
-
attn_emb = x.transpose(1, 2)
|
| 604 |
-
|
| 605 |
-
wave_length = torch.as_tensor(wave_length)
|
| 606 |
-
feat_length = torch.div(wave_length, self.hop_length,
|
| 607 |
-
rounding_mode="floor") + 1
|
| 608 |
-
feat_length = torch.div(feat_length, self.downsample_ratio,
|
| 609 |
-
rounding_mode="floor")
|
| 610 |
-
x_max = max_with_lens(attn_emb, feat_length)
|
| 611 |
-
x_mean = mean_with_lens(attn_emb, feat_length)
|
| 612 |
-
x = x_max + x_mean
|
| 613 |
-
# TODO: the original PANNs code does not have dropout here, why?
|
| 614 |
-
x = F.dropout(x, p=0.5, training=self.training)
|
| 615 |
-
x = F.relu_(self.fc1(x))
|
| 616 |
-
fc_emb = F.dropout(x, p=0.5, training=self.training)
|
| 617 |
-
|
| 618 |
-
output_dict = {
|
| 619 |
-
'fc_emb': fc_emb,
|
| 620 |
-
'attn_emb': attn_emb,
|
| 621 |
-
'attn_emb_len': feat_length
|
| 622 |
-
}
|
| 623 |
-
|
| 624 |
-
return output_dict
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
class MobileNetV3(nn.Module):
|
| 628 |
-
|
| 629 |
-
def __init__(self,
|
| 630 |
-
sample_rate,
|
| 631 |
-
model_name,
|
| 632 |
-
n_mels=64,
|
| 633 |
-
win_length=32,
|
| 634 |
-
pretrained=True,
|
| 635 |
-
freeze=False,
|
| 636 |
-
pooling="mean_max_fc"):
|
| 637 |
-
|
| 638 |
-
from captioning.models.eff_at_encoder import get_model, NAME_TO_WIDTH
|
| 639 |
-
|
| 640 |
-
super().__init__()
|
| 641 |
-
sr_to_fmax = {
|
| 642 |
-
32000: 14000,
|
| 643 |
-
16000: 8000
|
| 644 |
-
}
|
| 645 |
-
self.n_mels = n_mels
|
| 646 |
-
# Logmel spectrogram extractor
|
| 647 |
-
self.melspec_extractor = transforms.MelSpectrogram(
|
| 648 |
-
sample_rate=sample_rate,
|
| 649 |
-
n_fft=32 * sample_rate // 1000,
|
| 650 |
-
win_length=win_length * sample_rate // 1000,
|
| 651 |
-
hop_length=10 * sample_rate // 1000,
|
| 652 |
-
f_min=50,
|
| 653 |
-
f_max=sr_to_fmax[sample_rate],
|
| 654 |
-
n_mels=n_mels,
|
| 655 |
-
norm="slaney",
|
| 656 |
-
mel_scale="slaney"
|
| 657 |
-
)
|
| 658 |
-
self.hop_length = 10 * sample_rate // 1000
|
| 659 |
-
self.db_transform = transforms.AmplitudeToDB()
|
| 660 |
-
|
| 661 |
-
self.bn0 = nn.BatchNorm2d(n_mels)
|
| 662 |
-
|
| 663 |
-
width_mult = NAME_TO_WIDTH(model_name)
|
| 664 |
-
self.features = get_model(model_name=model_name,
|
| 665 |
-
pretrained=pretrained,
|
| 666 |
-
width_mult=width_mult).features
|
| 667 |
-
self.downsample_ratio = 32
|
| 668 |
-
|
| 669 |
-
if pooling == "mean_max_fc":
|
| 670 |
-
self.fc_emb_size = 512
|
| 671 |
-
self.fc1 = nn.Linear(self.features[-1].out_channels, 512, bias=True)
|
| 672 |
-
elif pooling == "mean":
|
| 673 |
-
self.fc_emb_size = self.features[-1].out_channels
|
| 674 |
-
self.init_weight()
|
| 675 |
-
|
| 676 |
-
if freeze:
|
| 677 |
-
for param in self.parameters():
|
| 678 |
-
param.requires_grad = False
|
| 679 |
-
|
| 680 |
-
self.pooling = pooling
|
| 681 |
-
|
| 682 |
-
def init_weight(self):
|
| 683 |
-
init_bn(self.bn0)
|
| 684 |
-
if hasattr(self, "fc1"):
|
| 685 |
-
init_layer(self.fc1)
|
| 686 |
-
|
| 687 |
-
def forward(self, input_dict):
|
| 688 |
-
|
| 689 |
-
waveform = input_dict["wav"]
|
| 690 |
-
wave_length = input_dict["wav_len"]
|
| 691 |
-
specaug = input_dict["specaug"]
|
| 692 |
-
x = self.melspec_extractor(waveform)
|
| 693 |
-
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
|
| 694 |
-
x = x.transpose(1, 2)
|
| 695 |
-
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
|
| 696 |
-
|
| 697 |
-
x = x.transpose(1, 3)
|
| 698 |
-
x = self.bn0(x)
|
| 699 |
-
x = x.transpose(1, 3)
|
| 700 |
-
|
| 701 |
-
x = self.features(x)
|
| 702 |
-
|
| 703 |
-
x = torch.mean(x, dim=3)
|
| 704 |
-
attn_emb = x.transpose(1, 2)
|
| 705 |
-
|
| 706 |
-
wave_length = torch.as_tensor(wave_length)
|
| 707 |
-
feat_length = torch.div(wave_length, self.hop_length,
|
| 708 |
-
rounding_mode="floor") + 1
|
| 709 |
-
feat_length = torch.div(feat_length, self.downsample_ratio,
|
| 710 |
-
rounding_mode="floor")
|
| 711 |
-
|
| 712 |
-
if self.pooling == "mean_max_fc":
|
| 713 |
-
x_max = max_with_lens(attn_emb, feat_length)
|
| 714 |
-
x_mean = mean_with_lens(attn_emb, feat_length)
|
| 715 |
-
x = x_max + x_mean
|
| 716 |
-
x = F.dropout(x, p=0.5, training=self.training)
|
| 717 |
-
x = F.relu_(self.fc1(x))
|
| 718 |
-
fc_emb = F.dropout(x, p=0.5, training=self.training)
|
| 719 |
-
elif self.pooling == "mean":
|
| 720 |
-
fc_emb = mean_with_lens(attn_emb, feat_length)
|
| 721 |
-
|
| 722 |
-
output_dict = {
|
| 723 |
-
'fc_emb': fc_emb,
|
| 724 |
-
'attn_emb': attn_emb,
|
| 725 |
-
'attn_emb_len': feat_length
|
| 726 |
-
}
|
| 727 |
-
|
| 728 |
-
return output_dict
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
class EfficientNetB2(nn.Module):
|
| 732 |
-
|
| 733 |
-
def __init__(self,
|
| 734 |
-
n_mels: int = 64,
|
| 735 |
-
win_length: int = 32,
|
| 736 |
-
hop_length: int = 10,
|
| 737 |
-
f_min: int = 0,
|
| 738 |
-
pretrained: bool = False,
|
| 739 |
-
prune_ratio: float = 0.0,
|
| 740 |
-
prune_se: bool = True,
|
| 741 |
-
prune_start_layer: int = 0,
|
| 742 |
-
prune_method: str = "operator_norm",
|
| 743 |
-
freeze: bool = False,):
|
| 744 |
-
from models.eff_latent_encoder import get_model, get_pruned_model
|
| 745 |
-
super().__init__()
|
| 746 |
-
sample_rate = 16000
|
| 747 |
-
self.melspec_extractor = transforms.MelSpectrogram(
|
| 748 |
-
sample_rate=sample_rate,
|
| 749 |
-
n_fft=win_length * sample_rate // 1000,
|
| 750 |
-
win_length=win_length * sample_rate // 1000,
|
| 751 |
-
hop_length=hop_length * sample_rate // 1000,
|
| 752 |
-
f_min=f_min,
|
| 753 |
-
n_mels=n_mels,
|
| 754 |
-
)
|
| 755 |
-
self.hop_length = 10 * sample_rate // 1000
|
| 756 |
-
self.db_transform = transforms.AmplitudeToDB(top_db=120)
|
| 757 |
-
if prune_ratio > 0:
|
| 758 |
-
self.backbone = get_pruned_model(pretrained=pretrained,
|
| 759 |
-
prune_ratio=prune_ratio,
|
| 760 |
-
prune_start_layer=prune_start_layer,
|
| 761 |
-
prune_se=prune_se,
|
| 762 |
-
prune_method=prune_method)
|
| 763 |
-
else:
|
| 764 |
-
self.backbone = get_model(pretrained=pretrained)
|
| 765 |
-
self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels
|
| 766 |
-
self.downsample_ratio = 32
|
| 767 |
-
if freeze:
|
| 768 |
-
for param in self.parameters():
|
| 769 |
-
param.requires_grad = False
|
| 770 |
-
|
| 771 |
-
def forward(self, input_dict):
|
| 772 |
-
|
| 773 |
-
waveform = input_dict["wav"]
|
| 774 |
-
wave_length = input_dict["wav_len"]
|
| 775 |
-
specaug = input_dict["specaug"]
|
| 776 |
-
x = self.melspec_extractor(waveform)
|
| 777 |
-
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
|
| 778 |
-
|
| 779 |
-
x = self.backbone(x)
|
| 780 |
-
attn_emb = x
|
| 781 |
-
|
| 782 |
-
wave_length = torch.as_tensor(wave_length)
|
| 783 |
-
feat_length = torch.div(wave_length, self.hop_length,
|
| 784 |
-
rounding_mode="floor") + 1
|
| 785 |
-
feat_length = torch.div(feat_length, self.downsample_ratio,
|
| 786 |
-
rounding_mode="floor")
|
| 787 |
-
fc_emb = mean_with_lens(attn_emb, feat_length)
|
| 788 |
-
|
| 789 |
-
output_dict = {
|
| 790 |
-
'fc_emb': fc_emb,
|
| 791 |
-
'attn_emb': attn_emb,
|
| 792 |
-
'attn_emb_len': feat_length
|
| 793 |
-
}
|
| 794 |
-
return output_dict
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
if __name__ == "__main__":
|
| 798 |
-
encoder = MobileNetV3(32000, "mn10_as")
|
| 799 |
-
print(encoder)
|
| 800 |
-
input_dict = {
|
| 801 |
-
"wav": torch.randn(4, 320000),
|
| 802 |
-
"wav_len": torch.tensor([320000, 280000, 160000, 300000]),
|
| 803 |
-
"specaug": True
|
| 804 |
-
}
|
| 805 |
-
output_dict = encoder(input_dict)
|
| 806 |
-
print("attn embed: ", output_dict["attn_emb"].shape)
|
| 807 |
-
print("fc embed: ", output_dict["fc_emb"].shape)
|
| 808 |
-
print("attn embed length: ", output_dict["attn_emb_len"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/eff_latent_encoder.py
DELETED
|
@@ -1,347 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
from efficientnet_pytorch import EfficientNet
|
| 7 |
-
from efficientnet_pytorch.model import MBConvBlock
|
| 8 |
-
from efficientnet_pytorch import utils as efficientnet_utils
|
| 9 |
-
from efficientnet_pytorch.utils import (
|
| 10 |
-
round_filters,
|
| 11 |
-
round_repeats,
|
| 12 |
-
get_same_padding_conv2d,
|
| 13 |
-
calculate_output_image_size,
|
| 14 |
-
MemoryEfficientSwish,
|
| 15 |
-
)
|
| 16 |
-
from einops import rearrange, reduce
|
| 17 |
-
from torch.hub import load_state_dict_from_url
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
model_dir = os.getcwd()
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class _EffiNet(nn.Module):
|
| 24 |
-
"""A proxy for efficient net models"""
|
| 25 |
-
def __init__(self,
|
| 26 |
-
blocks_args=None,
|
| 27 |
-
global_params=None,
|
| 28 |
-
prune_start_layer: int = 0,
|
| 29 |
-
prune_se: bool = True,
|
| 30 |
-
prune_ratio: float = 0.0
|
| 31 |
-
) -> None:
|
| 32 |
-
super().__init__()
|
| 33 |
-
if prune_ratio > 0:
|
| 34 |
-
self.eff_net = EfficientNetB2Pruned(blocks_args=blocks_args,
|
| 35 |
-
global_params=global_params,
|
| 36 |
-
prune_start_layer=prune_start_layer,
|
| 37 |
-
prune_se=prune_se,
|
| 38 |
-
prune_ratio=prune_ratio)
|
| 39 |
-
else:
|
| 40 |
-
self.eff_net = EfficientNet(blocks_args=blocks_args,
|
| 41 |
-
global_params=global_params)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def forward(self, x: torch.Tensor):
|
| 45 |
-
x = rearrange(x, 'b f t -> b 1 f t')
|
| 46 |
-
x = self.eff_net.extract_features(x)
|
| 47 |
-
return reduce(x, 'b c f t -> b t c', 'mean')
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def get_model(pretrained=True) -> _EffiNet:
|
| 51 |
-
blocks_args, global_params = efficientnet_utils.get_model_params(
|
| 52 |
-
'efficientnet-b2', {'include_top': False})
|
| 53 |
-
model = _EffiNet(blocks_args=blocks_args,
|
| 54 |
-
global_params=global_params)
|
| 55 |
-
model.eff_net._change_in_channels(1)
|
| 56 |
-
if pretrained:
|
| 57 |
-
model_path = os.path.join(model_dir, "effb2.pt")
|
| 58 |
-
if not os.path.exists(model_path):
|
| 59 |
-
state_dict = load_state_dict_from_url(
|
| 60 |
-
'https://github.com/richermans/HEAR2021_EfficientLatent/releases/download/v0.0.1/effb2.pt',
|
| 61 |
-
progress=True,
|
| 62 |
-
model_dir=model_dir)
|
| 63 |
-
else:
|
| 64 |
-
state_dict = torch.load(model_path)
|
| 65 |
-
del_keys = [key for key in state_dict if key.startswith("front_end")]
|
| 66 |
-
for key in del_keys:
|
| 67 |
-
del state_dict[key]
|
| 68 |
-
model.eff_net.load_state_dict(state_dict)
|
| 69 |
-
return model
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class MBConvBlockPruned(MBConvBlock):
|
| 73 |
-
|
| 74 |
-
def __init__(self, block_args, global_params, image_size=None, prune_ratio=0.5, prune_se=True):
|
| 75 |
-
super(MBConvBlock, self).__init__()
|
| 76 |
-
self._block_args = block_args
|
| 77 |
-
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
|
| 78 |
-
self._bn_eps = global_params.batch_norm_epsilon
|
| 79 |
-
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
| 80 |
-
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
|
| 81 |
-
|
| 82 |
-
# Expansion phase (Inverted Bottleneck)
|
| 83 |
-
inp = self._block_args.input_filters # number of input channels
|
| 84 |
-
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
| 85 |
-
if self._block_args.expand_ratio != 1:
|
| 86 |
-
oup = int(oup * (1 - prune_ratio))
|
| 87 |
-
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 88 |
-
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
| 89 |
-
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 90 |
-
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
|
| 91 |
-
|
| 92 |
-
# Depthwise convolution phase
|
| 93 |
-
k = self._block_args.kernel_size
|
| 94 |
-
s = self._block_args.stride
|
| 95 |
-
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 96 |
-
self._depthwise_conv = Conv2d(
|
| 97 |
-
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
|
| 98 |
-
kernel_size=k, stride=s, bias=False)
|
| 99 |
-
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 100 |
-
image_size = calculate_output_image_size(image_size, s)
|
| 101 |
-
|
| 102 |
-
# Squeeze and Excitation layer, if desired
|
| 103 |
-
if self.has_se:
|
| 104 |
-
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
| 105 |
-
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
| 106 |
-
if prune_se:
|
| 107 |
-
num_squeezed_channels = int(num_squeezed_channels * (1 - prune_ratio))
|
| 108 |
-
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
| 109 |
-
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
| 110 |
-
|
| 111 |
-
# Pointwise convolution phase
|
| 112 |
-
final_oup = self._block_args.output_filters
|
| 113 |
-
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 114 |
-
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
| 115 |
-
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 116 |
-
self._swish = MemoryEfficientSwish()
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
class EfficientNetB2Pruned(EfficientNet):
|
| 120 |
-
|
| 121 |
-
def __init__(self, blocks_args=None, global_params=None,
|
| 122 |
-
prune_start_layer=0, prune_ratio=0.5, prune_se=True):
|
| 123 |
-
super(EfficientNet, self).__init__()
|
| 124 |
-
assert isinstance(blocks_args, list), 'blocks_args should be a list'
|
| 125 |
-
assert len(blocks_args) > 0, 'block args must be greater than 0'
|
| 126 |
-
self._global_params = global_params
|
| 127 |
-
self._blocks_args = blocks_args
|
| 128 |
-
|
| 129 |
-
# Batch norm parameters
|
| 130 |
-
bn_mom = 1 - self._global_params.batch_norm_momentum
|
| 131 |
-
bn_eps = self._global_params.batch_norm_epsilon
|
| 132 |
-
|
| 133 |
-
# Get stem static or dynamic convolution depending on image size
|
| 134 |
-
image_size = global_params.image_size
|
| 135 |
-
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 136 |
-
|
| 137 |
-
n_build_blks = 0
|
| 138 |
-
# Stem
|
| 139 |
-
in_channels = 1 # spectrogram
|
| 140 |
-
|
| 141 |
-
p = 0.0 if n_build_blks < prune_start_layer else prune_ratio
|
| 142 |
-
out_channels = round_filters(32 * (1 - p),
|
| 143 |
-
self._global_params) # number of output channels
|
| 144 |
-
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 145 |
-
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 146 |
-
image_size = calculate_output_image_size(image_size, 2)
|
| 147 |
-
n_build_blks += 1
|
| 148 |
-
|
| 149 |
-
# Build blocks
|
| 150 |
-
self._blocks = nn.ModuleList([])
|
| 151 |
-
for block_args in self._blocks_args:
|
| 152 |
-
|
| 153 |
-
p = 0.0 if n_build_blks < prune_start_layer else prune_ratio
|
| 154 |
-
orig_input_filters = block_args.input_filters
|
| 155 |
-
# Update block input and output filters based on depth multiplier.
|
| 156 |
-
block_args = block_args._replace(
|
| 157 |
-
input_filters=round_filters(
|
| 158 |
-
block_args.input_filters * (1 - p),
|
| 159 |
-
self._global_params),
|
| 160 |
-
output_filters=round_filters(
|
| 161 |
-
block_args.output_filters * (1 - p),
|
| 162 |
-
self._global_params),
|
| 163 |
-
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
|
| 164 |
-
)
|
| 165 |
-
|
| 166 |
-
if n_build_blks == prune_start_layer:
|
| 167 |
-
block_args = block_args._replace(input_filters=round_filters(
|
| 168 |
-
orig_input_filters,
|
| 169 |
-
self._global_params)
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
# The first block needs to take care of stride and filter size increase.
|
| 173 |
-
self._blocks.append(MBConvBlockPruned(block_args, self._global_params,
|
| 174 |
-
image_size=image_size, prune_ratio=p,
|
| 175 |
-
prune_se=prune_se))
|
| 176 |
-
n_build_blks += 1
|
| 177 |
-
|
| 178 |
-
image_size = calculate_output_image_size(image_size, block_args.stride)
|
| 179 |
-
if block_args.num_repeat > 1: # modify block_args to keep same output size
|
| 180 |
-
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
|
| 181 |
-
for _ in range(block_args.num_repeat - 1):
|
| 182 |
-
self._blocks.append(MBConvBlockPruned(block_args,
|
| 183 |
-
self._global_params,
|
| 184 |
-
image_size=image_size,
|
| 185 |
-
prune_ratio=p,
|
| 186 |
-
prune_se=prune_se))
|
| 187 |
-
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
|
| 188 |
-
|
| 189 |
-
# Head
|
| 190 |
-
in_channels = block_args.output_filters # output of final block
|
| 191 |
-
p = 0.0 if n_build_blks < prune_start_layer else prune_ratio
|
| 192 |
-
out_channels = round_filters(1280 * (1 - p), self._global_params)
|
| 193 |
-
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 194 |
-
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
| 195 |
-
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 196 |
-
|
| 197 |
-
# Final linear layer
|
| 198 |
-
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
| 199 |
-
if self._global_params.include_top:
|
| 200 |
-
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
| 201 |
-
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
|
| 202 |
-
|
| 203 |
-
# set activation to memory efficient swish by default
|
| 204 |
-
self._swish = MemoryEfficientSwish()
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
def get_pruned_model(pretrained: bool = True,
|
| 208 |
-
prune_ratio: float = 0.5,
|
| 209 |
-
prune_start_layer: int = 0,
|
| 210 |
-
prune_se: bool = True,
|
| 211 |
-
prune_method: str = "operator_norm") -> _EffiNet:
|
| 212 |
-
|
| 213 |
-
import captioning.models.conv_filter_pruning as pruning_lib
|
| 214 |
-
|
| 215 |
-
blocks_args, global_params = efficientnet_utils.get_model_params(
|
| 216 |
-
'efficientnet-b2', {'include_top': False})
|
| 217 |
-
# print("num blocks: ", len(blocks_args))
|
| 218 |
-
# print("block args: ")
|
| 219 |
-
# for block_arg in blocks_args:
|
| 220 |
-
# print(block_arg)
|
| 221 |
-
model = _EffiNet(blocks_args=blocks_args,
|
| 222 |
-
global_params=global_params,
|
| 223 |
-
prune_start_layer=prune_start_layer,
|
| 224 |
-
prune_se=prune_se,
|
| 225 |
-
prune_ratio=prune_ratio)
|
| 226 |
-
|
| 227 |
-
if prune_method == "operator_norm":
|
| 228 |
-
filter_pruning = pruning_lib.operator_norm_pruning
|
| 229 |
-
elif prune_method == "interspeech":
|
| 230 |
-
filter_pruning = pruning_lib.cs_interspeech
|
| 231 |
-
elif prune_method == "iclr_l1":
|
| 232 |
-
filter_pruning = pruning_lib.iclr_l1
|
| 233 |
-
elif prune_method == "iclr_gm":
|
| 234 |
-
filter_pruning = pruning_lib.iclr_gm
|
| 235 |
-
elif prune_method == "cs_waspaa":
|
| 236 |
-
filter_pruning = pruning_lib.cs_waspaa
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
if isinstance(pretrained, str):
|
| 240 |
-
ckpt = torch.load(pretrained, "cpu")
|
| 241 |
-
state_dict = {}
|
| 242 |
-
for key in ckpt["model"].keys():
|
| 243 |
-
if key.startswith("model.encoder.backbone"):
|
| 244 |
-
state_dict[key[len("model.encoder.backbone.eff_net."):]] = ckpt["model"][key]
|
| 245 |
-
elif isinstance(pretrained, bool):
|
| 246 |
-
model_path = os.path.join(model_dir, "effb2.pt")
|
| 247 |
-
if not os.path.exists(model_path):
|
| 248 |
-
state_dict = load_state_dict_from_url(
|
| 249 |
-
'https://github.com/richermans/HEAR2021_EfficientLatent/releases/download/v0.0.1/effb2.pt',
|
| 250 |
-
progress=True,
|
| 251 |
-
model_dir=model_dir)
|
| 252 |
-
else:
|
| 253 |
-
state_dict = torch.load(model_path)
|
| 254 |
-
del_keys = [key for key in state_dict if key.startswith("front_end")]
|
| 255 |
-
for key in del_keys:
|
| 256 |
-
del state_dict[key]
|
| 257 |
-
|
| 258 |
-
# load pretrained model with corresponding filters
|
| 259 |
-
# rule:
|
| 260 |
-
# * depthwise_conv: in_ch_idx = out_ch_idx = prev_conv_idx
|
| 261 |
-
mod_dep_path = [
|
| 262 |
-
"_conv_stem",
|
| 263 |
-
]
|
| 264 |
-
conv_to_bn = {"_conv_stem": "_bn0"}
|
| 265 |
-
for i in range(2):
|
| 266 |
-
mod_dep_path.extend([
|
| 267 |
-
f"_blocks.{i}._depthwise_conv",
|
| 268 |
-
f"_blocks.{i}._se_reduce",
|
| 269 |
-
f"_blocks.{i}._se_expand",
|
| 270 |
-
f"_blocks.{i}._project_conv",
|
| 271 |
-
])
|
| 272 |
-
conv_to_bn[f"_blocks.{i}._depthwise_conv"] = f"_blocks.{i}._bn1"
|
| 273 |
-
conv_to_bn[f"_blocks.{i}._project_conv"] = f"_blocks.{i}._bn2"
|
| 274 |
-
|
| 275 |
-
for i in range(2, 23):
|
| 276 |
-
mod_dep_path.extend([
|
| 277 |
-
f"_blocks.{i}._expand_conv",
|
| 278 |
-
f"_blocks.{i}._depthwise_conv",
|
| 279 |
-
f"_blocks.{i}._se_reduce",
|
| 280 |
-
f"_blocks.{i}._se_expand",
|
| 281 |
-
f"_blocks.{i}._project_conv"
|
| 282 |
-
])
|
| 283 |
-
conv_to_bn[f"_blocks.{i}._expand_conv"] = f"_blocks.{i}._bn0"
|
| 284 |
-
conv_to_bn[f"_blocks.{i}._depthwise_conv"] = f"_blocks.{i}._bn1"
|
| 285 |
-
conv_to_bn[f"_blocks.{i}._project_conv"] = f"_blocks.{i}._bn2"
|
| 286 |
-
|
| 287 |
-
mod_dep_path.append("_conv_head")
|
| 288 |
-
conv_to_bn["_conv_head"] = "_bn1"
|
| 289 |
-
|
| 290 |
-
# print(mod_dep_path)
|
| 291 |
-
# print(conv_to_bn)
|
| 292 |
-
|
| 293 |
-
key_to_w_b_idx = {}
|
| 294 |
-
model_dict = model.eff_net.state_dict()
|
| 295 |
-
for conv_key in tqdm(mod_dep_path):
|
| 296 |
-
weight = state_dict[f"{conv_key}.weight"]
|
| 297 |
-
ptr_n_filter = weight.size(0)
|
| 298 |
-
model_n_filter = model_dict[f"{conv_key}.weight"].size(0)
|
| 299 |
-
if model_n_filter < ptr_n_filter:
|
| 300 |
-
key_to_w_b_idx[conv_key] = filter_pruning(weight.numpy())[:model_n_filter]
|
| 301 |
-
else:
|
| 302 |
-
key_to_w_b_idx[conv_key] = slice(None)
|
| 303 |
-
|
| 304 |
-
pruned_state_dict = {}
|
| 305 |
-
for conv_key, prev_conv_key in zip(mod_dep_path, [None] + mod_dep_path[:-1]):
|
| 306 |
-
|
| 307 |
-
for sub_key in ["weight", "bias"]: # adjust the conv layer
|
| 308 |
-
cur_key = f"{conv_key}.{sub_key}"
|
| 309 |
-
|
| 310 |
-
if cur_key not in state_dict:
|
| 311 |
-
continue
|
| 312 |
-
|
| 313 |
-
if prev_conv_key is None or conv_key.endswith("_depthwise_conv"):
|
| 314 |
-
conv_in_idx = slice(None)
|
| 315 |
-
else:
|
| 316 |
-
conv_in_idx = key_to_w_b_idx[prev_conv_key]
|
| 317 |
-
|
| 318 |
-
# the first pruned layer
|
| 319 |
-
if model_dict[cur_key].ndim > 1 and model_dict[cur_key].size(1) == state_dict[cur_key].size(1):
|
| 320 |
-
conv_in_idx = slice(None)
|
| 321 |
-
|
| 322 |
-
if conv_key.endswith("_depthwise_conv"):
|
| 323 |
-
conv_out_idx = key_to_w_b_idx[prev_conv_key]
|
| 324 |
-
else:
|
| 325 |
-
conv_out_idx = key_to_w_b_idx[conv_key]
|
| 326 |
-
|
| 327 |
-
# if conv_key == "_blocks.16._se_reduce":
|
| 328 |
-
# print(len(conv_out_idx), len(conv_in_idx))
|
| 329 |
-
|
| 330 |
-
if sub_key == "weight":
|
| 331 |
-
pruned_state_dict[cur_key] = state_dict[cur_key][
|
| 332 |
-
conv_out_idx, ...][:, conv_in_idx, ...]
|
| 333 |
-
else:
|
| 334 |
-
pruned_state_dict[cur_key] = state_dict[cur_key][
|
| 335 |
-
conv_out_idx, ...]
|
| 336 |
-
|
| 337 |
-
if conv_key in conv_to_bn: # adjust the corresponding bn layer
|
| 338 |
-
for sub_key in ["weight", "bias", "running_mean", "running_var"]:
|
| 339 |
-
cur_key = f"{conv_to_bn[conv_key]}.{sub_key}"
|
| 340 |
-
if cur_key not in state_dict:
|
| 341 |
-
continue
|
| 342 |
-
pruned_state_dict[cur_key] = state_dict[cur_key][
|
| 343 |
-
key_to_w_b_idx[conv_key], ...]
|
| 344 |
-
|
| 345 |
-
model.eff_net.load_state_dict(pruned_state_dict)
|
| 346 |
-
|
| 347 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/kd_wrapper.py
DELETED
|
@@ -1,226 +0,0 @@
|
|
| 1 |
-
from typing import Dict
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
from einops import repeat
|
| 8 |
-
|
| 9 |
-
from models.base import CaptionMetaMixin
|
| 10 |
-
from utils.model_util import init
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class WmlEncoderKdWrapper(nn.Module, CaptionMetaMixin):
|
| 14 |
-
|
| 15 |
-
def __init__(self,
|
| 16 |
-
model: nn.Module,
|
| 17 |
-
shared_dim: int,
|
| 18 |
-
tchr_layer_to_dims: Dict[str, int],
|
| 19 |
-
loss_type: str = "mse",):
|
| 20 |
-
super().__init__()
|
| 21 |
-
self.model = model
|
| 22 |
-
self.tchr_layers = list(tchr_layer_to_dims.keys())
|
| 23 |
-
self.stdnt_qv_proj = nn.Linear(model.encoder.fc_emb_size,
|
| 24 |
-
2 * shared_dim)
|
| 25 |
-
self.stdnt_qv_proj.apply(init)
|
| 26 |
-
for layer, dim in tchr_layer_to_dims.items():
|
| 27 |
-
self.add_module(f'tchr_kv_proj_{layer}', nn.Linear(dim, 2 * shared_dim))
|
| 28 |
-
getattr(self, f'tchr_kv_proj_{layer}').apply(init)
|
| 29 |
-
if loss_type == "mse":
|
| 30 |
-
self.loss_fn = nn.MSELoss(reduction="none")
|
| 31 |
-
|
| 32 |
-
def forward(self, input_dict: Dict):
|
| 33 |
-
output_dict = self.model(input_dict)
|
| 34 |
-
if "tchr_output" in input_dict:
|
| 35 |
-
stdnt_emb = output_dict["fc_emb"]
|
| 36 |
-
stdnt_qv = self.stdnt_qv_proj(stdnt_emb)
|
| 37 |
-
stdnt_q, stdnt_v = torch.chunk(stdnt_qv, 2, dim=-1)
|
| 38 |
-
|
| 39 |
-
tchr_output = input_dict["tchr_output"]
|
| 40 |
-
layer_ks, layer_vs = [], []
|
| 41 |
-
for layer in self.tchr_layers:
|
| 42 |
-
layer_kv = getattr(self, f'tchr_kv_proj_{layer}')(tchr_output[layer])
|
| 43 |
-
layer_k, layer_v = torch.chunk(layer_kv, 2, dim=-1)
|
| 44 |
-
layer_ks.append(layer_k)
|
| 45 |
-
layer_vs.append(layer_v)
|
| 46 |
-
layer_ks = torch.stack(layer_ks, dim=1)
|
| 47 |
-
layer_vs = torch.stack(layer_vs, dim=1)
|
| 48 |
-
weights = torch.softmax(stdnt_q.unsqueeze(1) @ layer_ks.transpose(1, 2), dim=-1)
|
| 49 |
-
stdnt_v = repeat(stdnt_v, 'b d -> b n d', n=len(self.tchr_layers))
|
| 50 |
-
loss = self.loss_fn(stdnt_v, layer_vs).mean(dim=-1, keepdim=True)
|
| 51 |
-
loss = (weights @ loss).mean()
|
| 52 |
-
output_dict["enc_kd_loss"] = loss
|
| 53 |
-
return output_dict
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class MseEncoderKdWrapper(nn.Module, CaptionMetaMixin):
|
| 57 |
-
|
| 58 |
-
def __init__(self,
|
| 59 |
-
model: nn.Module,
|
| 60 |
-
shared_dim: int,
|
| 61 |
-
tchr_dim: int,
|
| 62 |
-
use_tchr_proj: bool = True,
|
| 63 |
-
l2_norm: bool = False,
|
| 64 |
-
):
|
| 65 |
-
super().__init__()
|
| 66 |
-
self.model = model
|
| 67 |
-
self.use_tchr_proj = use_tchr_proj
|
| 68 |
-
if not use_tchr_proj:
|
| 69 |
-
assert shared_dim == tchr_dim
|
| 70 |
-
self.tchr_dim = tchr_dim
|
| 71 |
-
self.l2_norm = l2_norm
|
| 72 |
-
if hasattr(model, "encoder"):
|
| 73 |
-
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
|
| 74 |
-
shared_dim)
|
| 75 |
-
else:
|
| 76 |
-
self.stdnt_proj = nn.Linear(model.fc_emb_size,
|
| 77 |
-
shared_dim)
|
| 78 |
-
self.stdnt_proj.apply(init)
|
| 79 |
-
if use_tchr_proj:
|
| 80 |
-
self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
|
| 81 |
-
self.tchr_proj.apply(init)
|
| 82 |
-
else:
|
| 83 |
-
self.tchr_proj = nn.Identity()
|
| 84 |
-
|
| 85 |
-
def forward(self, input_dict: Dict):
|
| 86 |
-
unsup = input_dict.get("unsup", False)
|
| 87 |
-
if unsup is False:
|
| 88 |
-
if self.use_tchr_proj:
|
| 89 |
-
output_dict = self.model(input_dict)
|
| 90 |
-
stdnt_emb = output_dict["fc_emb"]
|
| 91 |
-
else:
|
| 92 |
-
encoder_output = self.model.encoder(input_dict)
|
| 93 |
-
stdnt_emb = encoder_output["fc_emb"]
|
| 94 |
-
encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"])
|
| 95 |
-
encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"])
|
| 96 |
-
output_dict = self.model.forward_decoder(input_dict, encoder_output)
|
| 97 |
-
else:
|
| 98 |
-
output_dict = self.model.encoder(input_dict)
|
| 99 |
-
stdnt_emb = output_dict["fc_emb"]
|
| 100 |
-
if "tchr_output" in input_dict:
|
| 101 |
-
stdnt_emb = self.stdnt_proj(stdnt_emb)
|
| 102 |
-
tchr_emb = input_dict["tchr_output"]["embedding"]
|
| 103 |
-
thcr_emb = self.tchr_proj(tchr_emb)
|
| 104 |
-
|
| 105 |
-
if self.l2_norm:
|
| 106 |
-
stdnt_emb = F.normalize(stdnt_emb, dim=-1)
|
| 107 |
-
thcr_emb = F.normalize(thcr_emb, dim=-1)
|
| 108 |
-
|
| 109 |
-
loss = F.mse_loss(stdnt_emb, thcr_emb)
|
| 110 |
-
output_dict["enc_kd_loss"] = loss
|
| 111 |
-
return output_dict
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin):
|
| 115 |
-
|
| 116 |
-
def __init__(self,
|
| 117 |
-
model: nn.Module,
|
| 118 |
-
shared_dim: int,
|
| 119 |
-
tchr_dim: int,
|
| 120 |
-
):
|
| 121 |
-
super().__init__()
|
| 122 |
-
self.model = model
|
| 123 |
-
self.tchr_dim = tchr_dim
|
| 124 |
-
if hasattr(model, "encoder"):
|
| 125 |
-
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
|
| 126 |
-
shared_dim)
|
| 127 |
-
else:
|
| 128 |
-
self.stdnt_proj = nn.Linear(model.fc_emb_size,
|
| 129 |
-
shared_dim)
|
| 130 |
-
self.stdnt_proj.apply(init)
|
| 131 |
-
self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
|
| 132 |
-
self.tchr_proj.apply(init)
|
| 133 |
-
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 134 |
-
|
| 135 |
-
def forward(self, input_dict: Dict):
|
| 136 |
-
unsup = input_dict.get("unsup", False)
|
| 137 |
-
if unsup is False:
|
| 138 |
-
output_dict = self.model(input_dict)
|
| 139 |
-
else:
|
| 140 |
-
output_dict = self.model.encoder(input_dict)
|
| 141 |
-
if "tchr_output" in input_dict:
|
| 142 |
-
stdnt_emb = output_dict["fc_emb"]
|
| 143 |
-
stdnt_emb = self.stdnt_proj(stdnt_emb)
|
| 144 |
-
tchr_emb = input_dict["tchr_output"]["embedding"]
|
| 145 |
-
thcr_emb = self.tchr_proj(tchr_emb)
|
| 146 |
-
|
| 147 |
-
stdnt_emb = F.normalize(stdnt_emb, dim=-1)
|
| 148 |
-
thcr_emb = F.normalize(thcr_emb, dim=-1)
|
| 149 |
-
|
| 150 |
-
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
|
| 151 |
-
logit = self.logit_scale * unscaled_logit
|
| 152 |
-
label = torch.arange(logit.shape[0]).to(logit.device)
|
| 153 |
-
loss1 = F.cross_entropy(logit, label)
|
| 154 |
-
loss2 = F.cross_entropy(logit.transpose(0, 1), label)
|
| 155 |
-
loss = (loss1 + loss2) / 2
|
| 156 |
-
output_dict["enc_kd_loss"] = loss
|
| 157 |
-
return output_dict
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
class ContraMseEncoderKdWrapper(nn.Module, CaptionMetaMixin):
|
| 161 |
-
|
| 162 |
-
def __init__(self,
|
| 163 |
-
model: nn.Module,
|
| 164 |
-
shared_dim: int,
|
| 165 |
-
tchr_dim: int,
|
| 166 |
-
use_tchr_proj: bool = True,
|
| 167 |
-
l2_norm: bool = False,
|
| 168 |
-
):
|
| 169 |
-
super().__init__()
|
| 170 |
-
self.model = model
|
| 171 |
-
self.use_tchr_proj = use_tchr_proj
|
| 172 |
-
if not use_tchr_proj:
|
| 173 |
-
assert shared_dim == tchr_dim
|
| 174 |
-
self.tchr_dim = tchr_dim
|
| 175 |
-
self.l2_norm = l2_norm
|
| 176 |
-
if hasattr(model, "encoder"):
|
| 177 |
-
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
|
| 178 |
-
shared_dim)
|
| 179 |
-
else:
|
| 180 |
-
self.stdnt_proj = nn.Linear(model.fc_emb_size,
|
| 181 |
-
shared_dim)
|
| 182 |
-
self.stdnt_proj.apply(init)
|
| 183 |
-
if use_tchr_proj:
|
| 184 |
-
self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
|
| 185 |
-
self.tchr_proj.apply(init)
|
| 186 |
-
else:
|
| 187 |
-
self.tchr_proj = nn.Identity()
|
| 188 |
-
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 189 |
-
|
| 190 |
-
def forward(self, input_dict: Dict):
|
| 191 |
-
unsup = input_dict.get("unsup", False)
|
| 192 |
-
if unsup is False:
|
| 193 |
-
if self.use_tchr_proj:
|
| 194 |
-
output_dict = self.model(input_dict)
|
| 195 |
-
stdnt_emb = output_dict["fc_emb"]
|
| 196 |
-
else:
|
| 197 |
-
encoder_output = self.model.encoder(input_dict)
|
| 198 |
-
stdnt_emb = encoder_output["fc_emb"]
|
| 199 |
-
encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"])
|
| 200 |
-
encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"])
|
| 201 |
-
output_dict = self.model.forward_decoder(input_dict, encoder_output)
|
| 202 |
-
else:
|
| 203 |
-
output_dict = self.model.encoder(input_dict)
|
| 204 |
-
stdnt_emb = output_dict["fc_emb"]
|
| 205 |
-
if "tchr_output" in input_dict:
|
| 206 |
-
stdnt_emb = self.stdnt_proj(stdnt_emb)
|
| 207 |
-
tchr_emb = input_dict["tchr_output"]["embedding"]
|
| 208 |
-
thcr_emb = self.tchr_proj(tchr_emb)
|
| 209 |
-
|
| 210 |
-
if self.l2_norm:
|
| 211 |
-
stdnt_emb = F.normalize(stdnt_emb, dim=-1)
|
| 212 |
-
thcr_emb = F.normalize(thcr_emb, dim=-1)
|
| 213 |
-
|
| 214 |
-
mse_loss = F.mse_loss(stdnt_emb, thcr_emb)
|
| 215 |
-
|
| 216 |
-
stdnt_emb = F.normalize(stdnt_emb, dim=-1)
|
| 217 |
-
thcr_emb = F.normalize(thcr_emb, dim=-1)
|
| 218 |
-
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
|
| 219 |
-
logit = self.logit_scale * unscaled_logit
|
| 220 |
-
label = torch.arange(logit.shape[0]).to(logit.device)
|
| 221 |
-
loss1 = F.cross_entropy(logit, label)
|
| 222 |
-
loss2 = F.cross_entropy(logit.transpose(0, 1), label)
|
| 223 |
-
cntr_loss = (loss1 + loss2) / 2
|
| 224 |
-
output_dict["enc_kd_loss"] = mse_loss + cntr_loss
|
| 225 |
-
|
| 226 |
-
return output_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/transformer_decoder.py
DELETED
|
@@ -1,214 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
|
| 6 |
-
from models import BaseDecoder
|
| 7 |
-
from utils.model_util import generate_length_mask, PositionalEncoding
|
| 8 |
-
from utils.train_util import merge_load_state_dict
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class TransformerDecoder(BaseDecoder):
|
| 12 |
-
|
| 13 |
-
def __init__(self,
|
| 14 |
-
emb_dim,
|
| 15 |
-
vocab_size,
|
| 16 |
-
fc_emb_dim,
|
| 17 |
-
attn_emb_dim,
|
| 18 |
-
dropout,
|
| 19 |
-
freeze=False,
|
| 20 |
-
tie_weights=False,
|
| 21 |
-
**kwargs):
|
| 22 |
-
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
| 23 |
-
dropout=dropout, tie_weights=tie_weights)
|
| 24 |
-
self.d_model = emb_dim
|
| 25 |
-
self.nhead = kwargs.get("nhead", self.d_model // 64)
|
| 26 |
-
self.nlayers = kwargs.get("nlayers", 2)
|
| 27 |
-
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
|
| 28 |
-
|
| 29 |
-
self.pos_encoder = PositionalEncoding(self.d_model, dropout)
|
| 30 |
-
layer = nn.TransformerDecoderLayer(d_model=self.d_model,
|
| 31 |
-
nhead=self.nhead,
|
| 32 |
-
dim_feedforward=self.dim_feedforward,
|
| 33 |
-
dropout=dropout)
|
| 34 |
-
self.model = nn.TransformerDecoder(layer, self.nlayers)
|
| 35 |
-
self.classifier = nn.Linear(self.d_model, vocab_size, bias=False)
|
| 36 |
-
if tie_weights:
|
| 37 |
-
self.classifier.weight = self.word_embedding.weight
|
| 38 |
-
self.attn_proj = nn.Sequential(
|
| 39 |
-
nn.Linear(self.attn_emb_dim, self.d_model),
|
| 40 |
-
nn.ReLU(),
|
| 41 |
-
nn.Dropout(dropout),
|
| 42 |
-
nn.LayerNorm(self.d_model)
|
| 43 |
-
)
|
| 44 |
-
self.init_params()
|
| 45 |
-
|
| 46 |
-
self.freeze = freeze
|
| 47 |
-
if freeze:
|
| 48 |
-
for p in self.parameters():
|
| 49 |
-
p.requires_grad = False
|
| 50 |
-
|
| 51 |
-
def init_params(self):
|
| 52 |
-
for p in self.parameters():
|
| 53 |
-
if p.dim() > 1:
|
| 54 |
-
nn.init.xavier_uniform_(p)
|
| 55 |
-
|
| 56 |
-
def load_pretrained(self, pretrained, output_fn):
|
| 57 |
-
checkpoint = torch.load(pretrained, map_location="cpu")
|
| 58 |
-
|
| 59 |
-
if "model" in checkpoint:
|
| 60 |
-
checkpoint = checkpoint["model"]
|
| 61 |
-
if next(iter(checkpoint)).startswith("decoder."):
|
| 62 |
-
state_dict = {}
|
| 63 |
-
for k, v in checkpoint.items():
|
| 64 |
-
state_dict[k[8:]] = v
|
| 65 |
-
|
| 66 |
-
loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
|
| 67 |
-
if self.freeze:
|
| 68 |
-
for name, param in self.named_parameters():
|
| 69 |
-
if name in loaded_keys:
|
| 70 |
-
param.requires_grad = False
|
| 71 |
-
else:
|
| 72 |
-
param.requires_grad = True
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def generate_square_subsequent_mask(self, max_length):
|
| 76 |
-
mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
|
| 77 |
-
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
| 78 |
-
return mask
|
| 79 |
-
|
| 80 |
-
def forward(self, input_dict):
|
| 81 |
-
word = input_dict["word"]
|
| 82 |
-
attn_emb = input_dict["attn_emb"]
|
| 83 |
-
attn_emb_len = input_dict["attn_emb_len"]
|
| 84 |
-
cap_padding_mask = input_dict["cap_padding_mask"]
|
| 85 |
-
|
| 86 |
-
p_attn_emb = self.attn_proj(attn_emb)
|
| 87 |
-
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
|
| 88 |
-
word = word.to(attn_emb.device)
|
| 89 |
-
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
|
| 90 |
-
embed = embed.transpose(0, 1) # [T, N, emb_dim]
|
| 91 |
-
embed = self.pos_encoder(embed)
|
| 92 |
-
|
| 93 |
-
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
|
| 94 |
-
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
|
| 95 |
-
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
|
| 96 |
-
tgt_key_padding_mask=cap_padding_mask,
|
| 97 |
-
memory_key_padding_mask=memory_key_padding_mask)
|
| 98 |
-
output = output.transpose(0, 1)
|
| 99 |
-
output = {
|
| 100 |
-
"embed": output,
|
| 101 |
-
"logit": self.classifier(output),
|
| 102 |
-
}
|
| 103 |
-
return output
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
class M2TransformerDecoder(BaseDecoder):
|
| 107 |
-
|
| 108 |
-
def __init__(self, vocab_size, fc_emb_dim, attn_emb_dim, dropout=0.1, **kwargs):
|
| 109 |
-
super().__init__(attn_emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout=dropout,)
|
| 110 |
-
try:
|
| 111 |
-
from m2transformer.models.transformer import MeshedDecoder
|
| 112 |
-
except:
|
| 113 |
-
raise ImportError("meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`")
|
| 114 |
-
del self.word_embedding
|
| 115 |
-
del self.in_dropout
|
| 116 |
-
|
| 117 |
-
self.d_model = attn_emb_dim
|
| 118 |
-
self.nhead = kwargs.get("nhead", self.d_model // 64)
|
| 119 |
-
self.nlayers = kwargs.get("nlayers", 2)
|
| 120 |
-
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
|
| 121 |
-
self.model = MeshedDecoder(vocab_size, 100, self.nlayers, 0,
|
| 122 |
-
d_model=self.d_model,
|
| 123 |
-
h=self.nhead,
|
| 124 |
-
d_ff=self.dim_feedforward,
|
| 125 |
-
dropout=dropout)
|
| 126 |
-
self.init_params()
|
| 127 |
-
|
| 128 |
-
def init_params(self):
|
| 129 |
-
for p in self.parameters():
|
| 130 |
-
if p.dim() > 1:
|
| 131 |
-
nn.init.xavier_uniform_(p)
|
| 132 |
-
|
| 133 |
-
def forward(self, input_dict):
|
| 134 |
-
word = input_dict["word"]
|
| 135 |
-
attn_emb = input_dict["attn_emb"]
|
| 136 |
-
attn_emb_mask = input_dict["attn_emb_mask"]
|
| 137 |
-
word = word.to(attn_emb.device)
|
| 138 |
-
embed, logit = self.model(word, attn_emb, attn_emb_mask)
|
| 139 |
-
output = {
|
| 140 |
-
"embed": embed,
|
| 141 |
-
"logit": logit,
|
| 142 |
-
}
|
| 143 |
-
return output
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
class EventTransformerDecoder(TransformerDecoder):
|
| 147 |
-
|
| 148 |
-
def forward(self, input_dict):
|
| 149 |
-
word = input_dict["word"] # index of word embeddings
|
| 150 |
-
attn_emb = input_dict["attn_emb"]
|
| 151 |
-
attn_emb_len = input_dict["attn_emb_len"]
|
| 152 |
-
cap_padding_mask = input_dict["cap_padding_mask"]
|
| 153 |
-
event_emb = input_dict["event"] # [N, emb_dim]
|
| 154 |
-
|
| 155 |
-
p_attn_emb = self.attn_proj(attn_emb)
|
| 156 |
-
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
|
| 157 |
-
word = word.to(attn_emb.device)
|
| 158 |
-
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
|
| 159 |
-
|
| 160 |
-
embed = embed.transpose(0, 1) # [T, N, emb_dim]
|
| 161 |
-
embed += event_emb
|
| 162 |
-
embed = self.pos_encoder(embed)
|
| 163 |
-
|
| 164 |
-
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
|
| 165 |
-
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
|
| 166 |
-
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
|
| 167 |
-
tgt_key_padding_mask=cap_padding_mask,
|
| 168 |
-
memory_key_padding_mask=memory_key_padding_mask)
|
| 169 |
-
output = output.transpose(0, 1)
|
| 170 |
-
output = {
|
| 171 |
-
"embed": output,
|
| 172 |
-
"logit": self.classifier(output),
|
| 173 |
-
}
|
| 174 |
-
return output
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
class KeywordProbTransformerDecoder(TransformerDecoder):
|
| 178 |
-
|
| 179 |
-
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
| 180 |
-
dropout, keyword_classes_num, **kwargs):
|
| 181 |
-
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
| 182 |
-
dropout, **kwargs)
|
| 183 |
-
self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model)
|
| 184 |
-
self.word_keyword_norm = nn.LayerNorm(self.d_model)
|
| 185 |
-
|
| 186 |
-
def forward(self, input_dict):
|
| 187 |
-
word = input_dict["word"] # index of word embeddings
|
| 188 |
-
attn_emb = input_dict["attn_emb"]
|
| 189 |
-
attn_emb_len = input_dict["attn_emb_len"]
|
| 190 |
-
cap_padding_mask = input_dict["cap_padding_mask"]
|
| 191 |
-
keyword = input_dict["keyword"] # [N, keyword_classes_num]
|
| 192 |
-
|
| 193 |
-
p_attn_emb = self.attn_proj(attn_emb)
|
| 194 |
-
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
|
| 195 |
-
word = word.to(attn_emb.device)
|
| 196 |
-
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
|
| 197 |
-
|
| 198 |
-
embed = embed.transpose(0, 1) # [T, N, emb_dim]
|
| 199 |
-
embed += self.keyword_proj(keyword)
|
| 200 |
-
embed = self.word_keyword_norm(embed)
|
| 201 |
-
|
| 202 |
-
embed = self.pos_encoder(embed)
|
| 203 |
-
|
| 204 |
-
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
|
| 205 |
-
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
|
| 206 |
-
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
|
| 207 |
-
tgt_key_padding_mask=cap_padding_mask,
|
| 208 |
-
memory_key_padding_mask=memory_key_padding_mask)
|
| 209 |
-
output = output.transpose(0, 1)
|
| 210 |
-
output = {
|
| 211 |
-
"embed": output,
|
| 212 |
-
"logit": self.classifier(output),
|
| 213 |
-
}
|
| 214 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/transformer_model.py
DELETED
|
@@ -1,264 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
import random
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
|
| 6 |
-
from models.base import CaptionModel
|
| 7 |
-
from utils.model_util import repeat_tensor
|
| 8 |
-
import models.transformer_decoder
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class TransformerModel(CaptionModel):
|
| 12 |
-
|
| 13 |
-
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
| 14 |
-
if not hasattr(self, "compatible_decoders"):
|
| 15 |
-
self.compatible_decoders = (
|
| 16 |
-
models.transformer_decoder.TransformerDecoder,
|
| 17 |
-
)
|
| 18 |
-
super().__init__(encoder, decoder, **kwargs)
|
| 19 |
-
|
| 20 |
-
def seq_forward(self, input_dict):
|
| 21 |
-
cap = input_dict["cap"]
|
| 22 |
-
cap_padding_mask = (cap == self.pad_idx).to(cap.device)
|
| 23 |
-
cap_padding_mask = cap_padding_mask[:, :-1]
|
| 24 |
-
output = self.decoder(
|
| 25 |
-
{
|
| 26 |
-
"word": cap[:, :-1],
|
| 27 |
-
"attn_emb": input_dict["attn_emb"],
|
| 28 |
-
"attn_emb_len": input_dict["attn_emb_len"],
|
| 29 |
-
"cap_padding_mask": cap_padding_mask
|
| 30 |
-
}
|
| 31 |
-
)
|
| 32 |
-
return output
|
| 33 |
-
|
| 34 |
-
def prepare_decoder_input(self, input_dict, output):
|
| 35 |
-
decoder_input = {
|
| 36 |
-
"attn_emb": input_dict["attn_emb"],
|
| 37 |
-
"attn_emb_len": input_dict["attn_emb_len"]
|
| 38 |
-
}
|
| 39 |
-
t = input_dict["t"]
|
| 40 |
-
|
| 41 |
-
###############
|
| 42 |
-
# determine input word
|
| 43 |
-
################
|
| 44 |
-
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
|
| 45 |
-
word = input_dict["cap"][:, :t+1]
|
| 46 |
-
else:
|
| 47 |
-
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
|
| 48 |
-
if t == 0:
|
| 49 |
-
word = start_word
|
| 50 |
-
else:
|
| 51 |
-
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
|
| 52 |
-
# word: [N, T]
|
| 53 |
-
decoder_input["word"] = word
|
| 54 |
-
|
| 55 |
-
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
|
| 56 |
-
decoder_input["cap_padding_mask"] = cap_padding_mask
|
| 57 |
-
return decoder_input
|
| 58 |
-
|
| 59 |
-
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
| 60 |
-
decoder_input = {}
|
| 61 |
-
t = input_dict["t"]
|
| 62 |
-
i = input_dict["sample_idx"]
|
| 63 |
-
beam_size = input_dict["beam_size"]
|
| 64 |
-
###############
|
| 65 |
-
# prepare attn embeds
|
| 66 |
-
################
|
| 67 |
-
if t == 0:
|
| 68 |
-
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
|
| 69 |
-
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
|
| 70 |
-
output_i["attn_emb"] = attn_emb
|
| 71 |
-
output_i["attn_emb_len"] = attn_emb_len
|
| 72 |
-
decoder_input["attn_emb"] = output_i["attn_emb"]
|
| 73 |
-
decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
|
| 74 |
-
###############
|
| 75 |
-
# determine input word
|
| 76 |
-
################
|
| 77 |
-
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
|
| 78 |
-
if t == 0:
|
| 79 |
-
word = start_word
|
| 80 |
-
else:
|
| 81 |
-
word = torch.cat((start_word, output_i["seq"]), dim=-1)
|
| 82 |
-
decoder_input["word"] = word
|
| 83 |
-
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
|
| 84 |
-
decoder_input["cap_padding_mask"] = cap_padding_mask
|
| 85 |
-
|
| 86 |
-
return decoder_input
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class M2TransformerModel(CaptionModel):
|
| 90 |
-
|
| 91 |
-
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
| 92 |
-
if not hasattr(self, "compatible_decoders"):
|
| 93 |
-
self.compatible_decoders = (
|
| 94 |
-
models.transformer_decoder.M2TransformerDecoder,
|
| 95 |
-
)
|
| 96 |
-
super().__init__(encoder, decoder, **kwargs)
|
| 97 |
-
self.check_encoder_compatibility()
|
| 98 |
-
|
| 99 |
-
def check_encoder_compatibility(self):
|
| 100 |
-
assert isinstance(self.encoder, models.encoder.M2TransformerEncoder), \
|
| 101 |
-
f"only M2TransformerModel is compatible with {self.__class__.__name__}"
|
| 102 |
-
|
| 103 |
-
def seq_forward(self, input_dict):
|
| 104 |
-
cap = input_dict["cap"]
|
| 105 |
-
output = self.decoder(
|
| 106 |
-
{
|
| 107 |
-
"word": cap[:, :-1],
|
| 108 |
-
"attn_emb": input_dict["attn_emb"],
|
| 109 |
-
"attn_emb_mask": input_dict["attn_emb_mask"],
|
| 110 |
-
}
|
| 111 |
-
)
|
| 112 |
-
return output
|
| 113 |
-
|
| 114 |
-
def prepare_decoder_input(self, input_dict, output):
|
| 115 |
-
decoder_input = {
|
| 116 |
-
"attn_emb": input_dict["attn_emb"],
|
| 117 |
-
"attn_emb_mask": input_dict["attn_emb_mask"]
|
| 118 |
-
}
|
| 119 |
-
t = input_dict["t"]
|
| 120 |
-
|
| 121 |
-
###############
|
| 122 |
-
# determine input word
|
| 123 |
-
################
|
| 124 |
-
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
|
| 125 |
-
word = input_dict["cap"][:, :t+1]
|
| 126 |
-
else:
|
| 127 |
-
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
|
| 128 |
-
if t == 0:
|
| 129 |
-
word = start_word
|
| 130 |
-
else:
|
| 131 |
-
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
|
| 132 |
-
# word: [N, T]
|
| 133 |
-
decoder_input["word"] = word
|
| 134 |
-
|
| 135 |
-
return decoder_input
|
| 136 |
-
|
| 137 |
-
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
| 138 |
-
decoder_input = {}
|
| 139 |
-
t = input_dict["t"]
|
| 140 |
-
i = input_dict["sample_idx"]
|
| 141 |
-
beam_size = input_dict["beam_size"]
|
| 142 |
-
###############
|
| 143 |
-
# prepare attn embeds
|
| 144 |
-
################
|
| 145 |
-
if t == 0:
|
| 146 |
-
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
|
| 147 |
-
attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size)
|
| 148 |
-
output_i["attn_emb"] = attn_emb
|
| 149 |
-
output_i["attn_emb_mask"] = attn_emb_mask
|
| 150 |
-
decoder_input["attn_emb"] = output_i["attn_emb"]
|
| 151 |
-
decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"]
|
| 152 |
-
###############
|
| 153 |
-
# determine input word
|
| 154 |
-
################
|
| 155 |
-
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
|
| 156 |
-
if t == 0:
|
| 157 |
-
word = start_word
|
| 158 |
-
else:
|
| 159 |
-
word = torch.cat((start_word, output_i["seq"]), dim=-1)
|
| 160 |
-
decoder_input["word"] = word
|
| 161 |
-
|
| 162 |
-
return decoder_input
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
class EventEncoder(nn.Module):
|
| 166 |
-
"""
|
| 167 |
-
Encode the Label information in AudioCaps and AudioSet
|
| 168 |
-
"""
|
| 169 |
-
def __init__(self, emb_dim, vocab_size=527):
|
| 170 |
-
super(EventEncoder, self).__init__()
|
| 171 |
-
self.label_embedding = nn.Parameter(
|
| 172 |
-
torch.randn((vocab_size, emb_dim)), requires_grad=True)
|
| 173 |
-
|
| 174 |
-
def forward(self, word_idxs):
|
| 175 |
-
indices = word_idxs / word_idxs.sum(dim=1, keepdim=True)
|
| 176 |
-
embeddings = indices @ self.label_embedding
|
| 177 |
-
return embeddings
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
class EventCondTransformerModel(TransformerModel):
|
| 181 |
-
|
| 182 |
-
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
| 183 |
-
if not hasattr(self, "compatible_decoders"):
|
| 184 |
-
self.compatible_decoders = (
|
| 185 |
-
models.transformer_decoder.EventTransformerDecoder,
|
| 186 |
-
)
|
| 187 |
-
super().__init__(encoder, decoder, **kwargs)
|
| 188 |
-
self.label_encoder = EventEncoder(decoder.emb_dim, 527)
|
| 189 |
-
self.train_forward_keys += ["events"]
|
| 190 |
-
self.inference_forward_keys += ["events"]
|
| 191 |
-
|
| 192 |
-
# def seq_forward(self, input_dict):
|
| 193 |
-
# cap = input_dict["cap"]
|
| 194 |
-
# cap_padding_mask = (cap == self.pad_idx).to(cap.device)
|
| 195 |
-
# cap_padding_mask = cap_padding_mask[:, :-1]
|
| 196 |
-
# output = self.decoder(
|
| 197 |
-
# {
|
| 198 |
-
# "word": cap[:, :-1],
|
| 199 |
-
# "attn_emb": input_dict["attn_emb"],
|
| 200 |
-
# "attn_emb_len": input_dict["attn_emb_len"],
|
| 201 |
-
# "cap_padding_mask": cap_padding_mask
|
| 202 |
-
# }
|
| 203 |
-
# )
|
| 204 |
-
# return output
|
| 205 |
-
|
| 206 |
-
def prepare_decoder_input(self, input_dict, output):
|
| 207 |
-
decoder_input = super().prepare_decoder_input(input_dict, output)
|
| 208 |
-
decoder_input["events"] = self.label_encoder(input_dict["events"])
|
| 209 |
-
return decoder_input
|
| 210 |
-
|
| 211 |
-
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
| 212 |
-
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
|
| 213 |
-
t = input_dict["t"]
|
| 214 |
-
i = input_dict["sample_idx"]
|
| 215 |
-
beam_size = input_dict["beam_size"]
|
| 216 |
-
if t == 0:
|
| 217 |
-
output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size)
|
| 218 |
-
decoder_input["events"] = output_i["events"]
|
| 219 |
-
return decoder_input
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
class KeywordCondTransformerModel(TransformerModel):
|
| 223 |
-
|
| 224 |
-
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
| 225 |
-
if not hasattr(self, "compatible_decoders"):
|
| 226 |
-
self.compatible_decoders = (
|
| 227 |
-
models.transformer_decoder.KeywordProbTransformerDecoder,
|
| 228 |
-
)
|
| 229 |
-
super().__init__(encoder, decoder, **kwargs)
|
| 230 |
-
self.train_forward_keys += ["keyword"]
|
| 231 |
-
self.inference_forward_keys += ["keyword"]
|
| 232 |
-
|
| 233 |
-
def seq_forward(self, input_dict):
|
| 234 |
-
cap = input_dict["cap"]
|
| 235 |
-
cap_padding_mask = (cap == self.pad_idx).to(cap.device)
|
| 236 |
-
cap_padding_mask = cap_padding_mask[:, :-1]
|
| 237 |
-
keyword = input_dict["keyword"]
|
| 238 |
-
output = self.decoder(
|
| 239 |
-
{
|
| 240 |
-
"word": cap[:, :-1],
|
| 241 |
-
"attn_emb": input_dict["attn_emb"],
|
| 242 |
-
"attn_emb_len": input_dict["attn_emb_len"],
|
| 243 |
-
"keyword": keyword,
|
| 244 |
-
"cap_padding_mask": cap_padding_mask
|
| 245 |
-
}
|
| 246 |
-
)
|
| 247 |
-
return output
|
| 248 |
-
|
| 249 |
-
def prepare_decoder_input(self, input_dict, output):
|
| 250 |
-
decoder_input = super().prepare_decoder_input(input_dict, output)
|
| 251 |
-
decoder_input["keyword"] = input_dict["keyword"]
|
| 252 |
-
return decoder_input
|
| 253 |
-
|
| 254 |
-
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
| 255 |
-
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
|
| 256 |
-
t = input_dict["t"]
|
| 257 |
-
i = input_dict["sample_idx"]
|
| 258 |
-
beam_size = input_dict["beam_size"]
|
| 259 |
-
if t == 0:
|
| 260 |
-
output_i["keyword"] = repeat_tensor(input_dict["keyword"][i],
|
| 261 |
-
beam_size)
|
| 262 |
-
decoder_input["keyword"] = output_i["keyword"]
|
| 263 |
-
return decoder_input
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
|
|
| 1 |
efficientnet_pytorch
|
| 2 |
-
PyYAML
|
| 3 |
torchaudio
|
| 4 |
-
einops
|
|
|
|
| 1 |
+
transformers
|
| 2 |
efficientnet_pytorch
|
|
|
|
| 3 |
torchaudio
|
| 4 |
+
einops
|
text_tokenizer.py
DELETED
|
@@ -1,107 +0,0 @@
|
|
| 1 |
-
import pickle
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
from utils.train_util import pad_sequence
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class DictTokenizer:
|
| 9 |
-
|
| 10 |
-
def __init__(self,
|
| 11 |
-
tokenizer_path: str = None,
|
| 12 |
-
max_length: int = 20) -> None:
|
| 13 |
-
self.word2idx = {}
|
| 14 |
-
self.idx2word = {}
|
| 15 |
-
self.idx = 0
|
| 16 |
-
self.add_word("<pad>")
|
| 17 |
-
self.add_word("<start>")
|
| 18 |
-
self.add_word("<end>")
|
| 19 |
-
self.add_word("<unk>")
|
| 20 |
-
if tokenizer_path is not None and Path(tokenizer_path).exists():
|
| 21 |
-
state_dict = pickle.load(open(tokenizer_path, "rb"))
|
| 22 |
-
self.load_state_dict(state_dict)
|
| 23 |
-
self.loaded = True
|
| 24 |
-
else:
|
| 25 |
-
self.loaded = False
|
| 26 |
-
self.bos, self.eos = self.word2idx["<start>"], self.word2idx["<end>"]
|
| 27 |
-
self.pad = self.word2idx["<pad>"]
|
| 28 |
-
self.max_length = max_length
|
| 29 |
-
|
| 30 |
-
def add_word(self, word):
|
| 31 |
-
if not word in self.word2idx:
|
| 32 |
-
self.word2idx[word] = self.idx
|
| 33 |
-
self.idx2word[self.idx] = word
|
| 34 |
-
self.idx += 1
|
| 35 |
-
|
| 36 |
-
def encode_word(self, word):
|
| 37 |
-
if word in self.word2idx:
|
| 38 |
-
return self.word2idx[word]
|
| 39 |
-
else:
|
| 40 |
-
return self.word2idx["<unk>"]
|
| 41 |
-
|
| 42 |
-
def __call__(self, texts):
|
| 43 |
-
assert isinstance(texts, list), "the input must be List[str]"
|
| 44 |
-
batch_tokens = []
|
| 45 |
-
for text in texts:
|
| 46 |
-
tokens = [self.encode_word(token) for token in text.split()][:self.max_length]
|
| 47 |
-
tokens = [self.bos] + tokens + [self.eos]
|
| 48 |
-
tokens = np.array(tokens)
|
| 49 |
-
batch_tokens.append(tokens)
|
| 50 |
-
caps, cap_lens = pad_sequence(batch_tokens, self.pad)
|
| 51 |
-
return {
|
| 52 |
-
"cap": caps,
|
| 53 |
-
"cap_len": cap_lens
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
def decode(self, batch_token_ids):
|
| 57 |
-
output = []
|
| 58 |
-
for token_ids in batch_token_ids:
|
| 59 |
-
tokens = []
|
| 60 |
-
for token_id in token_ids:
|
| 61 |
-
if token_id == self.eos:
|
| 62 |
-
break
|
| 63 |
-
elif token_id == self.bos:
|
| 64 |
-
continue
|
| 65 |
-
tokens.append(self.idx2word[token_id])
|
| 66 |
-
output.append(" ".join(tokens))
|
| 67 |
-
return output
|
| 68 |
-
|
| 69 |
-
def __len__(self):
|
| 70 |
-
return len(self.word2idx)
|
| 71 |
-
|
| 72 |
-
def state_dict(self):
|
| 73 |
-
return self.word2idx
|
| 74 |
-
|
| 75 |
-
def load_state_dict(self, state_dict):
|
| 76 |
-
self.word2idx = state_dict
|
| 77 |
-
self.idx2word = {idx: word for word, idx in self.word2idx.items()}
|
| 78 |
-
self.idx = len(self.word2idx)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
class HuggingfaceTokenizer:
|
| 82 |
-
|
| 83 |
-
def __init__(self,
|
| 84 |
-
model_name_or_path,
|
| 85 |
-
max_length) -> None:
|
| 86 |
-
from transformers import AutoTokenizer
|
| 87 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 88 |
-
self.max_length = max_length
|
| 89 |
-
self.bos, self.eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
|
| 90 |
-
self.pad = self.tokenizer.pad_token_id
|
| 91 |
-
self.loaded = True
|
| 92 |
-
|
| 93 |
-
def __call__(self, texts):
|
| 94 |
-
assert isinstance(texts, list), "the input must be List[str]"
|
| 95 |
-
batch_token_dict = self.tokenizer(texts,
|
| 96 |
-
padding=True,
|
| 97 |
-
truncation=True,
|
| 98 |
-
max_length=self.max_length,
|
| 99 |
-
return_tensors="pt")
|
| 100 |
-
batch_token_dict["cap"] = batch_token_dict["input_ids"]
|
| 101 |
-
cap_lens = batch_token_dict["attention_mask"].sum(dim=1)
|
| 102 |
-
cap_lens = cap_lens.numpy().astype(np.int32)
|
| 103 |
-
batch_token_dict["cap_len"] = cap_lens
|
| 104 |
-
return batch_token_dict
|
| 105 |
-
|
| 106 |
-
def decode(self, batch_token_ids):
|
| 107 |
-
return self.tokenizer.batch_decode(batch_token_ids, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/model_util.py
DELETED
|
@@ -1,186 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
|
| 7 |
-
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def sort_pack_padded_sequence(input, lengths):
|
| 11 |
-
sorted_lengths, indices = torch.sort(lengths, descending=True)
|
| 12 |
-
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
|
| 13 |
-
inv_ix = indices.clone()
|
| 14 |
-
inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix)
|
| 15 |
-
return tmp, inv_ix
|
| 16 |
-
|
| 17 |
-
def pad_unsort_packed_sequence(input, inv_ix):
|
| 18 |
-
tmp, _ = pad_packed_sequence(input, batch_first=True)
|
| 19 |
-
tmp = tmp[inv_ix]
|
| 20 |
-
return tmp
|
| 21 |
-
|
| 22 |
-
def pack_wrapper(module, attn_feats, attn_feat_lens):
|
| 23 |
-
packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
|
| 24 |
-
if isinstance(module, torch.nn.RNNBase):
|
| 25 |
-
return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
|
| 26 |
-
else:
|
| 27 |
-
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
|
| 28 |
-
|
| 29 |
-
def generate_length_mask(lens, max_length=None):
|
| 30 |
-
lens = torch.as_tensor(lens)
|
| 31 |
-
N = lens.size(0)
|
| 32 |
-
if max_length is None:
|
| 33 |
-
max_length = max(lens)
|
| 34 |
-
if isinstance(max_length, torch.Tensor):
|
| 35 |
-
max_length = max_length.item()
|
| 36 |
-
idxs = torch.arange(max_length).repeat(N).view(N, max_length)
|
| 37 |
-
idxs = idxs.to(lens.device)
|
| 38 |
-
mask = (idxs < lens.view(-1, 1))
|
| 39 |
-
return mask
|
| 40 |
-
|
| 41 |
-
def mean_with_lens(features, lens):
|
| 42 |
-
"""
|
| 43 |
-
features: [N, T, ...] (assume the second dimension represents length)
|
| 44 |
-
lens: [N,]
|
| 45 |
-
"""
|
| 46 |
-
lens = torch.as_tensor(lens)
|
| 47 |
-
if max(lens) != features.size(1):
|
| 48 |
-
max_length = features.size(1)
|
| 49 |
-
mask = generate_length_mask(lens, max_length)
|
| 50 |
-
else:
|
| 51 |
-
mask = generate_length_mask(lens)
|
| 52 |
-
mask = mask.to(features.device) # [N, T]
|
| 53 |
-
|
| 54 |
-
while mask.ndim < features.ndim:
|
| 55 |
-
mask = mask.unsqueeze(-1)
|
| 56 |
-
feature_mean = features * mask
|
| 57 |
-
feature_mean = feature_mean.sum(1)
|
| 58 |
-
while lens.ndim < feature_mean.ndim:
|
| 59 |
-
lens = lens.unsqueeze(1)
|
| 60 |
-
feature_mean = feature_mean / lens.to(features.device)
|
| 61 |
-
# feature_mean = features * mask.unsqueeze(-1)
|
| 62 |
-
# feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
|
| 63 |
-
return feature_mean
|
| 64 |
-
|
| 65 |
-
def max_with_lens(features, lens):
|
| 66 |
-
"""
|
| 67 |
-
features: [N, T, ...] (assume the second dimension represents length)
|
| 68 |
-
lens: [N,]
|
| 69 |
-
"""
|
| 70 |
-
lens = torch.as_tensor(lens)
|
| 71 |
-
if max(lens) != features.size(1):
|
| 72 |
-
max_length = features.size(1)
|
| 73 |
-
mask = generate_length_mask(lens, max_length)
|
| 74 |
-
else:
|
| 75 |
-
mask = generate_length_mask(lens)
|
| 76 |
-
mask = mask.to(features.device) # [N, T]
|
| 77 |
-
|
| 78 |
-
feature_max = features.clone()
|
| 79 |
-
feature_max[~mask] = float("-inf")
|
| 80 |
-
feature_max, _ = feature_max.max(1)
|
| 81 |
-
return feature_max
|
| 82 |
-
|
| 83 |
-
def repeat_tensor(x, n):
|
| 84 |
-
return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
|
| 85 |
-
|
| 86 |
-
def init(m, method="kaiming"):
|
| 87 |
-
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
|
| 88 |
-
if method == "kaiming":
|
| 89 |
-
nn.init.kaiming_uniform_(m.weight)
|
| 90 |
-
elif method == "xavier":
|
| 91 |
-
nn.init.xavier_uniform_(m.weight)
|
| 92 |
-
else:
|
| 93 |
-
raise Exception(f"initialization method {method} not supported")
|
| 94 |
-
if m.bias is not None:
|
| 95 |
-
nn.init.constant_(m.bias, 0)
|
| 96 |
-
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
| 97 |
-
nn.init.constant_(m.weight, 1)
|
| 98 |
-
if m.bias is not None:
|
| 99 |
-
nn.init.constant_(m.bias, 0)
|
| 100 |
-
elif isinstance(m, nn.Linear):
|
| 101 |
-
if method == "kaiming":
|
| 102 |
-
nn.init.kaiming_uniform_(m.weight)
|
| 103 |
-
elif method == "xavier":
|
| 104 |
-
nn.init.xavier_uniform_(m.weight)
|
| 105 |
-
else:
|
| 106 |
-
raise Exception(f"initialization method {method} not supported")
|
| 107 |
-
if m.bias is not None:
|
| 108 |
-
nn.init.constant_(m.bias, 0)
|
| 109 |
-
elif isinstance(m, nn.Embedding):
|
| 110 |
-
if method == "kaiming":
|
| 111 |
-
nn.init.kaiming_uniform_(m.weight)
|
| 112 |
-
elif method == "xavier":
|
| 113 |
-
nn.init.xavier_uniform_(m.weight)
|
| 114 |
-
else:
|
| 115 |
-
raise Exception(f"initialization method {method} not supported")
|
| 116 |
-
|
| 117 |
-
def compute_batch_score(decode_res,
|
| 118 |
-
key2refs,
|
| 119 |
-
keys,
|
| 120 |
-
start_idx,
|
| 121 |
-
end_idx,
|
| 122 |
-
vocabulary,
|
| 123 |
-
scorer):
|
| 124 |
-
"""
|
| 125 |
-
Args:
|
| 126 |
-
decode_res: decoding results of model, [N, max_length]
|
| 127 |
-
key2refs: references of all samples, dict(<key> -> [ref_1, ref_2, ..., ref_n]
|
| 128 |
-
keys: keys of this batch, used to match decode results and refs
|
| 129 |
-
Return:
|
| 130 |
-
scores of this batch, [N,]
|
| 131 |
-
"""
|
| 132 |
-
|
| 133 |
-
if scorer is None:
|
| 134 |
-
from pycocoevalcap.cider.cider import Cider
|
| 135 |
-
scorer = Cider()
|
| 136 |
-
|
| 137 |
-
hypothesis = {}
|
| 138 |
-
references = {}
|
| 139 |
-
|
| 140 |
-
for i in range(len(keys)):
|
| 141 |
-
|
| 142 |
-
if keys[i] in hypothesis.keys():
|
| 143 |
-
continue
|
| 144 |
-
|
| 145 |
-
# prepare candidate sentence
|
| 146 |
-
candidate = []
|
| 147 |
-
for w_t in decode_res[i]:
|
| 148 |
-
if w_t == start_idx:
|
| 149 |
-
continue
|
| 150 |
-
elif w_t == end_idx:
|
| 151 |
-
break
|
| 152 |
-
candidate.append(vocabulary.idx2word[w_t])
|
| 153 |
-
|
| 154 |
-
hypothesis[keys[i]] = [" ".join(candidate), ]
|
| 155 |
-
|
| 156 |
-
# prepare reference sentences
|
| 157 |
-
references[keys[i]] = key2refs[keys[i]]
|
| 158 |
-
|
| 159 |
-
score, scores = scorer.compute_score(references, hypothesis)
|
| 160 |
-
key2score = {key: scores[i] for i, key in enumerate(references.keys())}
|
| 161 |
-
results = np.zeros(decode_res.shape[0])
|
| 162 |
-
for i in range(decode_res.shape[0]):
|
| 163 |
-
results[i] = key2score[keys[i]]
|
| 164 |
-
return results
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
class PositionalEncoding(nn.Module):
|
| 168 |
-
|
| 169 |
-
def __init__(self, d_model, dropout=0.1, max_len=100):
|
| 170 |
-
super(PositionalEncoding, self).__init__()
|
| 171 |
-
self.dropout = nn.Dropout(p=dropout)
|
| 172 |
-
|
| 173 |
-
pe = torch.zeros(max_len, d_model)
|
| 174 |
-
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 175 |
-
div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
|
| 176 |
-
(-math.log(10000.0) / d_model))
|
| 177 |
-
pe[:, 0::2] = torch.sin(position * div_term)
|
| 178 |
-
pe[:, 1::2] = torch.cos(position * div_term)
|
| 179 |
-
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 180 |
-
# self.register_buffer("pe", pe)
|
| 181 |
-
self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
|
| 182 |
-
|
| 183 |
-
def forward(self, x):
|
| 184 |
-
# x: [T, N, E]
|
| 185 |
-
x = x + self.pe[:x.size(0), :]
|
| 186 |
-
return self.dropout(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/train_util.py
DELETED
|
@@ -1,117 +0,0 @@
|
|
| 1 |
-
import importlib
|
| 2 |
-
import os
|
| 3 |
-
import sys
|
| 4 |
-
from typing import Callable, Dict, Union
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import yaml
|
| 8 |
-
import torch
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def merge_a_into_b(a, b):
|
| 12 |
-
# merge dict a into dict b. values in a will overwrite b.
|
| 13 |
-
for k, v in a.items():
|
| 14 |
-
if isinstance(v, dict) and k in b:
|
| 15 |
-
assert isinstance(
|
| 16 |
-
b[k], dict
|
| 17 |
-
), "Cannot inherit key '{}' from base!".format(k)
|
| 18 |
-
merge_a_into_b(v, b[k])
|
| 19 |
-
else:
|
| 20 |
-
b[k] = v
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def load_config(config_file):
|
| 24 |
-
with open(config_file, "r") as reader:
|
| 25 |
-
config = yaml.load(reader, Loader=yaml.FullLoader)
|
| 26 |
-
if "inherit_from" in config:
|
| 27 |
-
base_config_file = config["inherit_from"]
|
| 28 |
-
base_config_file = os.path.join(
|
| 29 |
-
os.path.dirname(config_file), base_config_file
|
| 30 |
-
)
|
| 31 |
-
assert not os.path.samefile(config_file, base_config_file), \
|
| 32 |
-
"inherit from itself"
|
| 33 |
-
base_config = load_config(base_config_file)
|
| 34 |
-
del config["inherit_from"]
|
| 35 |
-
merge_a_into_b(config, base_config)
|
| 36 |
-
return base_config
|
| 37 |
-
return config
|
| 38 |
-
|
| 39 |
-
def get_cls_from_str(string, reload=False):
|
| 40 |
-
module_name, cls_name = string.rsplit(".", 1)
|
| 41 |
-
if reload:
|
| 42 |
-
module_imp = importlib.import_module(module_name)
|
| 43 |
-
importlib.reload(module_imp)
|
| 44 |
-
return getattr(importlib.import_module(module_name, package=None), cls_name)
|
| 45 |
-
|
| 46 |
-
def init_obj_from_dict(config, **kwargs):
|
| 47 |
-
obj_args = config["args"].copy()
|
| 48 |
-
obj_args.update(kwargs)
|
| 49 |
-
for k in config:
|
| 50 |
-
if k not in ["type", "args"] and isinstance(config[k], dict) and k not in kwargs:
|
| 51 |
-
obj_args[k] = init_obj_from_dict(config[k])
|
| 52 |
-
try:
|
| 53 |
-
obj = get_cls_from_str(config["type"])(**obj_args)
|
| 54 |
-
return obj
|
| 55 |
-
except Exception as e:
|
| 56 |
-
print(f"Initializing {config} failed, detailed error stack: ")
|
| 57 |
-
raise e
|
| 58 |
-
|
| 59 |
-
def init_model_from_config(config, print_fn=sys.stdout.write):
|
| 60 |
-
kwargs = {}
|
| 61 |
-
for k in config:
|
| 62 |
-
if k not in ["type", "args", "pretrained"]:
|
| 63 |
-
sub_model = init_model_from_config(config[k], print_fn)
|
| 64 |
-
if "pretrained" in config[k]:
|
| 65 |
-
load_pretrained_model(sub_model,
|
| 66 |
-
config[k]["pretrained"],
|
| 67 |
-
print_fn)
|
| 68 |
-
kwargs[k] = sub_model
|
| 69 |
-
model = init_obj_from_dict(config, **kwargs)
|
| 70 |
-
return model
|
| 71 |
-
|
| 72 |
-
def merge_load_state_dict(state_dict,
|
| 73 |
-
model: torch.nn.Module,
|
| 74 |
-
output_fn: Callable = sys.stdout.write):
|
| 75 |
-
model_dict = model.state_dict()
|
| 76 |
-
pretrained_dict = {}
|
| 77 |
-
mismatch_keys = []
|
| 78 |
-
for key, value in state_dict.items():
|
| 79 |
-
if key in model_dict and model_dict[key].shape == value.shape:
|
| 80 |
-
pretrained_dict[key] = value
|
| 81 |
-
else:
|
| 82 |
-
mismatch_keys.append(key)
|
| 83 |
-
output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}\n")
|
| 84 |
-
model_dict.update(pretrained_dict)
|
| 85 |
-
model.load_state_dict(model_dict, strict=True)
|
| 86 |
-
return pretrained_dict.keys()
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def load_pretrained_model(model: torch.nn.Module,
|
| 90 |
-
pretrained: Union[str, Dict],
|
| 91 |
-
output_fn: Callable = sys.stdout.write):
|
| 92 |
-
if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
|
| 93 |
-
output_fn(f"pretrained {pretrained} not exist!")
|
| 94 |
-
return
|
| 95 |
-
|
| 96 |
-
if hasattr(model, "load_pretrained"):
|
| 97 |
-
model.load_pretrained(pretrained, output_fn)
|
| 98 |
-
return
|
| 99 |
-
|
| 100 |
-
if isinstance(pretrained, dict):
|
| 101 |
-
state_dict = pretrained
|
| 102 |
-
else:
|
| 103 |
-
state_dict = torch.load(pretrained, map_location="cpu")
|
| 104 |
-
|
| 105 |
-
if "model" in state_dict:
|
| 106 |
-
state_dict = state_dict["model"]
|
| 107 |
-
|
| 108 |
-
merge_load_state_dict(state_dict, model, output_fn)
|
| 109 |
-
|
| 110 |
-
def pad_sequence(data, pad_value=0):
|
| 111 |
-
if isinstance(data[0], (np.ndarray, torch.Tensor)):
|
| 112 |
-
data = [torch.as_tensor(arr) for arr in data]
|
| 113 |
-
padded_seq = torch.nn.utils.rnn.pad_sequence(data,
|
| 114 |
-
batch_first=True,
|
| 115 |
-
padding_value=pad_value)
|
| 116 |
-
length = np.array([x.shape[0] for x in data])
|
| 117 |
-
return padded_seq, length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|