| # Improved Quality, Synchrony, and Preference Alignment for Joint Audio-Video Generation | |
| This codebase is built upon [JavisDiT](https://github.com/JavisDiT/JavisDiT). Many thanks to their contribution. | |
| ## Installation | |
| For CUDA 12.1, you can install the dependencies with the following commands. | |
| ```bash | |
| # create a virtual env and activate (conda as an example) | |
| conda create -n javisdit python=3.10 | |
| conda activate javisdit | |
| # install torch, torchvision and xformers | |
| pip install -r requirements/requirements-cu121.txt | |
| # install ffpmeg | |
| conda install "ffmpeg<7" -c conda-forge -y | |
| # the default installation is for inference only | |
| pip install -v . | |
| # for development mode, `pip install -v -e .` | |
| # to skip dependencies, `pip install -v -e . --no-deps` | |
| # replace | |
| export PYTHON_SITE_PACKAGES=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") | |
| cp assets/src/pytorchvideo_augmentations.py ${PYTHON_SITE_PACKAGES}/pytorchvideo/transforms/augmentations.py | |
| cp assets/src/funasr_utils_load_utils.py ${PYTHON_SITE_PACKAGES}/funasr/utils/load_utils.py | |
| # (optional but recommended) install flash attention | |
| # set enable_flash_attn=False in config to disable flash attention | |
| pip install packaging ninja | |
| pip install flash-attn --no-build-isolation | |
| ``` | |
| ## Training | |
| ### Data Preparation | |
| In this project, we use a `.csv` file to manage all the training entries and their attributes for efficient training: | |
| | path | id | relpath | num_frames | height | width | aspect_ratio | fps | resolution | audio_path | audio_fps | text| | |
| | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | ---| | |
| | /path/to/xxx.mp4 | xxx | xxx.mp4 | 240 | 480 | 640 | 0.75 | 24 | 307200 | /path/to/xxx.wav | 16000 | yyy | | |
| The content of columns may vary in different training stages. The detailed instructions for each training stage can be found in [here](assets/docs/data.md). | |
| ### Stage1 - Audio Pre-Train | |
| In this stage, we perform audio pretraining to intialize the text-to-audio generation capability: | |
| ```bash | |
| torchrun --standalone --nproc_per_node 8 \ | |
| scripts/train.py \ | |
| configs/wan2.1/train/stage1_audio.py \ | |
| --data-path data/meta/audio/train_audio.csv | |
| ``` | |
| The resulting checkpoints will be saved at `runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc/model`. You can move the checkpoints to `exps/audio_pretrain/` for later use. | |
| ```bash | |
| mkdir -p exps/audio_pretrain | |
| mv runs/000-Wan2_1_T2V_1_3B/epoch049-global_step53000 exps/audio_pretrain/ | |
| ``` | |
| ### Stage2 - Audio-Video SFT | |
| In this stage, we perform finetuning for joint audio-video generation (with LoRA adaptation): | |
| ```bash | |
| torchrun --standalone --nproc_per_node 8 \ | |
| scripts/train_prior.py \ | |
| configs/wan2.1/train/stage2_audio_video.py \ | |
| --data-path data/meta/video/train_av_sft.csv | |
| ``` | |
| The resulting checkpoints will be saved at `runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc` with the `model` and `lora` subfolders. You can move the checkpoints to `exps/audio_video_sft/` for later use. | |
| ```bash | |
| mkdir -p exps/audio_video_sft | |
| mv runs/000-Wan2_1_T2V_1_3B/epoch001-global_step13000 exps/audio_video_sft/ | |
| ``` | |
| ### Stage3 - Audio-Video DPO | |
| In this stage, we perform DPO to align joint audio-video generation with human preference (reuse and update the LoRA parameters learned from the previous stage): | |
| ```bash | |
| torchrun --standalone --nproc_per_node 8 \ | |
| scripts/train.py \ | |
| configs/wan2.1/train/stage3_audio_video_dpo.py \ | |
| --data-path /data/meta/avdpo/train_av_dpo.csv | |
| ``` | |
| The resulting checkpoints will be also saved at `runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc` with the `model` and `lora` subfolders. You can move the checkpoints to `checkpoints/` for inference and evaluation. | |
| ```bash | |
| mv runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc checkpoints/your_model | |
| ``` | |
| ## Inference | |
| The basic command line inference is as follows: | |
| ```bash | |
| resolution=480p # or 240p | |
| num_frames=65 # 4s | |
| aspect_ratio="9:16" | |
| DATASET="JavisBench" # or JavisBench-mini | |
| prompt_path="data/eval/JavisBench/${DATASET}.csv" | |
| save_dir="samples/${DATASET}" | |
| model_path="checkpoints/your_model" | |
| ngpus=1 | |
| torchrun --standalone --nproc_per_node ${ngpus} \ | |
| scripts/inference.py \ | |
| configs/wan2.1/inference/sample.py \ | |
| --resolution ${resolution} --num-frames ${num_frames} --aspect-ratio ${aspect_ratio} \ | |
| --prompt-path ${prompt_path} --model-path ${model_path} \ | |
| --save-dir ${save_dir} --verbose 1 | |
| # (Optional, for evaluation) Extract audios from generated videos | |
| python -m tools.datasets.convert video ${save_dir} --output ${save_dir}/meta.csv | |
| python -m tools.datasets.datautil ${save_dir}/meta.csv --extract-audio --audio-sr 16000 | |
| rm -f ${save_dir}/meta*.csv | |
| ``` | |
| Setting `--verbose 2` will display the progress of a single diffusion process. And you can replace the `--prompt-path ${prompt_path}` with a single prompt to generate a single video, such as `--prompt "a beautiful waterfall"`. | |
| ## Evaluation | |
| ### Installation | |
| Install necessary packages: | |
| ```bash | |
| pip install -r requirements/requirements-eval.txt | |
| ``` | |
| Download the meta file and data of [JavisBench](https://huggingface.co/datasets/JavisDiT/JavisBench), and put them into `data/eval/`: | |
| ```bash | |
| cd /path/to/JavisDiT | |
| mkdir -p data/eval | |
| huggingface-cli download --repo-type dataset JavisDiT/JavisBench --local-dir data/eval/JavisBench | |
| ``` | |
| ### Evaluation on JavisBench or JavisBench-mini | |
| Run the following code and the results will be saved in `./evaluation_results`. For details please refer to the details of [JavisBench](eval/javisbench/README.md). | |
| ```bash | |
| MAX_FRAMES=16 | |
| IMAGE_SIZE=224 | |
| MAX_AUDIO_LEN_S=4.0 | |
| # Params to calculate JavisScore | |
| WINDOW_SIZE_S=2.0 | |
| WINDOW_OVERLAP_S=1.5 | |
| METRICS="all" | |
| RESULTS_DIR="./evaluation_results" | |
| DATASET="JavisBench" # or JavisBench-mini | |
| INPUT_FILE="data/eval/JavisBench/${DATASET}.csv" | |
| FVD_AVCACHE_PATH="data/eval/JavisBench/cache/fvd_fad/${DATASET}-vanilla-max4s.pt" | |
| INFER_DATA_DIR="samples/${DATASET}" | |
| python -m eval.javisbench.main \ | |
| --input_file "${INPUT_FILE}" \ | |
| --infer_data_dir "${INFER_DATA_DIR}" \ | |
| --output_file "${RESULTS_DIR}/${DATASET}.json" \ | |
| --max_frames ${MAX_FRAMES} \ | |
| --image_size ${IMAGE_SIZE} \ | |
| --max_audio_len_s ${MAX_AUDIO_LEN_S} \ | |
| --window_size_s ${WINDOW_SIZE_S} \ | |
| --window_overlap_s ${WINDOW_OVERLAP_S} \ | |
| --fvd_avcache_path ${FVD_AVCACHE_PATH} \ | |
| --metrics ${METRICS} | |
| ``` | |