|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import random |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
def add_noise(x_0, noise, t): |
|
|
return x_0 + noise * (t / 1000.0) |
|
|
|
|
|
|
|
|
def plot_data(mu, sigma, color, title): |
|
|
all_losses = np.array(mu) |
|
|
sigma_losses = np.array(sigma) |
|
|
x = np.arange(len(mu)) |
|
|
plt.plot(x, all_losses, f'{color}-') |
|
|
plt.fill_between(x, all_losses - sigma_losses, all_losses + sigma_losses, color=color, alpha=0.2) |
|
|
plt.legend(['Mean Loss', 'Variance of Loss']) |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.title(title) |
|
|
plt.show() |
|
|
|
|
|
|
|
|
def train(model, conditioner, dataset, epochs=10): |
|
|
model.train() |
|
|
conditioner.train() |
|
|
optimizer = torch.optim.Adam(list(model.parameters()) + list(conditioner.parameters()), lr=1e-4) |
|
|
ce_loss = nn.CrossEntropyLoss() |
|
|
torch.autograd.set_detect_anomaly(True) |
|
|
all_bond_losses: list = [] |
|
|
all_noise_losses: list = [] |
|
|
all_losses: list = [] |
|
|
all_sigma_bond_losses: list = [] |
|
|
all_sigma_noise_losses: list = [] |
|
|
all_sigma_losses: list = [] |
|
|
|
|
|
for epoch in range(epochs): |
|
|
total_bond_loss = 0 |
|
|
total_noise_loss = 0 |
|
|
total_loss = 0 |
|
|
sigma_bond_losses: list = [] |
|
|
sigma_noise_losses: list = [] |
|
|
sigma_losses: list = [] |
|
|
|
|
|
for data in dataset: |
|
|
x_0, pos, edge_index, edge_attr, labels = data.x, data.pos, data.edge_index, data.edge_attr.view(-1), data.y |
|
|
if torch.any(edge_attr >= 4) or torch.any(edge_attr < 0) or torch.any(torch.isnan(x_0)): |
|
|
continue |
|
|
t = torch.tensor([random.randint(1, 1000)]) |
|
|
noise = torch.randn_like(x_0) |
|
|
x_t = add_noise(x_0, noise, t) |
|
|
cond_embed = conditioner(labels) |
|
|
pred_noise, bond_logits = model(x_t, pos, edge_index, t, cond_embed) |
|
|
loss_noise = F.mse_loss(pred_noise, noise) |
|
|
loss_bond = ce_loss(bond_logits, edge_attr) |
|
|
loss = loss_noise + loss_bond |
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
optimizer.step() |
|
|
total_bond_loss += loss_bond.item() |
|
|
total_noise_loss += loss_noise.item() |
|
|
total_loss += loss.item() |
|
|
sigma_bond_losses.append(loss_bond.item()) |
|
|
sigma_noise_losses.append(loss_noise.item()) |
|
|
sigma_losses.append(loss.item()) |
|
|
|
|
|
all_bond_losses.append(total_bond_loss) |
|
|
all_noise_losses.append(total_noise_loss) |
|
|
all_losses.append(total_loss) |
|
|
all_sigma_bond_losses.append(torch.std(torch.tensor(sigma_bond_losses))) |
|
|
all_sigma_noise_losses.append(torch.std(torch.tensor(sigma_noise_losses))) |
|
|
all_sigma_losses.append(torch.std(torch.tensor(sigma_losses))) |
|
|
print(f"Epoch {epoch}: Loss = {total_loss:.4f}, Noise Loss = {total_noise_loss:.4f}, Bond Loss = {total_bond_loss:.4f}") |
|
|
|
|
|
plot_data(mu=all_bond_losses, sigma=all_sigma_bond_losses, color='b', title="Bond Loss") |
|
|
plot_data(mu=all_noise_losses, sigma=all_sigma_noise_losses, color='r', title="Noise Loss") |
|
|
plot_data(mu=all_losses, sigma=all_sigma_losses, color='g', title="Total Loss") |
|
|
|
|
|
plt.plot(all_bond_losses) |
|
|
plt.plot(all_noise_losses) |
|
|
plt.plot(all_losses) |
|
|
plt.legend(['Bond Loss', 'Noise Loss', 'Total Loss']) |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.title('Training Loss Over Epochs') |
|
|
plt.show() |
|
|
return model, conditioner |
|
|
|
|
|
|
|
|
|
|
|
def temperature_scaled_softmax(logits, temperature=1.0): |
|
|
logits = logits / temperature |
|
|
return torch.softmax(logits, dim=0) |
|
|
|