File size: 3,725 Bytes
b611e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt


# -------- Noise and Training --------
def add_noise(x_0, noise, t):
    return x_0 + noise * (t / 1000.0)


def plot_data(mu, sigma, color, title):
    all_losses = np.array(mu)
    sigma_losses = np.array(sigma)
    x = np.arange(len(mu))
    plt.plot(x, all_losses, f'{color}-')
    plt.fill_between(x, all_losses - sigma_losses, all_losses + sigma_losses, color=color, alpha=0.2)
    plt.legend(['Mean Loss', 'Variance of Loss'])
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(title)
    plt.show()


def train(model, conditioner, dataset, epochs=10):
    model.train()
    conditioner.train()
    optimizer = torch.optim.Adam(list(model.parameters()) + list(conditioner.parameters()), lr=1e-4)
    ce_loss = nn.CrossEntropyLoss()
    torch.autograd.set_detect_anomaly(True)
    all_bond_losses: list = []
    all_noise_losses: list = []
    all_losses: list = []
    all_sigma_bond_losses: list = []
    all_sigma_noise_losses: list = []
    all_sigma_losses: list = []

    for epoch in range(epochs):
        total_bond_loss = 0
        total_noise_loss = 0
        total_loss = 0
        sigma_bond_losses: list = []
        sigma_noise_losses: list = []
        sigma_losses: list = []

        for data in dataset:
            x_0, pos, edge_index, edge_attr, labels = data.x, data.pos, data.edge_index, data.edge_attr.view(-1), data.y
            if torch.any(edge_attr >= 4) or torch.any(edge_attr < 0) or torch.any(torch.isnan(x_0)):
                continue  # skip corrupted data
            t = torch.tensor([random.randint(1, 1000)])
            noise = torch.randn_like(x_0)
            x_t = add_noise(x_0, noise, t)
            cond_embed = conditioner(labels)
            pred_noise, bond_logits = model(x_t, pos, edge_index, t, cond_embed)
            loss_noise = F.mse_loss(pred_noise, noise)
            loss_bond = ce_loss(bond_logits, edge_attr)
            loss = loss_noise + loss_bond
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_bond_loss += loss_bond.item()
            total_noise_loss += loss_noise.item()
            total_loss += loss.item()
            sigma_bond_losses.append(loss_bond.item())
            sigma_noise_losses.append(loss_noise.item())
            sigma_losses.append(loss.item())

        all_bond_losses.append(total_bond_loss)
        all_noise_losses.append(total_noise_loss)
        all_losses.append(total_loss)
        all_sigma_bond_losses.append(torch.std(torch.tensor(sigma_bond_losses)))
        all_sigma_noise_losses.append(torch.std(torch.tensor(sigma_noise_losses)))
        all_sigma_losses.append(torch.std(torch.tensor(sigma_losses)))
        print(f"Epoch {epoch}: Loss = {total_loss:.4f}, Noise Loss = {total_noise_loss:.4f}, Bond Loss = {total_bond_loss:.4f}")

    plot_data(mu=all_bond_losses, sigma=all_sigma_bond_losses, color='b', title="Bond Loss")
    plot_data(mu=all_noise_losses, sigma=all_sigma_noise_losses, color='r', title="Noise Loss")
    plot_data(mu=all_losses, sigma=all_sigma_losses, color='g', title="Total Loss")

    plt.plot(all_bond_losses)
    plt.plot(all_noise_losses)
    plt.plot(all_losses)
    plt.legend(['Bond Loss', 'Noise Loss', 'Total Loss'])
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Epochs')
    plt.show()
    return model, conditioner


# Generation
def temperature_scaled_softmax(logits, temperature=1.0):
    logits = logits / temperature
    return torch.softmax(logits, dim=0)