|
|
import os |
|
|
import os.path as osp |
|
|
from glob import glob |
|
|
import math |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
import plotly.graph_objects as go |
|
|
from tensorboard.backend.event_processing import event_accumulator |
|
|
|
|
|
|
|
|
|
|
|
plt.style.use("classic") |
|
|
|
|
|
plt.rcParams["font.family"] = "Times New Roman" |
|
|
|
|
|
|
|
|
def smooth(values, weight=0.6): |
|
|
""" |
|
|
指数滑动平均 (Exponential Moving Average, EMA) |
|
|
weight 越大越平滑,0~1之间 |
|
|
""" |
|
|
smoothed = [] |
|
|
last = values[0] |
|
|
for v in values: |
|
|
last = last * weight + (1 - weight) * v |
|
|
smoothed.append(last) |
|
|
return np.array(smoothed) |
|
|
|
|
|
|
|
|
def read_event_scalar(event_file, tag): |
|
|
""" |
|
|
从 event 文件中读取某个 tag 的标量数据 |
|
|
""" |
|
|
ea = event_accumulator.EventAccumulator(event_file) |
|
|
ea.Reload() |
|
|
if tag not in ea.Tags()["scalars"]: |
|
|
raise ValueError(f"Tag {tag} not found in {event_file}, available: {ea.Tags()['scalars']}") |
|
|
events = ea.Scalars(tag) |
|
|
steps = [e.step for e in events] |
|
|
values = [e.value for e in events] |
|
|
return steps, values |
|
|
|
|
|
|
|
|
def plot_multiple_events(event_files, tags, save_path, fontsize=18): |
|
|
""" |
|
|
在一张图中绘制多个 event 文件的同一个 tag 曲线 |
|
|
""" |
|
|
|
|
|
plt.figure(figsize=(6, 8)) |
|
|
for i, tag in enumerate(tags): |
|
|
plt.subplot(len(tags), 1, i + 1) |
|
|
max_val, max_cfg = -1, None |
|
|
for j, event_file in enumerate(event_files): |
|
|
label = osp.basename(osp.dirname(osp.dirname(event_file))) |
|
|
steps, values = read_event_scalar(event_file, tag) |
|
|
if max_val < np.max(values): |
|
|
max_val = np.max(values) |
|
|
max_cfg = label |
|
|
values = smooth(values, weight=0.6) |
|
|
plt.plot(steps, values, label=label) |
|
|
|
|
|
plt.title(tag, fontsize=fontsize) |
|
|
plt.xlabel("Seen Prompts", fontsize=fontsize-2) |
|
|
|
|
|
plt.xticks(steps[::10], [f'{step * 40 / 1000:.1f}k' for step in steps[::10]], fontsize=fontsize-2) |
|
|
plt.legend(loc='lower right', fontsize=fontsize-4) |
|
|
plt.grid(True) |
|
|
plt.tight_layout() |
|
|
print(f"Max {tag}: {max_val:.4f} in {max_cfg}") |
|
|
|
|
|
plt.savefig(save_path, dpi=400) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_dpo_implicit_acc_beta(): |
|
|
event_files = glob('exps/audio_video/ablation/dpo/beta/*/tensorboard/events.out.tfevents.*') |
|
|
event_files = sorted(event_files, key=lambda x: int(osp.basename(osp.dirname(osp.dirname(x))).split('_')[-1])) |
|
|
event_files = [path for path in event_files if '300/' not in path and '700/' not in path] |
|
|
tags = ["implicit_acc_audio", "implicit_acc_video"] |
|
|
plot_multiple_events(event_files, tags, save_path='./debug/plot/implicit_acc_beta.pdf') |
|
|
|
|
|
|
|
|
def plot_dpo_implicit_acc_lr(): |
|
|
event_files = [ |
|
|
glob(f'exps/audio_video/ablation/dpo/lr/lr_{lr}/tensorboard/events.out.tfevents.*')[0] for lr in ['1e-5', '5e-6', '1e-6'] |
|
|
] |
|
|
tags = ["implicit_acc_audio", "implicit_acc_video"] |
|
|
plot_multiple_events(event_files, tags, save_path='./debug/plot/implicit_acc_lr.pdf') |
|
|
|
|
|
|
|
|
def plot_perform_radar(): |
|
|
|
|
|
labels = ["V-Quality", "A-Quality", "TV-Align", "TA-Align", "AV-Align", "AV-Sync"] |
|
|
ranges = [(0, 3.0), (3.8, 5.2), (2, 3.19), (1, 1.98), (0, 2.8), (0.4, 1.35)] |
|
|
|
|
|
|
|
|
scores = { |
|
|
"JavisDiT": [1.02, 4.28, 2.6, 1.3, 1.8, 0.75], |
|
|
"UniVerse-1": [1.52, 4.09, 2.9, 1.1, 1.0, 0.80], |
|
|
"Ours": [2.47, 4.91, 3.0, 1.66, 1.92, 1.02], |
|
|
"Veo-3": [2.84, 5.11, 3.1, 1.9, 2.6, 1.28], |
|
|
} |
|
|
|
|
|
|
|
|
def normalize(values, ranges, labels): |
|
|
return [(v - ranges[l][0]) / (ranges[l][1] - ranges[l][0]) for l, v in enumerate(values)] |
|
|
|
|
|
|
|
|
N = len(labels) |
|
|
angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist() |
|
|
angles += angles[:1] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) |
|
|
fontsize = 16 |
|
|
|
|
|
ax.set_theta_offset(np.pi/2) |
|
|
ax.set_theta_direction(1) |
|
|
|
|
|
|
|
|
for name, values in scores.items(): |
|
|
norm_values = normalize(values, ranges, labels) |
|
|
norm_values += norm_values[:1] |
|
|
ax.plot(angles, norm_values, label=name, linewidth=1.5 if name in ['Ours', 'Veo-3'] else 1) |
|
|
ax.fill(angles, norm_values, alpha=0.1) |
|
|
|
|
|
|
|
|
ax.set_ylim(0, 1) |
|
|
ax.set_yticks([0.25,0.5,0.75]) |
|
|
ax.set_yticklabels([]) |
|
|
|
|
|
|
|
|
for i, angle in enumerate(angles[:-1]): |
|
|
rmin, rmax = ranges[i] |
|
|
ticks = np.linspace(rmin, rmax, 5)[1:-1] |
|
|
for t in ticks: |
|
|
rt = (t - rmin) / (rmax - rmin) |
|
|
ax.text(angle, rt+0.03, f"{t:.1f}", ha="center", va="center", fontsize=fontsize-2, color="gray") |
|
|
|
|
|
|
|
|
ax.set_xticks(angles[:-1]) |
|
|
labels_text = ax.set_xticklabels(labels, fontsize=fontsize, fontweight="bold") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, (label, angle) in enumerate(zip(labels_text, angles[:-1])): |
|
|
|
|
|
if i not in [0, 3]: |
|
|
label.set_position((angle, -0.12)) |
|
|
label.set_ha('center') |
|
|
label.set_va('center') |
|
|
|
|
|
ax.legend(loc="upper right", bbox_to_anchor=(1.17, 1.15), fontsize=fontsize-2) |
|
|
plt.tight_layout() |
|
|
plt.savefig('./debug/plot/perform_radar.pdf') |
|
|
|
|
|
|
|
|
def plot_ablation_bar(): |
|
|
|
|
|
tasks = ["FVD ↓", "FAD ↓" , "TV-IB ↑", "TA-IB ↑", "AV-IB ↑", "JavisScore ↑", "DeSync ↓"] |
|
|
models = [ |
|
|
"A-LoRA + AV-LoRA (r=64)", |
|
|
"A-noLoRA + AV-AttnLoRA (r=64)", |
|
|
"A-noLoRA + AV-LoRA (r=64)", |
|
|
"A-noLoRA + AV-LoRA (r=32)", |
|
|
"A-noLoRA + AV-LoRA (r=128)" |
|
|
] |
|
|
|
|
|
|
|
|
data = np.array([ |
|
|
[311.6, 223.1, 221.3, 222.5, 218.6], |
|
|
[ 5.80, 5.66, 5.51, 5.54, 5.60], |
|
|
[ 16.2, 28.1, 28.3, 28.3, 28.2], |
|
|
[ 14.2, 14.7, 15.3, 15.2, 14.7], |
|
|
[ 12.6, 18.6, 19.4, 19.2, 18.0], |
|
|
[ 9.1, 14.1, 15.1, 14.7, 14.3], |
|
|
[ 96.9, 95.8, 90.1, 90.0, 90.1], |
|
|
]) |
|
|
data_range = np.array([ |
|
|
[200, 350], |
|
|
[5, 6], |
|
|
[10, 35], |
|
|
[10, 18], |
|
|
[10, 25], |
|
|
[5, 20], |
|
|
[85, 100], |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
norm_data = (data - data_range[:, :1]) / (data_range[:, 1:] - data_range[:, :1]) |
|
|
|
|
|
|
|
|
x = np.arange(len(tasks)) |
|
|
bar_width = 0.15 |
|
|
fontsize = 18 |
|
|
|
|
|
colors = ["grey", "lightgrey", "royalblue", "cornflowerblue", "lightsteelblue"] |
|
|
hatches = [None, None, "//", None, None] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(11.2,4)) |
|
|
|
|
|
|
|
|
ax.set_axisbelow(True) |
|
|
ax.yaxis.grid(True, linestyle='--', linewidth=0.7, alpha=0.3, color='gray', zorder=0) |
|
|
|
|
|
for i, model in enumerate(models): |
|
|
bars = ax.bar( |
|
|
x + i*bar_width - (len(models)-1)/2*bar_width, |
|
|
norm_data[:, i], |
|
|
width=bar_width, |
|
|
label=model, |
|
|
color=colors[i], |
|
|
hatch=hatches[i], |
|
|
edgecolor="white", |
|
|
linewidth=0.5 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax.set_xticks(x) |
|
|
ax.set_xlim(x[0]-0.5, x[-1]+0.5) |
|
|
ax.set_xticklabels(tasks, fontsize=fontsize) |
|
|
ax.set_ylim(0, 1) |
|
|
ax.set_yticklabels([]) |
|
|
|
|
|
|
|
|
ax.legend( |
|
|
loc="upper center", bbox_to_anchor=(0.5, 1.3), |
|
|
ncol=int(math.ceil(len(models)/2)), frameon=False, fontsize=fontsize-3 |
|
|
) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.subplots_adjust(left=0.02, right=0.98, top=0.8, bottom=0.1) |
|
|
plt.savefig('./debug/plot/lora_cfg_bar.pdf') |
|
|
|
|
|
|
|
|
def plot_data_filtering(): |
|
|
|
|
|
labels = [ |
|
|
|
|
|
"Raw Videos", |
|
|
|
|
|
"Clean Videos", |
|
|
"Speech Videos", |
|
|
|
|
|
"HQ Videos", |
|
|
"Scoring Filters", |
|
|
|
|
|
"SFT", |
|
|
"DPO", |
|
|
] |
|
|
|
|
|
|
|
|
colors = [ |
|
|
"purple", "blue", "grey", "darkgreen", "grey", "gold", "brown" |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_x = [ |
|
|
0.01, |
|
|
0.33, 0.33, |
|
|
0.66, 0.66, |
|
|
0.99, 0.99, |
|
|
] |
|
|
node_y = [ |
|
|
0.3, |
|
|
0.2, 0.7, |
|
|
0.1, 0.45, |
|
|
0.1, 0.3, |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
links = { |
|
|
'source': [0, 0, 1, 1, 3, 3], |
|
|
'target': [1, 2, 3, 4, 5, 6], |
|
|
'value': [66, 34, 33, 33, 29, 4], |
|
|
|
|
|
'color': ['rgba(128, 128, 128, 0.3)'] * 6 |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fig = go.Figure(data=[go.Sankey( |
|
|
|
|
|
node=dict( |
|
|
pad=25, |
|
|
thickness=30, |
|
|
line=dict(color="black", width=0.5), |
|
|
label=labels, |
|
|
color=colors, |
|
|
|
|
|
x=node_x, |
|
|
y=node_y |
|
|
), |
|
|
link=dict( |
|
|
source=links['source'], |
|
|
target=links['target'], |
|
|
value=links['value'], |
|
|
color=links['color'] |
|
|
) |
|
|
)]) |
|
|
|
|
|
fig.update_layout( |
|
|
|
|
|
font_size=16, |
|
|
height=600, |
|
|
|
|
|
font=dict( |
|
|
family="Times New Roman", |
|
|
size=16, |
|
|
color="black" |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
fig.write_image('./debug/plot/data_filtering.pdf') |
|
|
|
|
|
|
|
|
def stat_audio_data_dist(): |
|
|
data_root = '/mnt/HithinkOmniSSD/user_workspace/liukai4/datasets/JavisDiT/train/audio' |
|
|
df = pd.read_csv(f'{data_root}/JavisDiT_train_audio_v1.csv') |
|
|
subsets = [] |
|
|
for subset in os.listdir(data_root): |
|
|
if osp.isdir(f'{data_root}/{subset}'): |
|
|
subsets.append(subset) |
|
|
stat = [] |
|
|
for subset in subsets: |
|
|
stat.append(df['audio_path'].str.contains(subset).sum()) |
|
|
stat = np.array(stat) |
|
|
indices = np.argsort(-stat) |
|
|
|
|
|
subsets = [subsets[i] for i in indices] |
|
|
stat = stat[indices] |
|
|
rel_stat = stat / stat.sum() |
|
|
for i in range(len(subsets)): |
|
|
print(f'{subsets[i]:<10} {stat[i]:>5} {rel_stat[i]:.2%}') |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plot_ablation_bar() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|