Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| import numpy as np | |
| from torch.utils import data | |
| from collections import OrderedDict | |
| from torch.nn.parameter import Parameter | |
| class SincConv(nn.Module): | |
| def to_mel(hz): | |
| return 2595 * np.log10(1 + hz / 700) | |
| def to_hz(mel): | |
| return 700 * (10 ** (mel / 2595) - 1) | |
| def __init__(self, device,out_channels, kernel_size,in_channels=1,sample_rate=16000, | |
| stride=1, padding=0, dilation=1, bias=False, groups=1): | |
| super(SincConv,self).__init__() | |
| if in_channels != 1: | |
| msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels) | |
| raise ValueError(msg) | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| self.sample_rate=sample_rate | |
| # Forcing the filters to be odd (i.e, perfectly symmetrics) | |
| if kernel_size%2==0: | |
| self.kernel_size=self.kernel_size+1 | |
| self.device=device | |
| self.stride = stride | |
| self.padding = padding | |
| self.dilation = dilation | |
| if bias: | |
| raise ValueError('SincConv does not support bias.') | |
| if groups > 1: | |
| raise ValueError('SincConv does not support groups.') | |
| # initialize filterbanks using Mel scale | |
| NFFT = 512 | |
| f=int(self.sample_rate/2)*np.linspace(0,1,int(NFFT/2)+1) | |
| fmel=self.to_mel(f) # Hz to mel conversion | |
| fmelmax=np.max(fmel) | |
| fmelmin=np.min(fmel) | |
| filbandwidthsmel=np.linspace(fmelmin,fmelmax,self.out_channels+1) | |
| filbandwidthsf=self.to_hz(filbandwidthsmel) # Mel to Hz conversion | |
| self.mel=filbandwidthsf | |
| self.hsupp=torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2+1) | |
| self.band_pass=torch.zeros(self.out_channels,self.kernel_size) | |
| def forward(self,x): | |
| for i in range(len(self.mel)-1): | |
| fmin=self.mel[i] | |
| fmax=self.mel[i+1] | |
| hHigh=(2*fmax/self.sample_rate)*np.sinc(2*fmax*self.hsupp/self.sample_rate) | |
| hLow=(2*fmin/self.sample_rate)*np.sinc(2*fmin*self.hsupp/self.sample_rate) | |
| hideal=hHigh-hLow | |
| self.band_pass[i,:]=Tensor(np.hamming(self.kernel_size))*Tensor(hideal) | |
| band_pass_filter=self.band_pass.to(self.device) | |
| self.filters = (band_pass_filter).view(self.out_channels, 1, self.kernel_size) | |
| return F.conv1d(x, self.filters, stride=self.stride, | |
| padding=self.padding, dilation=self.dilation, | |
| bias=None, groups=1) | |
| class Residual_block(nn.Module): | |
| def __init__(self, nb_filts, first = False): | |
| super(Residual_block, self).__init__() | |
| self.first = first | |
| if not self.first: | |
| self.bn1 = nn.BatchNorm1d(num_features = nb_filts[0]) | |
| self.lrelu = nn.LeakyReLU(negative_slope=0.3) | |
| self.conv1 = nn.Conv1d(in_channels = nb_filts[0], | |
| out_channels = nb_filts[1], | |
| kernel_size = 3, | |
| padding = 1, | |
| stride = 1) | |
| self.bn2 = nn.BatchNorm1d(num_features = nb_filts[1]) | |
| self.conv2 = nn.Conv1d(in_channels = nb_filts[1], | |
| out_channels = nb_filts[1], | |
| padding = 1, | |
| kernel_size = 3, | |
| stride = 1) | |
| if nb_filts[0] != nb_filts[1]: | |
| self.downsample = True | |
| self.conv_downsample = nn.Conv1d(in_channels = nb_filts[0], | |
| out_channels = nb_filts[1], | |
| padding = 0, | |
| kernel_size = 1, | |
| stride = 1) | |
| else: | |
| self.downsample = False | |
| self.mp = nn.MaxPool1d(3) | |
| def forward(self, x): | |
| identity = x | |
| if not self.first: | |
| out = self.bn1(x) | |
| out = self.lrelu(out) | |
| else: | |
| out = x | |
| out = self.conv1(x) | |
| out = self.bn2(out) | |
| out = self.lrelu(out) | |
| out = self.conv2(out) | |
| if self.downsample: | |
| identity = self.conv_downsample(identity) | |
| out += identity | |
| out = self.mp(out) | |
| return out | |
| class RawNet(nn.Module): | |
| def __init__(self, d_args, device): | |
| super(RawNet, self).__init__() | |
| self.device=device | |
| self.Sinc_conv=SincConv(device=self.device, | |
| out_channels = d_args['filts'][0], | |
| kernel_size = d_args['first_conv'], | |
| in_channels = d_args['in_channels'] | |
| ) | |
| self.first_bn = nn.BatchNorm1d(num_features = d_args['filts'][0]) | |
| self.selu = nn.SELU(inplace=True) | |
| self.block0 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1], first = True)) | |
| self.block1 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1])) | |
| self.block2 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2])) | |
| d_args['filts'][2][0] = d_args['filts'][2][1] | |
| self.block3 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2])) | |
| self.block4 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2])) | |
| self.block5 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2])) | |
| self.avgpool = nn.AdaptiveAvgPool1d(1) | |
| self.fc_attention0 = self._make_attention_fc(in_features = d_args['filts'][1][-1], | |
| l_out_features = d_args['filts'][1][-1]) | |
| self.fc_attention1 = self._make_attention_fc(in_features = d_args['filts'][1][-1], | |
| l_out_features = d_args['filts'][1][-1]) | |
| self.fc_attention2 = self._make_attention_fc(in_features = d_args['filts'][2][-1], | |
| l_out_features = d_args['filts'][2][-1]) | |
| self.fc_attention3 = self._make_attention_fc(in_features = d_args['filts'][2][-1], | |
| l_out_features = d_args['filts'][2][-1]) | |
| self.fc_attention4 = self._make_attention_fc(in_features = d_args['filts'][2][-1], | |
| l_out_features = d_args['filts'][2][-1]) | |
| self.fc_attention5 = self._make_attention_fc(in_features = d_args['filts'][2][-1], | |
| l_out_features = d_args['filts'][2][-1]) | |
| self.bn_before_gru = nn.BatchNorm1d(num_features = d_args['filts'][2][-1]) | |
| self.gru = nn.GRU(input_size = d_args['filts'][2][-1], | |
| hidden_size = d_args['gru_node'], | |
| num_layers = d_args['nb_gru_layer'], | |
| batch_first = True) | |
| self.fc1_gru = nn.Linear(in_features = d_args['gru_node'], | |
| out_features = d_args['nb_fc_node']) | |
| self.fc2_gru = nn.Linear(in_features = d_args['nb_fc_node'], | |
| out_features = d_args['nb_classes'],bias=True) | |
| self.sig = nn.Sigmoid() | |
| self.logsoftmax = nn.LogSoftmax(dim=1) | |
| def forward(self, x, y = None): | |
| nb_samp = x.shape[0] | |
| len_seq = x.shape[1] | |
| x=x.view(nb_samp,1,len_seq) | |
| x = self.Sinc_conv(x) | |
| x = F.max_pool1d(torch.abs(x), 3) | |
| x = self.first_bn(x) | |
| x = self.selu(x) | |
| x0 = self.block0(x) | |
| y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter]) | |
| y0 = self.fc_attention0(y0) | |
| y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1]) | |
| x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1) | |
| x1 = self.block1(x) | |
| y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter]) | |
| y1 = self.fc_attention1(y1) | |
| y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1]) | |
| x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1) | |
| x2 = self.block2(x) | |
| y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter]) | |
| y2 = self.fc_attention2(y2) | |
| y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1]) | |
| x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1) | |
| x3 = self.block3(x) | |
| y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter]) | |
| y3 = self.fc_attention3(y3) | |
| y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1]) | |
| x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1) | |
| x4 = self.block4(x) | |
| y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter]) | |
| y4 = self.fc_attention4(y4) | |
| y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1]) | |
| x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1) | |
| x5 = self.block5(x) | |
| y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter]) | |
| y5 = self.fc_attention5(y5) | |
| y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1]) | |
| x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1) | |
| x = self.bn_before_gru(x) | |
| x = self.selu(x) | |
| x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt) | |
| self.gru.flatten_parameters() | |
| x, _ = self.gru(x) | |
| x = x[:,-1,:] | |
| x = self.fc1_gru(x) | |
| x = self.fc2_gru(x) | |
| output=self.logsoftmax(x) | |
| print(f"Spec output shape: {output.shape}") | |
| return output | |
| def _make_attention_fc(self, in_features, l_out_features): | |
| l_fc = [] | |
| l_fc.append(nn.Linear(in_features = in_features, | |
| out_features = l_out_features)) | |
| return nn.Sequential(*l_fc) | |
| def _make_layer(self, nb_blocks, nb_filts, first = False): | |
| layers = [] | |
| #def __init__(self, nb_filts, first = False): | |
| for i in range(nb_blocks): | |
| first = first if i == 0 else False | |
| layers.append(Residual_block(nb_filts = nb_filts, | |
| first = first)) | |
| if i == 0: nb_filts[0] = nb_filts[1] | |
| return nn.Sequential(*layers) | |
| def summary(self, input_size, batch_size=-1, device="cuda", print_fn = None): | |
| if print_fn == None: printfn = print | |
| model = self | |
| def register_hook(module): | |
| def hook(module, input, output): | |
| class_name = str(module.__class__).split(".")[-1].split("'")[0] | |
| module_idx = len(summary) | |
| m_key = "%s-%i" % (class_name, module_idx + 1) | |
| summary[m_key] = OrderedDict() | |
| summary[m_key]["input_shape"] = list(input[0].size()) | |
| summary[m_key]["input_shape"][0] = batch_size | |
| if isinstance(output, (list, tuple)): | |
| summary[m_key]["output_shape"] = [ | |
| [-1] + list(o.size())[1:] for o in output | |
| ] | |
| else: | |
| summary[m_key]["output_shape"] = list(output.size()) | |
| if len(summary[m_key]["output_shape"]) != 0: | |
| summary[m_key]["output_shape"][0] = batch_size | |
| params = 0 | |
| if hasattr(module, "weight") and hasattr(module.weight, "size"): | |
| params += torch.prod(torch.LongTensor(list(module.weight.size()))) | |
| summary[m_key]["trainable"] = module.weight.requires_grad | |
| if hasattr(module, "bias") and hasattr(module.bias, "size"): | |
| params += torch.prod(torch.LongTensor(list(module.bias.size()))) | |
| summary[m_key]["nb_params"] = params | |
| if ( | |
| not isinstance(module, nn.Sequential) | |
| and not isinstance(module, nn.ModuleList) | |
| and not (module == model) | |
| ): | |
| hooks.append(module.register_forward_hook(hook)) | |
| device = device.lower() | |
| assert device in [ | |
| "cuda", | |
| "cpu", | |
| ], "Input device is not valid, please specify 'cuda' or 'cpu'" | |
| if device == "cuda" and torch.cuda.is_available(): | |
| dtype = torch.cuda.FloatTensor | |
| else: | |
| dtype = torch.FloatTensor | |
| if isinstance(input_size, tuple): | |
| input_size = [input_size] | |
| x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] | |
| summary = OrderedDict() | |
| hooks = [] | |
| model.apply(register_hook) | |
| model(*x) | |
| for h in hooks: | |
| h.remove() | |
| print_fn("----------------------------------------------------------------") | |
| line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") | |
| print_fn(line_new) | |
| print_fn("================================================================") | |
| total_params = 0 | |
| total_output = 0 | |
| trainable_params = 0 | |
| for layer in summary: | |
| # input_shape, output_shape, trainable, nb_params | |
| line_new = "{:>20} {:>25} {:>15}".format( | |
| layer, | |
| str(summary[layer]["output_shape"]), | |
| "{0:,}".format(summary[layer]["nb_params"]), | |
| ) | |
| total_params += summary[layer]["nb_params"] | |
| total_output += np.prod(summary[layer]["output_shape"]) | |
| if "trainable" in summary[layer]: | |
| if summary[layer]["trainable"] == True: | |
| trainable_params += summary[layer]["nb_params"] | |
| print_fn(line_new) | |