Upload 2 files
Browse files- conditional-diffusion.py +149 -0
- diffusion_condition_model.pt +3 -0
conditional-diffusion.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader, Dataset
|
| 3 |
+
import torchaudio
|
| 4 |
+
import torchvision.transforms as tvt
|
| 5 |
+
from denoising_diffusion_pytorch.classifier_free_guidance import Unet, GaussianDiffusion
|
| 6 |
+
import glob
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import time, math
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from diffusers import Mel
|
| 11 |
+
import sys
|
| 12 |
+
import torchaudio
|
| 13 |
+
import librosa
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
|
| 16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 17 |
+
|
| 18 |
+
args = sys.argv[1:]
|
| 19 |
+
|
| 20 |
+
class Audio(Dataset):
|
| 21 |
+
def __init__(self, folder):
|
| 22 |
+
# resample = tat.Resample(48000)
|
| 23 |
+
self.waveforms = []
|
| 24 |
+
self.labels = []
|
| 25 |
+
print("Loading files...")
|
| 26 |
+
for file in glob.iglob(folder + '/**/*.wav', recursive=True): # recurse through files
|
| 27 |
+
self.labels.append(int(file.split('/')[-1][0])) # get label from file name
|
| 28 |
+
waveform, _ = torchaudio.load(file)
|
| 29 |
+
# waveform, _ = librosa.load(file, sr=None) # load text
|
| 30 |
+
self.waveforms.append(waveform)
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.waveforms)
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, index):
|
| 36 |
+
return self.waveforms[index], self.labels[index]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
image_size = 256
|
| 40 |
+
if len(args) >= 1:
|
| 41 |
+
image_size = int(args[0])
|
| 42 |
+
|
| 43 |
+
MEL = Mel(x_res=image_size, y_res=image_size)
|
| 44 |
+
img_to_tensor = tvt.PILToTensor()
|
| 45 |
+
|
| 46 |
+
def collate(batch):
|
| 47 |
+
spectros = []
|
| 48 |
+
labels = []
|
| 49 |
+
for waveform, label in batch:
|
| 50 |
+
MEL.load_audio(raw_audio=waveform[0])
|
| 51 |
+
for slice in range(MEL.get_number_of_slices()):
|
| 52 |
+
spectro = MEL.audio_slice_to_image(slice)
|
| 53 |
+
spectro = img_to_tensor(spectro) / 255.0
|
| 54 |
+
# print(spectro.shape)
|
| 55 |
+
# plt.imshow(spectro[0])
|
| 56 |
+
# plt.show()
|
| 57 |
+
# input("continue")
|
| 58 |
+
spectros.append(spectro)
|
| 59 |
+
labels.append(label)
|
| 60 |
+
|
| 61 |
+
spectros = torch.stack(spectros)
|
| 62 |
+
labels = torch.tensor(labels)
|
| 63 |
+
# one_hot = nn.functional.one_hot(labels, num_classes=10) # one hot vectors for conditional generation
|
| 64 |
+
return spectros.to(device), labels.to(device)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def initialize(scheduler = None, batch_size=32):
|
| 68 |
+
model = Unet(
|
| 69 |
+
dim = 64,
|
| 70 |
+
num_classes=10,
|
| 71 |
+
dim_mults=(1, 2, 4, 8),
|
| 72 |
+
channels=1
|
| 73 |
+
)
|
| 74 |
+
diffusion = GaussianDiffusion(
|
| 75 |
+
model,
|
| 76 |
+
image_size=image_size,
|
| 77 |
+
timesteps=1000,
|
| 78 |
+
loss_type = 'l2',
|
| 79 |
+
objective='pred_x0',
|
| 80 |
+
# channels=1,
|
| 81 |
+
)
|
| 82 |
+
diffusion.to(device)
|
| 83 |
+
|
| 84 |
+
optim = torch.optim.AdamW(model.parameters(), lr=1e-4, eps=1e-8)
|
| 85 |
+
if scheduler:
|
| 86 |
+
scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=1e-5, max_lr=1e-3, mode="exp_range", cycle_momentum=False)
|
| 87 |
+
return diffusion, optim, scheduler
|
| 88 |
+
|
| 89 |
+
def timeSince(since):
|
| 90 |
+
now = time.time()
|
| 91 |
+
s = now - since
|
| 92 |
+
m = math.floor(s / 60)
|
| 93 |
+
s -= m * 60
|
| 94 |
+
return '%dm %ds' % (m, s)
|
| 95 |
+
|
| 96 |
+
start = time.time()
|
| 97 |
+
|
| 98 |
+
def train(model, optim, train_dl, batch_size=32, epochs=5, scheduler = None):
|
| 99 |
+
size = len(train_dl.dataset)
|
| 100 |
+
model.train()
|
| 101 |
+
losses = []
|
| 102 |
+
|
| 103 |
+
for e in range(epochs):
|
| 104 |
+
batch_loss, batch_counts = 0, 0
|
| 105 |
+
for step, batch in enumerate(train_dl):
|
| 106 |
+
model.zero_grad()
|
| 107 |
+
batch_counts += 1
|
| 108 |
+
spectros, labels = batch
|
| 109 |
+
loss = model(spectros, classes=labels)
|
| 110 |
+
|
| 111 |
+
batch_loss += loss.item()
|
| 112 |
+
loss.backward()
|
| 113 |
+
nn.utils.clip_grad_norm_(model.parameters(), 1)
|
| 114 |
+
optim.step()
|
| 115 |
+
if scheduler is not None:
|
| 116 |
+
scheduler.step()
|
| 117 |
+
|
| 118 |
+
if (step % 100 == 0 and step != 0) or (step == len(train_dl) - 1):
|
| 119 |
+
to_print = f"{e + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {timeSince(start)} | {step*batch_size:>5d}/{size:>5d}"
|
| 120 |
+
print(to_print)
|
| 121 |
+
losses.append(batch_loss)
|
| 122 |
+
batch_loss, batch_counts = 0, 0
|
| 123 |
+
|
| 124 |
+
labels = torch.randint(0,9,(8, )).to(device)
|
| 125 |
+
print(labels)
|
| 126 |
+
samples = model.sample(labels)
|
| 127 |
+
for i, sample in enumerate(samples):
|
| 128 |
+
im = Image.fromarray(sample[0].cpu().numpy() * 255).convert('L')
|
| 129 |
+
audio = torch.tensor([MEL.image_to_audio(im)])
|
| 130 |
+
torchaudio.save(f"audio/sample{e}_{i}_{labels[i]}.wav", audio, 48000)
|
| 131 |
+
im.save(f"images/sample{e}_{i}_{labels[i]}.jpg")
|
| 132 |
+
return losses
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
num_epochs = 10
|
| 136 |
+
if len(args) >= 2:
|
| 137 |
+
num_epochs = int(args[1])
|
| 138 |
+
|
| 139 |
+
batch_size = 32
|
| 140 |
+
if len(args) >= 3:
|
| 141 |
+
batch_size = int(args[2])
|
| 142 |
+
|
| 143 |
+
print(image_size, num_epochs, batch_size)
|
| 144 |
+
model, optim, scheduler = initialize(scheduler=True, batch_size=batch_size)
|
| 145 |
+
train_data = Audio("AudioMNIST/data")
|
| 146 |
+
print("Done Loading")
|
| 147 |
+
train_dl = DataLoader(train_data, batch_size, True, collate_fn=collate)
|
| 148 |
+
train(model, optim, train_dl, batch_size, num_epochs, scheduler)
|
| 149 |
+
torch.save(model.state_dict(), "diffusion_condition_model.pt")
|
diffusion_condition_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:25021d8c6a1813ba51f2f7cb9d015b132f7f21d2deaad397fac0d641cdc671cc
|
| 3 |
+
size 153739669
|