import argparse import os import os.path as osp import pandas as pd import json from typing import List, Dict, Literal from glob import glob from copy import deepcopy import math import logging logging.warning = lambda *args, **kwargs: None import torch import torch.distributed as dist from .src.metrics import ( # quality calc_fvd_kvd_fad, calc_audio_quality_score, calc_video_quality_score, # alignment calc_imagebind_score, calc_clip_score, calc_clap_score, calc_cavp_score, # synchrony calc_av_align, calc_av_score, calc_desync_score, # audio-only calc_audio_score ) class JavisBenchCategory(object): def __init__(self, cfg: str): self.cfg = cfg with open(cfg, 'r') as f: data = json.load(f) category_matrix = [] for aspect in data: category_list = [] for category in aspect['categories']: category_list.append(category['title']) category_matrix.append(category_list) self.category_cfg = data self.aspect_list = [aspect['aspect'] for aspect in data] self.category_matrix = category_matrix self.aspect_num = len(self.category_matrix) class JavisEvaluator(object): def __init__(self, input_file: str, category_cfg: str, metrics: List[str], output_file: str, **kwargs): self.input_file = input_file self.df = pd.read_csv(input_file) eval_num = kwargs.pop('eval_num', None) if eval_num: print(f'Evaluate the first {eval_num} samples.') self.df = self.df.iloc[:eval_num] self.world_size = dist.get_world_size() self.rank = dist.get_rank() % torch.cuda.device_count() self.is_main_process = self.rank in [0, -1] if self.world_size > 1: if metrics != ['av-reward']: raise NotImplementedError('Only support reward evaluation in multi-gpu mode.') nums_per_rank = math.ceil(len(self.df) / self.world_size) start = self.rank * nums_per_rank self.df = self.df.iloc[start:start+nums_per_rank] if category_cfg and osp.isfile(category_cfg) and kwargs.get('verbose'): self.parse_aspect_dict(category_cfg) else: self.cat2indices = None self.output_file = output_file self.total_metrics = [ 'fvd+kvd+fad', # quality 'video-quality', # visual quality and motion quality 'audio-quality', # audio quality 'imagebind-score', 'cxxp-score', # semantic consistency 'av-align', # av alignment 'av-score', #'avh-score', 'javis-score' 'desync', # av synchrony # 'audio-score', ] if metrics == ['all']: metrics = self.total_metrics self.metrics = metrics self.metric2items = { # for general audio-video evaluation 'fvd+kvd+fad': ['fvd', 'kvd', 'fad'], 'video-quality': ['visual_quality', 'motion_quality'], 'audio-quality': ['audio_quality'], 'imagebind-score': ['ib_tv', 'ib_ta', 'ib_av'], 'cxxp-score': ['clip_score', 'clap_score', 'cavp_score'], 'av-align': ['av_align'], 'av-score': ['avh_score', 'javis_score'], 'desync': ['desync'], # for audio evaluation only 'audio-score': ['fad', 'quality', 'ib_ta', 'clap'], # for audio-video reward calculation 'av-reward': [ 'visual_quality', 'motion_quality', 'audio_quality', 'ib_tv', 'ib_ta', 'ib_av', 'desync_scores', # TODO: JavisScore? ] } self.exclude = kwargs.pop('exclude', []) self.eval_kwargs = kwargs self.gather_audio_video_pred() def parse_aspect_dict(self, category_cfg:str): self.category = JavisBenchCategory(category_cfg) cat2indices: List[List[List[int]]] = [] for ai in range(self.category.aspect_num): index_list = [[] for _ in range(len(self.category.category_matrix[ai]))] for pi, cat_str in enumerate(self.df[f'cat{ai}_ind'].tolist()): for ci in (cat_str.split(',') if isinstance(cat_str, str) else [cat_str]): index_list[int(ci)].append(pi) cat2indices.append(index_list) self.cat2indices = cat2indices @torch.no_grad() def __call__(self, *args, **kwds): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") exist_metrics = self.load_metric() prompt_list = self.df['text'].tolist() # video_prompt_list = self.df.get('video_text', self.df['text']).tolist() video_prompt_list = prompt_list audio_prompt_list = self.df.get('audio_text', self.df['text']).tolist() gt_video_list = self.df['path'].tolist() gt_audio_list = self.df.get('audio_path', self.df['path']).tolist() pred_video_list = self.df['pred_video_path'].tolist() pred_audio_list = self.df['pred_audio_path'].tolist() save_avalign_scores = self.eval_kwargs.get('save_avalign_scores', False) save_av_reward = False max_length_s = self.eval_kwargs.get('max_audio_len_s', 8.0) for metric in self.metrics: if not self.eval_kwargs.get('force_eval', False) and \ all(item in exist_metrics for item in self.metric2items[metric]): print(f'{metric} calculated. skip.') continue if metric in self.exclude: print(f'{metric} excluded. skip.') continue if metric == 'fvd+kvd+fad': mode = self.eval_kwargs.get('fvd_mode', 'vanilla') exist_metrics["fvd"], exist_metrics["kvd"], exist_metrics["fad"] = \ calc_fvd_kvd_fad(gt_video_list, pred_video_list, gt_audio_list, pred_audio_list, device, self.cat2indices, mode=mode, **self.eval_kwargs) self.write_metric(exist_metrics, metric) elif metric == 'video-quality': exist_metrics["visual_quality"], exist_metrics["motion_quality"] = \ calc_video_quality_score(pred_video_list, video_prompt_list, device, self.cat2indices) self.write_metric(exist_metrics, metric) elif metric == 'audio-quality': audio_sr = self.eval_kwargs.get('audio_sr', 16000) exist_metrics["audio_quality"] = calc_audio_quality_score( pred_audio_list, audio_prompt_list, max_length_s, device, self.cat2indices, audio_sr=audio_sr) self.write_metric(exist_metrics, metric) elif metric == 'imagebind-score': exist_metrics["ib_tv"], exist_metrics["ib_ta"], exist_metrics["ib_av"] = \ calc_imagebind_score(pred_video_list, pred_audio_list, video_prompt_list, audio_prompt_list, device, self.cat2indices) self.write_metric(exist_metrics, metric) elif metric == 'cxxp-score': if "clip_score" not in exist_metrics: exist_metrics["clip_score"] = calc_clip_score(pred_video_list, video_prompt_list, device, self.cat2indices) if "clap_score" not in exist_metrics: exist_metrics["clap_score"] = calc_clap_score(pred_audio_list, audio_prompt_list, device, self.cat2indices) if "cavp_score" not in exist_metrics and 'cavp_score' not in self.exclude: exist_metrics["cavp_score"] = calc_cavp_score(pred_video_list, pred_audio_list, device, self.cat2indices, cavp_config_path=self.eval_kwargs['cavp_config_path']) self.write_metric(exist_metrics, metric) elif metric == 'av-align': ret = calc_av_align(pred_video_list, pred_audio_list, self.cat2indices, return_score_list=save_avalign_scores) if save_avalign_scores: exist_metrics["av_align"], av_align_scores = ret assert len(av_align_scores) == len(self.df) self.df['av_align_scores'] = av_align_scores else: exist_metrics["av_align"] = ret self.write_metric(exist_metrics, metric) elif metric == 'av-score': ret = calc_av_score(pred_video_list, pred_audio_list, prompt_list, device, self.cat2indices, window_size_s=self.eval_kwargs.get("window_size_s", 2.0), window_overlap_s=self.eval_kwargs.get("window_overlap_s", 1.5), return_score_list=save_avalign_scores) if save_avalign_scores: exist_metrics["avh_score"], exist_metrics["javis_score"], avh_scores, javis_scores = ret assert len(avh_scores) == len(javis_scores) == len(self.df) self.df['avh_scores'] = avh_scores self.df['javis_scores'] = javis_scores else: exist_metrics["avh_score"], exist_metrics["javis_score"] = ret self.write_metric(exist_metrics, metric) elif metric == 'desync': ret = calc_desync_score(pred_video_list, pred_audio_list, max_length_s, device, self.cat2indices, return_score_list=save_avalign_scores) if save_avalign_scores: exist_metrics["desync"], desync_scores = ret assert len(desync_scores) == len(self.df) self.df['desync_scores'] = desync_scores else: exist_metrics["desync"] = ret self.write_metric(exist_metrics, metric) elif metric == 'audio-score': calc_audio_score(gt_audio_list, pred_audio_list, audio_prompt_list, device, exist_metrics=exist_metrics, **self.eval_kwargs) self.write_metric(exist_metrics, metric) elif metric == 'av-reward': if 'visual_quality' not in self.exclude: ret = calc_video_quality_score(pred_video_list, video_prompt_list, device, self.cat2indices, return_score_list=True) self.df['visual_quality'], self.df['motion_quality'] = ret[2:] # list save_av_reward = True exist_metrics['visual_quality'] = float(self.df['visual_quality'].mean()) exist_metrics['motion_quality'] = float(self.df['motion_quality'].mean()) if 'audio_quality' not in self.exclude: audio_sr = self.eval_kwargs.get('audio_sr', 16000) ret = calc_audio_quality_score(pred_audio_list, audio_prompt_list, max_length_s, device, self.cat2indices, audio_sr=audio_sr, return_score_list=True) self.df['audio_quality'] = ret[1] # list save_av_reward = True exist_metrics['audio_quality'] = float(self.df['audio_quality'].mean()) if 'ib_tv' not in self.exclude: ret = calc_imagebind_score(pred_video_list, pred_audio_list, video_prompt_list, audio_prompt_list, device, self.cat2indices, return_score_list=True) self.df["ib_tv"], self.df["ib_ta"], self.df["ib_av"] = ret[3:] # list save_av_reward = True exist_metrics['ib_tv'] = float(self.df['ib_tv'].mean()) exist_metrics['ib_ta'] = float(self.df['ib_ta'].mean()) exist_metrics['ib_av'] = float(self.df['ib_av'].mean()) if 'desync' not in self.exclude: ret = calc_desync_score(pred_video_list, pred_audio_list, max_length_s, device, self.cat2indices, return_score_list=True) self.df['desync_scores'] = ret[1] # list save_av_reward = True exist_metrics['desync_scores'] = float(self.df['desync_scores'].mean()) self.write_metric(exist_metrics, metric) if self.world_size > 1: # TODO: gather scores per rank for general eval obj_list = [None for _ in range(self.world_size)] dist.all_gather_object(obj_list, self.df) if self.is_main_process: self.df = pd.concat(obj_list) if self.is_main_process: if save_avalign_scores: save_path = osp.splitext(self.output_file)[0] + '_avalign.csv' self.df.to_csv(save_path, index=False) if save_av_reward: save_path = osp.splitext(self.output_file)[0] + '_avreward.csv' self.df.to_csv(save_path, index=False) def write_metric(self, metric:dict, metric_type:str): os.makedirs(osp.dirname(self.output_file), exist_ok=True) for item in self.metric2items[metric_type]: if item not in metric: print(f'{item}: NOT FOUND', end='; ') continue score = metric[item] if isinstance(score, dict): score = score['overall'] print(f'{item}: {score:.4f}', end='; ') print() with open(self.output_file, 'w+') as f: json.dump(metric, f, indent=4, ensure_ascii=False) def load_metric(self): metric = {} if osp.exists(self.output_file) and osp.getsize(self.output_file) > 0: with open(self.output_file, 'r') as f: metric = json.load(f) return metric def gather_audio_video_pred(self): infer_data_dir = self.eval_kwargs['infer_data_dir'] if not infer_data_dir: if not self.eval_kwargs['eval_gt']: assert 'pred_video_path' in self.df and 'pred_audio_path' in self.df else: assert 'fvd+kvd+fad' not in self.metrics self.df['pred_video_path'] = self.df['path'] self.df['pred_audio_path'] = self.df['audio_path'] return assert osp.isdir(infer_data_dir), infer_data_dir audio_only = self.metrics == ['audio-score'] sample_num = len(self.df) if audio_only: pred_audio_list = [f'{infer_data_dir}/sample_{i:04d}.wav' for i in range(sample_num)] pred_video_list = [''] * sample_num assert all(osp.exists(path) for path in pred_audio_list) self.df['text'] = self.df['audio_text'] self.df['path'] = self.df['audio_path'] else: pred_audio_list = [f'{infer_data_dir}/sample_{i:04d}.wav' for i in range(sample_num)] pred_video_list = [f'{infer_data_dir}/sample_{i:04d}.mp4' for i in range(sample_num)] assert all(osp.exists(path) for path in pred_audio_list) assert all(osp.exists(path) for path in pred_video_list) self.df['pred_video_path'] = pred_video_list self.df['pred_audio_path'] = pred_audio_list def run_eval(args): dist.init_process_group(backend="nccl") local_rank = int(os.environ["LOCAL_RANK"]) % torch.cuda.device_count() torch.cuda.set_device(local_rank) print(f"Start evaluation on {args.infer_data_dir}") evaluator = JavisEvaluator(**vars(args)) evaluator() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--input_file", type=str, default=None, help="path to input csv file", required=True) parser.add_argument("--infer_data_dir", type=str, default=None, help="directory to audio-video inference results") parser.add_argument("--output_file", type=str, default=None, help="path to output json file", required=True) parser.add_argument("--category_cfg", type=str, default='./eval/javisbench/configs/category.json') parser.add_argument("--metrics", type=str, nargs='+', default='all', help="metrics to calculate, default as `all`") parser.add_argument("--exclude", type=str, nargs='+', default=[], help="skipping specific metric calculation") parser.add_argument("--verbose", action='store_true', default=False, help="whether to present category-specific score list") parser.add_argument("--num_workers", type=int, default=8, help="number of workers for data loading") # parameters for evaluation parser.add_argument("--max_frames", type=int, default=24, help="size of the input video") parser.add_argument("--max_audio_len_s", type=float, default=None, help="maximum length of the audio") parser.add_argument("--video_fps", type=int, default=24, help="frame rate of the input video") parser.add_argument("--audio_sr", type=int, default=16000, help="sampling rate of the audio") parser.add_argument("--image_size", type=int, default=224, help="size of the input image") parser.add_argument("--eval_num", type=int, default=None, help="number of videos to evaluate") parser.add_argument("--fvd_avcache_path", type=str, default=None, help="path to the audio-video cache file for FVD/KVD/FAD evaluation") parser.add_argument("--fvd_mode", type=str, default='vanilla', choices=['vanilla', 'mmdiffusion'], help="mode of fvd calculation, `video` or `audio`") parser.add_argument("--force_eval", action='store_true', default=False, help="whether to evaluate scores even if existing") parser.add_argument("--eval_gt", action='store_true', default=False, help="whether to evaluate ground-truth audio-video pairs") # hyper-parameters for metrics parser.add_argument("--window_size_s", type=float, default=2.0, help="JavisScore window size") parser.add_argument("--window_overlap_s", type=float, default=1.5, help="JavisScore overlap size") parser.add_argument("--cavp_config_path", type=str, default='./eval/javisbench/configs/Stage1_CAVP.yaml', help="JavisScore overlap size") parser.add_argument("--save_avalign_scores", action='store_true', default=False, help="whether to return score list for AV-Align evaluation") args = parser.parse_args() os.makedirs('./checkpoints', exist_ok=True) os.makedirs(osp.dirname(args.output_file), exist_ok=True) cache_dir = f'{osp.dirname(args.output_file)}/cache/{osp.basename(osp.splitext(args.output_file)[0])}' setattr(args, "cache_dir", cache_dir) run_eval(args)