kaiw7's picture
Upload folder using huggingface_hub
e490e7e verified
import argparse
import os
import os.path as osp
import pandas as pd
import json
from typing import List, Dict, Literal
from glob import glob
from copy import deepcopy
import math
import logging
logging.warning = lambda *args, **kwargs: None
import torch
import torch.distributed as dist
from .src.metrics import (
# quality
calc_fvd_kvd_fad, calc_audio_quality_score, calc_video_quality_score,
# alignment
calc_imagebind_score, calc_clip_score, calc_clap_score, calc_cavp_score,
# synchrony
calc_av_align, calc_av_score, calc_desync_score,
# audio-only
calc_audio_score
)
class JavisBenchCategory(object):
def __init__(self, cfg: str):
self.cfg = cfg
with open(cfg, 'r') as f:
data = json.load(f)
category_matrix = []
for aspect in data:
category_list = []
for category in aspect['categories']:
category_list.append(category['title'])
category_matrix.append(category_list)
self.category_cfg = data
self.aspect_list = [aspect['aspect'] for aspect in data]
self.category_matrix = category_matrix
self.aspect_num = len(self.category_matrix)
class JavisEvaluator(object):
def __init__(self, input_file: str, category_cfg: str, metrics: List[str], output_file: str, **kwargs):
self.input_file = input_file
self.df = pd.read_csv(input_file)
eval_num = kwargs.pop('eval_num', None)
if eval_num:
print(f'Evaluate the first {eval_num} samples.')
self.df = self.df.iloc[:eval_num]
self.world_size = dist.get_world_size()
self.rank = dist.get_rank() % torch.cuda.device_count()
self.is_main_process = self.rank in [0, -1]
if self.world_size > 1:
if metrics != ['av-reward']:
raise NotImplementedError('Only support reward evaluation in multi-gpu mode.')
nums_per_rank = math.ceil(len(self.df) / self.world_size)
start = self.rank * nums_per_rank
self.df = self.df.iloc[start:start+nums_per_rank]
if category_cfg and osp.isfile(category_cfg) and kwargs.get('verbose'):
self.parse_aspect_dict(category_cfg)
else:
self.cat2indices = None
self.output_file = output_file
self.total_metrics = [
'fvd+kvd+fad', # quality
'video-quality', # visual quality and motion quality
'audio-quality', # audio quality
'imagebind-score', 'cxxp-score', # semantic consistency
'av-align', # av alignment
'av-score', #'avh-score', 'javis-score'
'desync', # av synchrony
# 'audio-score',
]
if metrics == ['all']:
metrics = self.total_metrics
self.metrics = metrics
self.metric2items = {
# for general audio-video evaluation
'fvd+kvd+fad': ['fvd', 'kvd', 'fad'],
'video-quality': ['visual_quality', 'motion_quality'],
'audio-quality': ['audio_quality'],
'imagebind-score': ['ib_tv', 'ib_ta', 'ib_av'],
'cxxp-score': ['clip_score', 'clap_score', 'cavp_score'],
'av-align': ['av_align'],
'av-score': ['avh_score', 'javis_score'],
'desync': ['desync'],
# for audio evaluation only
'audio-score': ['fad', 'quality', 'ib_ta', 'clap'],
# for audio-video reward calculation
'av-reward': [
'visual_quality', 'motion_quality', 'audio_quality',
'ib_tv', 'ib_ta', 'ib_av',
'desync_scores', # TODO: JavisScore?
]
}
self.exclude = kwargs.pop('exclude', [])
self.eval_kwargs = kwargs
self.gather_audio_video_pred()
def parse_aspect_dict(self, category_cfg:str):
self.category = JavisBenchCategory(category_cfg)
cat2indices: List[List[List[int]]] = []
for ai in range(self.category.aspect_num):
index_list = [[] for _ in range(len(self.category.category_matrix[ai]))]
for pi, cat_str in enumerate(self.df[f'cat{ai}_ind'].tolist()):
for ci in (cat_str.split(',') if isinstance(cat_str, str) else [cat_str]):
index_list[int(ci)].append(pi)
cat2indices.append(index_list)
self.cat2indices = cat2indices
@torch.no_grad()
def __call__(self, *args, **kwds):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
exist_metrics = self.load_metric()
prompt_list = self.df['text'].tolist()
# video_prompt_list = self.df.get('video_text', self.df['text']).tolist()
video_prompt_list = prompt_list
audio_prompt_list = self.df.get('audio_text', self.df['text']).tolist()
gt_video_list = self.df['path'].tolist()
gt_audio_list = self.df.get('audio_path', self.df['path']).tolist()
pred_video_list = self.df['pred_video_path'].tolist()
pred_audio_list = self.df['pred_audio_path'].tolist()
save_avalign_scores = self.eval_kwargs.get('save_avalign_scores', False)
save_av_reward = False
max_length_s = self.eval_kwargs.get('max_audio_len_s', 8.0)
for metric in self.metrics:
if not self.eval_kwargs.get('force_eval', False) and \
all(item in exist_metrics for item in self.metric2items[metric]):
print(f'{metric} calculated. skip.')
continue
if metric in self.exclude:
print(f'{metric} excluded. skip.')
continue
if metric == 'fvd+kvd+fad':
mode = self.eval_kwargs.get('fvd_mode', 'vanilla')
exist_metrics["fvd"], exist_metrics["kvd"], exist_metrics["fad"] = \
calc_fvd_kvd_fad(gt_video_list, pred_video_list, gt_audio_list, pred_audio_list,
device, self.cat2indices, mode=mode, **self.eval_kwargs)
self.write_metric(exist_metrics, metric)
elif metric == 'video-quality':
exist_metrics["visual_quality"], exist_metrics["motion_quality"] = \
calc_video_quality_score(pred_video_list, video_prompt_list, device, self.cat2indices)
self.write_metric(exist_metrics, metric)
elif metric == 'audio-quality':
audio_sr = self.eval_kwargs.get('audio_sr', 16000)
exist_metrics["audio_quality"] = calc_audio_quality_score(
pred_audio_list, audio_prompt_list, max_length_s, device, self.cat2indices, audio_sr=audio_sr)
self.write_metric(exist_metrics, metric)
elif metric == 'imagebind-score':
exist_metrics["ib_tv"], exist_metrics["ib_ta"], exist_metrics["ib_av"] = \
calc_imagebind_score(pred_video_list, pred_audio_list, video_prompt_list, audio_prompt_list,
device, self.cat2indices)
self.write_metric(exist_metrics, metric)
elif metric == 'cxxp-score':
if "clip_score" not in exist_metrics:
exist_metrics["clip_score"] = calc_clip_score(pred_video_list, video_prompt_list, device, self.cat2indices)
if "clap_score" not in exist_metrics:
exist_metrics["clap_score"] = calc_clap_score(pred_audio_list, audio_prompt_list, device, self.cat2indices)
if "cavp_score" not in exist_metrics and 'cavp_score' not in self.exclude:
exist_metrics["cavp_score"] = calc_cavp_score(pred_video_list, pred_audio_list, device, self.cat2indices,
cavp_config_path=self.eval_kwargs['cavp_config_path'])
self.write_metric(exist_metrics, metric)
elif metric == 'av-align':
ret = calc_av_align(pred_video_list, pred_audio_list, self.cat2indices, return_score_list=save_avalign_scores)
if save_avalign_scores:
exist_metrics["av_align"], av_align_scores = ret
assert len(av_align_scores) == len(self.df)
self.df['av_align_scores'] = av_align_scores
else:
exist_metrics["av_align"] = ret
self.write_metric(exist_metrics, metric)
elif metric == 'av-score':
ret = calc_av_score(pred_video_list, pred_audio_list, prompt_list, device, self.cat2indices,
window_size_s=self.eval_kwargs.get("window_size_s", 2.0),
window_overlap_s=self.eval_kwargs.get("window_overlap_s", 1.5),
return_score_list=save_avalign_scores)
if save_avalign_scores:
exist_metrics["avh_score"], exist_metrics["javis_score"], avh_scores, javis_scores = ret
assert len(avh_scores) == len(javis_scores) == len(self.df)
self.df['avh_scores'] = avh_scores
self.df['javis_scores'] = javis_scores
else:
exist_metrics["avh_score"], exist_metrics["javis_score"] = ret
self.write_metric(exist_metrics, metric)
elif metric == 'desync':
ret = calc_desync_score(pred_video_list, pred_audio_list, max_length_s, device,
self.cat2indices, return_score_list=save_avalign_scores)
if save_avalign_scores:
exist_metrics["desync"], desync_scores = ret
assert len(desync_scores) == len(self.df)
self.df['desync_scores'] = desync_scores
else:
exist_metrics["desync"] = ret
self.write_metric(exist_metrics, metric)
elif metric == 'audio-score':
calc_audio_score(gt_audio_list, pred_audio_list, audio_prompt_list, device,
exist_metrics=exist_metrics, **self.eval_kwargs)
self.write_metric(exist_metrics, metric)
elif metric == 'av-reward':
if 'visual_quality' not in self.exclude:
ret = calc_video_quality_score(pred_video_list, video_prompt_list, device,
self.cat2indices, return_score_list=True)
self.df['visual_quality'], self.df['motion_quality'] = ret[2:] # list
save_av_reward = True
exist_metrics['visual_quality'] = float(self.df['visual_quality'].mean())
exist_metrics['motion_quality'] = float(self.df['motion_quality'].mean())
if 'audio_quality' not in self.exclude:
audio_sr = self.eval_kwargs.get('audio_sr', 16000)
ret = calc_audio_quality_score(pred_audio_list, audio_prompt_list, max_length_s, device,
self.cat2indices, audio_sr=audio_sr, return_score_list=True)
self.df['audio_quality'] = ret[1] # list
save_av_reward = True
exist_metrics['audio_quality'] = float(self.df['audio_quality'].mean())
if 'ib_tv' not in self.exclude:
ret = calc_imagebind_score(pred_video_list, pred_audio_list, video_prompt_list, audio_prompt_list,
device, self.cat2indices, return_score_list=True)
self.df["ib_tv"], self.df["ib_ta"], self.df["ib_av"] = ret[3:] # list
save_av_reward = True
exist_metrics['ib_tv'] = float(self.df['ib_tv'].mean())
exist_metrics['ib_ta'] = float(self.df['ib_ta'].mean())
exist_metrics['ib_av'] = float(self.df['ib_av'].mean())
if 'desync' not in self.exclude:
ret = calc_desync_score(pred_video_list, pred_audio_list, max_length_s, device,
self.cat2indices, return_score_list=True)
self.df['desync_scores'] = ret[1] # list
save_av_reward = True
exist_metrics['desync_scores'] = float(self.df['desync_scores'].mean())
self.write_metric(exist_metrics, metric)
if self.world_size > 1:
# TODO: gather scores per rank for general eval
obj_list = [None for _ in range(self.world_size)]
dist.all_gather_object(obj_list, self.df)
if self.is_main_process:
self.df = pd.concat(obj_list)
if self.is_main_process:
if save_avalign_scores:
save_path = osp.splitext(self.output_file)[0] + '_avalign.csv'
self.df.to_csv(save_path, index=False)
if save_av_reward:
save_path = osp.splitext(self.output_file)[0] + '_avreward.csv'
self.df.to_csv(save_path, index=False)
def write_metric(self, metric:dict, metric_type:str):
os.makedirs(osp.dirname(self.output_file), exist_ok=True)
for item in self.metric2items[metric_type]:
if item not in metric:
print(f'{item}: NOT FOUND', end='; ')
continue
score = metric[item]
if isinstance(score, dict):
score = score['overall']
print(f'{item}: {score:.4f}', end='; ')
print()
with open(self.output_file, 'w+') as f:
json.dump(metric, f, indent=4, ensure_ascii=False)
def load_metric(self):
metric = {}
if osp.exists(self.output_file) and osp.getsize(self.output_file) > 0:
with open(self.output_file, 'r') as f:
metric = json.load(f)
return metric
def gather_audio_video_pred(self):
infer_data_dir = self.eval_kwargs['infer_data_dir']
if not infer_data_dir:
if not self.eval_kwargs['eval_gt']:
assert 'pred_video_path' in self.df and 'pred_audio_path' in self.df
else:
assert 'fvd+kvd+fad' not in self.metrics
self.df['pred_video_path'] = self.df['path']
self.df['pred_audio_path'] = self.df['audio_path']
return
assert osp.isdir(infer_data_dir), infer_data_dir
audio_only = self.metrics == ['audio-score']
sample_num = len(self.df)
if audio_only:
pred_audio_list = [f'{infer_data_dir}/sample_{i:04d}.wav' for i in range(sample_num)]
pred_video_list = [''] * sample_num
assert all(osp.exists(path) for path in pred_audio_list)
self.df['text'] = self.df['audio_text']
self.df['path'] = self.df['audio_path']
else:
pred_audio_list = [f'{infer_data_dir}/sample_{i:04d}.wav' for i in range(sample_num)]
pred_video_list = [f'{infer_data_dir}/sample_{i:04d}.mp4' for i in range(sample_num)]
assert all(osp.exists(path) for path in pred_audio_list)
assert all(osp.exists(path) for path in pred_video_list)
self.df['pred_video_path'] = pred_video_list
self.df['pred_audio_path'] = pred_audio_list
def run_eval(args):
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"]) % torch.cuda.device_count()
torch.cuda.set_device(local_rank)
print(f"Start evaluation on {args.infer_data_dir}")
evaluator = JavisEvaluator(**vars(args))
evaluator()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, default=None, help="path to input csv file", required=True)
parser.add_argument("--infer_data_dir", type=str, default=None, help="directory to audio-video inference results")
parser.add_argument("--output_file", type=str, default=None, help="path to output json file", required=True)
parser.add_argument("--category_cfg", type=str, default='./eval/javisbench/configs/category.json')
parser.add_argument("--metrics", type=str, nargs='+', default='all', help="metrics to calculate, default as `all`")
parser.add_argument("--exclude", type=str, nargs='+', default=[], help="skipping specific metric calculation")
parser.add_argument("--verbose", action='store_true', default=False, help="whether to present category-specific score list")
parser.add_argument("--num_workers", type=int, default=8, help="number of workers for data loading")
# parameters for evaluation
parser.add_argument("--max_frames", type=int, default=24, help="size of the input video")
parser.add_argument("--max_audio_len_s", type=float, default=None, help="maximum length of the audio")
parser.add_argument("--video_fps", type=int, default=24, help="frame rate of the input video")
parser.add_argument("--audio_sr", type=int, default=16000, help="sampling rate of the audio")
parser.add_argument("--image_size", type=int, default=224, help="size of the input image")
parser.add_argument("--eval_num", type=int, default=None, help="number of videos to evaluate")
parser.add_argument("--fvd_avcache_path", type=str, default=None, help="path to the audio-video cache file for FVD/KVD/FAD evaluation")
parser.add_argument("--fvd_mode", type=str, default='vanilla', choices=['vanilla', 'mmdiffusion'], help="mode of fvd calculation, `video` or `audio`")
parser.add_argument("--force_eval", action='store_true', default=False, help="whether to evaluate scores even if existing")
parser.add_argument("--eval_gt", action='store_true', default=False, help="whether to evaluate ground-truth audio-video pairs")
# hyper-parameters for metrics
parser.add_argument("--window_size_s", type=float, default=2.0, help="JavisScore window size")
parser.add_argument("--window_overlap_s", type=float, default=1.5, help="JavisScore overlap size")
parser.add_argument("--cavp_config_path", type=str, default='./eval/javisbench/configs/Stage1_CAVP.yaml', help="JavisScore overlap size")
parser.add_argument("--save_avalign_scores", action='store_true', default=False, help="whether to return score list for AV-Align evaluation")
args = parser.parse_args()
os.makedirs('./checkpoints', exist_ok=True)
os.makedirs(osp.dirname(args.output_file), exist_ok=True)
cache_dir = f'{osp.dirname(args.output_file)}/cache/{osp.basename(osp.splitext(args.output_file)[0])}'
setattr(args, "cache_dir", cache_dir)
run_eval(args)