Upload 208 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- LatentSync/.gitattributes +42 -0
- LatentSync/Colab.ipynb +48 -0
- LatentSync/LICENSE +201 -0
- LatentSync/README.md +109 -0
- LatentSync/app.py +267 -0
- LatentSync/apt.txt +2 -0
- LatentSync/assets/demo1_audio.wav +3 -0
- LatentSync/assets/demo1_video.mp4 +3 -0
- LatentSync/assets/demo2_audio.wav +3 -0
- LatentSync/assets/demo2_video.mp4 +3 -0
- LatentSync/assets/demo3_audio.wav +3 -0
- LatentSync/assets/demo3_video.mp4 +3 -0
- LatentSync/assets/framework.png +3 -0
- LatentSync/configs/audio.yaml +23 -0
- LatentSync/configs/scheduler_config.json +13 -0
- LatentSync/configs/syncnet/syncnet_16_latent.yaml +46 -0
- LatentSync/configs/syncnet/syncnet_16_pixel.yaml +45 -0
- LatentSync/configs/syncnet/syncnet_25_pixel.yaml +45 -0
- LatentSync/configs/unet/first_stage.yaml +103 -0
- LatentSync/configs/unet/second_stage.yaml +103 -0
- LatentSync/data_processing_pipeline.sh +9 -0
- LatentSync/eval/detectors/README.md +3 -0
- LatentSync/eval/detectors/__init__.py +1 -0
- LatentSync/eval/detectors/s3fd/__init__.py +61 -0
- LatentSync/eval/detectors/s3fd/box_utils.py +221 -0
- LatentSync/eval/detectors/s3fd/nets.py +174 -0
- LatentSync/eval/draw_syncnet_lines.py +70 -0
- LatentSync/eval/eval_fvd.py +96 -0
- LatentSync/eval/eval_sync_conf.py +77 -0
- LatentSync/eval/eval_sync_conf.sh +2 -0
- LatentSync/eval/eval_syncnet_acc.py +118 -0
- LatentSync/eval/eval_syncnet_acc.sh +3 -0
- LatentSync/eval/fvd.py +56 -0
- LatentSync/eval/hyper_iqa.py +343 -0
- LatentSync/eval/inference_videos.py +37 -0
- LatentSync/eval/syncnet/__init__.py +1 -0
- LatentSync/eval/syncnet/syncnet.py +113 -0
- LatentSync/eval/syncnet/syncnet_eval.py +220 -0
- LatentSync/eval/syncnet_detect.py +251 -0
- LatentSync/inference.sh +9 -0
- LatentSync/latentsync/data/syncnet_dataset.py +153 -0
- LatentSync/latentsync/data/unet_dataset.py +164 -0
- LatentSync/latentsync/models/attention.py +492 -0
- LatentSync/latentsync/models/motion_module.py +332 -0
- LatentSync/latentsync/models/resnet.py +234 -0
- LatentSync/latentsync/models/syncnet.py +233 -0
- LatentSync/latentsync/models/syncnet_wav2lip.py +90 -0
- LatentSync/latentsync/models/unet.py +528 -0
- LatentSync/latentsync/models/unet_blocks.py +903 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
LatentSync/assets/demo1_audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
LatentSync/assets/demo1_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
LatentSync/assets/demo2_audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
LatentSync/assets/demo2_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
LatentSync/assets/demo3_audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
LatentSync/assets/demo3_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
LatentSync/assets/framework.png filter=lfs diff=lfs merge=lfs -text
|
LatentSync/.gitattributes
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/demo1_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/demo2_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/demo3_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/demo1_audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/demo2_audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/demo3_audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/framework.png filter=lfs diff=lfs merge=lfs -text
|
LatentSync/Colab.ipynb
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": null,
|
| 22 |
+
"metadata": {
|
| 23 |
+
"id": "GCpiCPg8h5r3",
|
| 24 |
+
"collapsed": true
|
| 25 |
+
},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"#@title ⚙️ Cài đặt\n",
|
| 29 |
+
"!git clone https://huggingface.co/spaces/LTTEAM/LatentSync\n",
|
| 30 |
+
"%cd LatentSync\n",
|
| 31 |
+
"!pip install -r requirements.txt"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"source": [
|
| 37 |
+
"#@title ⌛️ Chạy\n",
|
| 38 |
+
"!python app.py"
|
| 39 |
+
],
|
| 40 |
+
"metadata": {
|
| 41 |
+
"id": "E9rJM-F4iVTU",
|
| 42 |
+
"collapsed": true
|
| 43 |
+
},
|
| 44 |
+
"execution_count": null,
|
| 45 |
+
"outputs": []
|
| 46 |
+
}
|
| 47 |
+
]
|
| 48 |
+
}
|
LatentSync/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
LatentSync/README.md
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: LatentSync - Đồng bộ môi bằng AI
|
| 3 |
+
emoji: 🎤
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.34.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# LatentSync - AI Lip Sync Technology
|
| 13 |
+
|
| 14 |
+
[](https://huggingface.co/spaces/LTTEAM/LatentSync)
|
| 15 |
+
[](https://www.facebook.com/groups/622526090937760)
|
| 16 |
+
|
| 17 |
+
## 🌟 Giới thiệu / Introduction
|
| 18 |
+
|
| 19 |
+
**LatentSync** là công nghệ đồng bộ hóa chuyển động môi sử dụng mô hình Diffusion tiên tiến, cho phép tạo chuyển động môi tự nhiên từ âm thanh đầu vào.
|
| 20 |
+
|
| 21 |
+
**LatentSync** is an advanced lip-sync technology using Diffusion models to generate natural lip movements from input audio.
|
| 22 |
+
|
| 23 |
+
## 🚀 Công nghệ / Technology
|
| 24 |
+
|
| 25 |
+
```python
|
| 26 |
+
# Kiến trúc chính / Core Architecture
|
| 27 |
+
pipeline = LipsyncPipeline(
|
| 28 |
+
vae=AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse"),
|
| 29 |
+
audio_encoder=Audio2Feature(model_path="whisper/small.pt"),
|
| 30 |
+
unet=UNet3DConditionModel.from_config(config),
|
| 31 |
+
scheduler=DDIMScheduler()
|
| 32 |
+
)
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
**Công nghệ chính / Key Technologies:**
|
| 36 |
+
- 🧠 UNet 3D Condition Model
|
| 37 |
+
- 🔊 Whisper Audio Encoder
|
| 38 |
+
- 🌀 Latent Diffusion
|
| 39 |
+
- ⚡ GPU Acceleration
|
| 40 |
+
|
| 41 |
+
## 📚 Cách sử dụng / How to Use
|
| 42 |
+
|
| 43 |
+
1. Tải lên video chứa khuôn mặt / Upload face video
|
| 44 |
+
2. Tải lên file âm thanh / Upload audio file
|
| 45 |
+
3. Nhấn "Chạy đồng bộ" / Click "Run Sync"
|
| 46 |
+
4. Chờ kết quả / Wait for processing
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
# Chạy local / Run locally
|
| 50 |
+
git clone https://huggingface.co/spaces/LTTEAM/LatentSync
|
| 51 |
+
cd LatentSync
|
| 52 |
+
pip install -r requirements.txt
|
| 53 |
+
python app.py
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## 🌐 Demo Online
|
| 57 |
+
|
| 58 |
+
[](https://huggingface.co/spaces/LTTEAM/LatentSync)
|
| 59 |
+
|
| 60 |
+
## 👨💻 Tác giả & Cộng đồng / Author & Community
|
| 61 |
+
|
| 62 |
+
**Tác giả / Author:**
|
| 63 |
+
[Lý Trần](https://github.com/lytrann)
|
| 64 |
+
|
| 65 |
+
**Cộng đồng / Community:**
|
| 66 |
+
[LTTEAM Facebook Group](https://www.facebook.com/groups/622526090937760)
|
| 67 |
+
|
| 68 |
+
**Hỗ trợ / Support:**
|
| 69 |
+
[](https://www.facebook.com/groups/622526090937760)
|
| 70 |
+
|
| 71 |
+
## 📜 Giấy phép / License
|
| 72 |
+
|
| 73 |
+
```text
|
| 74 |
+
Copyright 2023 LTTEAM
|
| 75 |
+
|
| 76 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 77 |
+
you may not use this file except in compliance with the License.
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
🔥 **Đóng góp / Contributions welcome!**
|
| 83 |
+
💡 **Báo lỗi / Report issues:** [Issues](https://huggingface.co/spaces/LTTEAM/LatentSync/discussions)
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## Key Features of this README:
|
| 87 |
+
|
| 88 |
+
1. **Bilingual Presentation**: Vietnamese and English for wider accessibility
|
| 89 |
+
2. **Technical Highlights**:
|
| 90 |
+
- Code block showing core architecture
|
| 91 |
+
- Badges for easy navigation
|
| 92 |
+
- Clear technology stack
|
| 93 |
+
|
| 94 |
+
3. **Community Focus**:
|
| 95 |
+
- Author information
|
| 96 |
+
- Community links
|
| 97 |
+
- Support channels
|
| 98 |
+
|
| 99 |
+
4. **Visual Appeal**:
|
| 100 |
+
- Emoji usage
|
| 101 |
+
- Colorful badges
|
| 102 |
+
- Clear section separation
|
| 103 |
+
|
| 104 |
+
5. **Practical Information**:
|
| 105 |
+
- Usage instructions
|
| 106 |
+
- Local setup guide
|
| 107 |
+
- License information
|
| 108 |
+
|
| 109 |
+
This README will display beautifully on your Hugging Face Space while effectively communicating all key information to users.
|
LatentSync/app.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import shutil
|
| 4 |
+
import tempfile
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import torch
|
| 7 |
+
from moviepy.editor import VideoFileClip
|
| 8 |
+
from pydub import AudioSegment
|
| 9 |
+
from huggingface_hub import snapshot_download
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
| 12 |
+
from latentsync.models.unet import UNet3DConditionModel
|
| 13 |
+
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
|
| 14 |
+
from latentsync.whisper.audio2feature import Audio2Feature
|
| 15 |
+
from accelerate.utils import set_seed
|
| 16 |
+
AUTHOR = "Lý Trần"
|
| 17 |
+
COMMUNITY = "LTTEAM"
|
| 18 |
+
COMMUNITY_LINK = "https://www.facebook.com/groups/622526090937760"
|
| 19 |
+
REPO_ID = "LTTEAM/Nhep_Mieng"
|
| 20 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 21 |
+
snapshot_download(
|
| 22 |
+
repo_id=REPO_ID,
|
| 23 |
+
local_dir="./checkpoints"
|
| 24 |
+
)
|
| 25 |
+
def process_video(input_video_path: str, temp_dir: str = "temp_video") -> str:
|
| 26 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 27 |
+
video = VideoFileClip(input_video_path)
|
| 28 |
+
output_path = os.path.join(
|
| 29 |
+
temp_dir, f"crop_{os.path.basename(input_video_path)}"
|
| 30 |
+
)
|
| 31 |
+
if video.duration > 10:
|
| 32 |
+
video = video.subclip(0, 10)
|
| 33 |
+
video.write_videofile(output_path, codec="libx264", audio_codec="aac")
|
| 34 |
+
return output_path
|
| 35 |
+
def process_audio(input_audio_path: str, temp_dir: str) -> str:
|
| 36 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 37 |
+
audio = AudioSegment.from_file(input_audio_path)
|
| 38 |
+
max_ms = 8 * 1000
|
| 39 |
+
if len(audio) > max_ms:
|
| 40 |
+
audio = audio[:max_ms]
|
| 41 |
+
output_path = os.path.join(temp_dir, "trim_audio.wav")
|
| 42 |
+
audio.export(output_path, format="wav")
|
| 43 |
+
return output_path
|
| 44 |
+
def main(video_path, audio_path, progress=gr.Progress(track_tqdm=True)):
|
| 45 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
+
print(f"[INFO] Chạy trên device: {device}")
|
| 47 |
+
space_id = os.environ.get("SPACE_ID", "")
|
| 48 |
+
is_shared_ui = "fffiloni/LatentSync" in space_id
|
| 49 |
+
|
| 50 |
+
# Nếu chạy trên shared UI, lưu tạm và cắt ngắn đầu vào
|
| 51 |
+
temp_dir = None
|
| 52 |
+
if is_shared_ui:
|
| 53 |
+
temp_dir = tempfile.mkdtemp()
|
| 54 |
+
video_path = process_video(video_path, temp_dir)
|
| 55 |
+
audio_path = process_audio(audio_path, temp_dir)
|
| 56 |
+
|
| 57 |
+
# Nạp cấu hình và checkpoint
|
| 58 |
+
config = OmegaConf.load("configs/unet/second_stage.yaml")
|
| 59 |
+
unet_ckpt = "checkpoints/latentsync_unet.pt"
|
| 60 |
+
scheduler = DDIMScheduler.from_pretrained("configs")
|
| 61 |
+
|
| 62 |
+
# Chọn Whisper model dựa vào cross_attention_dim
|
| 63 |
+
dim = config.model.cross_attention_dim
|
| 64 |
+
if dim == 768:
|
| 65 |
+
whisper_ckpt = "checkpoints/whisper/small.pt"
|
| 66 |
+
elif dim == 384:
|
| 67 |
+
whisper_ckpt = "checkpoints/whisper/tiny.pt"
|
| 68 |
+
else:
|
| 69 |
+
raise NotImplementedError("cross_attention_dim phải là 768 hoặc 384")
|
| 70 |
+
|
| 71 |
+
# Tạo audio encoder
|
| 72 |
+
audio_encoder = Audio2Feature(
|
| 73 |
+
model_path=whisper_ckpt,
|
| 74 |
+
device=device,
|
| 75 |
+
num_frames=config.data.num_frames
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Nạp VAE
|
| 79 |
+
vae = AutoencoderKL.from_pretrained(
|
| 80 |
+
"stabilityai/sd-vae-ft-mse",
|
| 81 |
+
torch_dtype=torch.float16 if device=="cuda" else torch.float32
|
| 82 |
+
)
|
| 83 |
+
vae.config.scaling_factor = 0.18215
|
| 84 |
+
vae.config.shift_factor = 0
|
| 85 |
+
|
| 86 |
+
# Nạp UNet
|
| 87 |
+
unet, _ = UNet3DConditionModel.from_pretrained(
|
| 88 |
+
OmegaConf.to_container(config.model),
|
| 89 |
+
unet_ckpt,
|
| 90 |
+
device=device
|
| 91 |
+
)
|
| 92 |
+
# Chuyển dtype phù hợp
|
| 93 |
+
unet = unet.to(dtype=torch.float16) if device=="cuda" else unet.to(dtype=torch.float32)
|
| 94 |
+
|
| 95 |
+
# Khởi tạo pipeline và chuyển lên device
|
| 96 |
+
pipeline = LipsyncPipeline(
|
| 97 |
+
vae=vae,
|
| 98 |
+
audio_encoder=audio_encoder,
|
| 99 |
+
unet=unet,
|
| 100 |
+
scheduler=scheduler,
|
| 101 |
+
).to(device)
|
| 102 |
+
|
| 103 |
+
# Thiết lập seed
|
| 104 |
+
seed = -1
|
| 105 |
+
if seed != -1:
|
| 106 |
+
set_seed(seed)
|
| 107 |
+
else:
|
| 108 |
+
torch.seed()
|
| 109 |
+
print(f"[INFO] Seed khởi tạo: {torch.initial_seed()}")
|
| 110 |
+
|
| 111 |
+
# Thực thi pipeline
|
| 112 |
+
output_id = uuid.uuid4().hex
|
| 113 |
+
result_path = f"output_{output_id}.mp4"
|
| 114 |
+
pipeline(
|
| 115 |
+
video_path=video_path,
|
| 116 |
+
audio_path=audio_path,
|
| 117 |
+
video_out_path=result_path,
|
| 118 |
+
video_mask_path=result_path.replace(".mp4", "_mask.mp4"),
|
| 119 |
+
num_frames=config.data.num_frames,
|
| 120 |
+
num_inference_steps=config.run.inference_steps,
|
| 121 |
+
guidance_scale=1.0,
|
| 122 |
+
weight_dtype=torch.float16 if device=="cuda" else torch.float32,
|
| 123 |
+
width=config.data.resolution,
|
| 124 |
+
height=config.data.resolution,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Dọn dẹp thư mục tạm nếu có
|
| 128 |
+
if is_shared_ui and temp_dir and os.path.exists(temp_dir):
|
| 129 |
+
shutil.rmtree(temp_dir)
|
| 130 |
+
|
| 131 |
+
return result_path
|
| 132 |
+
custom_css = """
|
| 133 |
+
:root {
|
| 134 |
+
--primary: #4CAF50;
|
| 135 |
+
--secondary: #8BC34A;
|
| 136 |
+
--accent: #FFC107;
|
| 137 |
+
--dark: #1E1E1E;
|
| 138 |
+
--light: #F5F5F5;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
body {
|
| 142 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 143 |
+
background-color: var(--light);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
div#main-container {
|
| 147 |
+
margin: 0 auto;
|
| 148 |
+
max-width: 900px;
|
| 149 |
+
background: white;
|
| 150 |
+
padding: 2rem;
|
| 151 |
+
border-radius: 12px;
|
| 152 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
h1 {
|
| 156 |
+
color: var(--primary);
|
| 157 |
+
border-bottom: 2px solid var(--secondary);
|
| 158 |
+
padding-bottom: 0.5rem;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
.gr-button {
|
| 162 |
+
background: var(--primary) !important;
|
| 163 |
+
color: white !important;
|
| 164 |
+
border: none !important;
|
| 165 |
+
padding: 0.75rem 1.5rem !important;
|
| 166 |
+
border-radius: 8px !important;
|
| 167 |
+
font-weight: 600 !important;
|
| 168 |
+
transition: all 0.3s ease !important;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
.gr-button:hover {
|
| 172 |
+
background: var(--secondary) !important;
|
| 173 |
+
transform: translateY(-2px) !important;
|
| 174 |
+
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
.gr-box {
|
| 178 |
+
border-radius: 8px !important;
|
| 179 |
+
border: 1px solid #e0e0e0 !important;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
footer {
|
| 183 |
+
text-align: center;
|
| 184 |
+
margin-top: 2rem;
|
| 185 |
+
color: #666;
|
| 186 |
+
font-size: 0.9rem;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
.example-container {
|
| 190 |
+
background: #f9f9f9;
|
| 191 |
+
padding: 1rem;
|
| 192 |
+
border-radius: 8px;
|
| 193 |
+
margin-top: 1rem;
|
| 194 |
+
}
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
with gr.Blocks(css=custom_css, title="LatentSync - Đồng bộ môi bằng AI") as demo:
|
| 198 |
+
with gr.Column(elem_id="main-container"):
|
| 199 |
+
# Header
|
| 200 |
+
gr.Markdown(f"# 🎤 LatentSync - Đồng bộ môi bằng AI")
|
| 201 |
+
gr.Markdown(f"**Tác giả:** {AUTHOR} | **Cộng đồng:** [{COMMUNITY}]({COMMUNITY_LINK})")
|
| 202 |
+
|
| 203 |
+
# Giới thiệu
|
| 204 |
+
with gr.Accordion("ℹ️ Giới thiệu ứng dụng", open=False):
|
| 205 |
+
gr.Markdown("""
|
| 206 |
+
Ứng dụng sử dụng mô hình AI tiên tiến để đồng bộ chuyển động môi trong video với âm thanh đầu vào.
|
| 207 |
+
|
| 208 |
+
**Cách sử dụng:**
|
| 209 |
+
1. Tải lên video chứa khuôn mặt cần đồng bộ môi
|
| 210 |
+
2. Tải lên file âm thanh hoặc ghi âm trực tiếp
|
| 211 |
+
3. Nhấn nút "Chạy đồng bộ" và chờ kết quả
|
| 212 |
+
|
| 213 |
+
**Lưu ý:**
|
| 214 |
+
- Video nên có khuôn mặt rõ ràng, ánh sáng tốt
|
| 215 |
+
- Âm thanh cần rõ ràng, không nhiễu
|
| 216 |
+
- Thời gian xử lý phụ thuộc vào độ dài video và cấu hình máy
|
| 217 |
+
""")
|
| 218 |
+
|
| 219 |
+
# Input/Output
|
| 220 |
+
with gr.Row():
|
| 221 |
+
with gr.Column():
|
| 222 |
+
gr.Markdown("### 🎥 Đầu vào")
|
| 223 |
+
video_in = gr.Video(label="Video đầu vào (MP4)", format="mp4", interactive=True)
|
| 224 |
+
audio_in = gr.Audio(label="Âm thanh đầu vào", type="filepath", interactive=True)
|
| 225 |
+
with gr.Row():
|
| 226 |
+
btn = gr.Button("🚀 Chạy đồng bộ", variant="primary")
|
| 227 |
+
clear_btn = gr.Button("🔄 Xóa hết")
|
| 228 |
+
|
| 229 |
+
with gr.Column():
|
| 230 |
+
gr.Markdown("### 📼 Kết quả")
|
| 231 |
+
video_out = gr.Video(label="Video kết quả", interactive=False)
|
| 232 |
+
with gr.Row():
|
| 233 |
+
download_btn = gr.Button("💾 Tải xuống")
|
| 234 |
+
|
| 235 |
+
# Ví dụ mẫu - ĐÃ SỬA LỖI Ở ĐÂY
|
| 236 |
+
with gr.Accordion("📂 Ví dụ mẫu", open=True):
|
| 237 |
+
gr.Examples(
|
| 238 |
+
examples=[
|
| 239 |
+
["assets/demo1_video.mp4", "assets/demo1_audio.wav"],
|
| 240 |
+
["assets/demo2_video.mp4", "assets/demo2_audio.wav"],
|
| 241 |
+
["assets/demo3_video.mp4", "assets/demo3_audio.wav"],
|
| 242 |
+
],
|
| 243 |
+
inputs=[video_in, audio_in],
|
| 244 |
+
outputs=[video_out],
|
| 245 |
+
fn=main, # Thêm hàm xử lý chính
|
| 246 |
+
label="Nhấn vào ví dụ để thử ngay",
|
| 247 |
+
# cache_examples=True # Đã bỏ cache_examples vì cần thêm cấu hình
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Footer
|
| 251 |
+
gr.Markdown(f"""
|
| 252 |
+
---
|
| 253 |
+
*Ứng dụng được phát triển bởi {AUTHOR} và cộng đồng {COMMUNITY}*
|
| 254 |
+
*Phiên bản 1.0 | [Tham gia nhóm]({COMMUNITY_LINK}) để cập nhật và hỗ trợ*
|
| 255 |
+
""")
|
| 256 |
+
|
| 257 |
+
# Xử lý sự kiện
|
| 258 |
+
btn.click(fn=main, inputs=[video_in, audio_in], outputs=[video_out])
|
| 259 |
+
clear_btn.click(lambda: [None, None, None], outputs=[video_in, audio_in, video_out])
|
| 260 |
+
download_btn.click(lambda x: x, inputs=[video_out], outputs=[video_out])
|
| 261 |
+
|
| 262 |
+
demo.launch(
|
| 263 |
+
share=True,
|
| 264 |
+
show_error=True,
|
| 265 |
+
server_name="0.0.0.0",
|
| 266 |
+
server_port=int(os.environ.get("PORT", 7860)),
|
| 267 |
+
)
|
LatentSync/apt.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
libgl1
|
LatentSync/assets/demo1_audio.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f7dd2112dbdc0bece5ee6f26553a4867b65740eb53187ecc8b1a3c1618b2405
|
| 3 |
+
size 307278
|
LatentSync/assets/demo1_video.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed2dd1e2001aa605c3f2d77672a8af4ed55e427a85c55d408adfc3d5076bc872
|
| 3 |
+
size 1240008
|
LatentSync/assets/demo2_audio.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4916574779fb975367ddcb1f12597205ae15ea8aeaa61ad92d2c1c5d719c3607
|
| 3 |
+
size 634958
|
LatentSync/assets/demo2_video.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c3f10288e0642e587a95c0040e6966f8f6b7e003c3a17b572f72472b896d8ff
|
| 3 |
+
size 1772492
|
LatentSync/assets/demo3_audio.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5014567b03d35e0bd813a3725c3129a99722497cd4cf8e036d2c304530ea432
|
| 3 |
+
size 593998
|
LatentSync/assets/demo3_video.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cfa177b2a44f7809f606285c120e270d526caa50d708ec95e0f614d220970e0f
|
| 3 |
+
size 2112370
|
LatentSync/assets/framework.png
ADDED
|
Git LFS Details
|
LatentSync/configs/audio.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
audio:
|
| 2 |
+
num_mels: 80 # Number of mel-spectrogram channels and local conditioning dimensionality
|
| 3 |
+
rescale: true # Whether to rescale audio prior to preprocessing
|
| 4 |
+
rescaling_max: 0.9 # Rescaling value
|
| 5 |
+
use_lws:
|
| 6 |
+
false # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
|
| 7 |
+
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
|
| 8 |
+
# Does not work if n_ffit is not multiple of hop_size!!
|
| 9 |
+
n_fft: 800 # Extra window size is filled with 0 paddings to match this parameter
|
| 10 |
+
hop_size: 200 # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
|
| 11 |
+
win_size: 800 # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
|
| 12 |
+
sample_rate: 16000 # 16000Hz (corresponding to librispeech) (sox --i <filename>)
|
| 13 |
+
frame_shift_ms: null
|
| 14 |
+
signal_normalization: true
|
| 15 |
+
allow_clipping_in_normalization: true
|
| 16 |
+
symmetric_mels: true
|
| 17 |
+
max_abs_value: 4.0
|
| 18 |
+
preemphasize: true # whether to apply filter
|
| 19 |
+
preemphasis: 0.97 # filter coefficient.
|
| 20 |
+
min_level_db: -100
|
| 21 |
+
ref_level_db: 20
|
| 22 |
+
fmin: 55
|
| 23 |
+
fmax: 7600
|
LatentSync/configs/scheduler_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "DDIMScheduler",
|
| 3 |
+
"_diffusers_version": "0.6.0.dev0",
|
| 4 |
+
"beta_end": 0.012,
|
| 5 |
+
"beta_schedule": "scaled_linear",
|
| 6 |
+
"beta_start": 0.00085,
|
| 7 |
+
"clip_sample": false,
|
| 8 |
+
"num_train_timesteps": 1000,
|
| 9 |
+
"set_alpha_to_one": false,
|
| 10 |
+
"steps_offset": 1,
|
| 11 |
+
"trained_betas": null,
|
| 12 |
+
"skip_prk_steps": true
|
| 13 |
+
}
|
LatentSync/configs/syncnet/syncnet_16_latent.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
audio_encoder: # input (1, 80, 52)
|
| 3 |
+
in_channels: 1
|
| 4 |
+
block_out_channels: [32, 64, 128, 256, 512, 1024]
|
| 5 |
+
downsample_factors: [[2, 1], 2, 2, 2, 2, [2, 3]]
|
| 6 |
+
attn_blocks: [0, 0, 0, 0, 0, 0]
|
| 7 |
+
dropout: 0.0
|
| 8 |
+
visual_encoder: # input (64, 32, 32)
|
| 9 |
+
in_channels: 64
|
| 10 |
+
block_out_channels: [64, 128, 256, 256, 512, 1024]
|
| 11 |
+
downsample_factors: [2, 2, 2, 1, 2, 2]
|
| 12 |
+
attn_blocks: [0, 0, 0, 0, 0, 0]
|
| 13 |
+
dropout: 0.0
|
| 14 |
+
|
| 15 |
+
ckpt:
|
| 16 |
+
resume_ckpt_path: ""
|
| 17 |
+
inference_ckpt_path: ""
|
| 18 |
+
save_ckpt_steps: 2500
|
| 19 |
+
|
| 20 |
+
data:
|
| 21 |
+
train_output_dir: output/syncnet
|
| 22 |
+
num_val_samples: 1200
|
| 23 |
+
batch_size: 120 # 40
|
| 24 |
+
num_workers: 11 # 11
|
| 25 |
+
latent_space: true
|
| 26 |
+
num_frames: 16
|
| 27 |
+
resolution: 256
|
| 28 |
+
train_fileslist: ""
|
| 29 |
+
train_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/train
|
| 30 |
+
val_fileslist: ""
|
| 31 |
+
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
|
| 32 |
+
audio_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel_new
|
| 33 |
+
lower_half: false
|
| 34 |
+
pretrained_audio_model_path: facebook/wav2vec2-large-xlsr-53
|
| 35 |
+
audio_sample_rate: 16000
|
| 36 |
+
video_fps: 25
|
| 37 |
+
|
| 38 |
+
optimizer:
|
| 39 |
+
lr: 1e-5
|
| 40 |
+
max_grad_norm: 1.0
|
| 41 |
+
|
| 42 |
+
run:
|
| 43 |
+
max_train_steps: 10000000
|
| 44 |
+
validation_steps: 2500
|
| 45 |
+
mixed_precision_training: true
|
| 46 |
+
seed: 42
|
LatentSync/configs/syncnet/syncnet_16_pixel.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
audio_encoder: # input (1, 80, 52)
|
| 3 |
+
in_channels: 1
|
| 4 |
+
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
|
| 5 |
+
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
|
| 6 |
+
attn_blocks: [0, 0, 0, 0, 0, 0, 0]
|
| 7 |
+
dropout: 0.0
|
| 8 |
+
visual_encoder: # input (48, 128, 256)
|
| 9 |
+
in_channels: 48
|
| 10 |
+
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
|
| 11 |
+
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
| 12 |
+
attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
|
| 13 |
+
dropout: 0.0
|
| 14 |
+
|
| 15 |
+
ckpt:
|
| 16 |
+
resume_ckpt_path: ""
|
| 17 |
+
inference_ckpt_path: checkpoints/latentsync_syncnet.pt
|
| 18 |
+
save_ckpt_steps: 2500
|
| 19 |
+
|
| 20 |
+
data:
|
| 21 |
+
train_output_dir: debug/syncnet
|
| 22 |
+
num_val_samples: 2048
|
| 23 |
+
batch_size: 128 # 128
|
| 24 |
+
num_workers: 11 # 11
|
| 25 |
+
latent_space: false
|
| 26 |
+
num_frames: 16
|
| 27 |
+
resolution: 256
|
| 28 |
+
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/all_data_v6.txt
|
| 29 |
+
train_data_dir: ""
|
| 30 |
+
val_fileslist: ""
|
| 31 |
+
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
|
| 32 |
+
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel_new
|
| 33 |
+
lower_half: true
|
| 34 |
+
audio_sample_rate: 16000
|
| 35 |
+
video_fps: 25
|
| 36 |
+
|
| 37 |
+
optimizer:
|
| 38 |
+
lr: 1e-5
|
| 39 |
+
max_grad_norm: 1.0
|
| 40 |
+
|
| 41 |
+
run:
|
| 42 |
+
max_train_steps: 10000000
|
| 43 |
+
validation_steps: 2500
|
| 44 |
+
mixed_precision_training: true
|
| 45 |
+
seed: 42
|
LatentSync/configs/syncnet/syncnet_25_pixel.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
audio_encoder: # input (1, 80, 80)
|
| 3 |
+
in_channels: 1
|
| 4 |
+
block_out_channels: [64, 128, 256, 256, 512, 1024]
|
| 5 |
+
downsample_factors: [2, 2, 2, 2, 2, 2]
|
| 6 |
+
dropout: 0.0
|
| 7 |
+
visual_encoder: # input (75, 128, 256)
|
| 8 |
+
in_channels: 75
|
| 9 |
+
block_out_channels: [128, 128, 256, 256, 512, 512, 1024, 1024]
|
| 10 |
+
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
| 11 |
+
dropout: 0.0
|
| 12 |
+
|
| 13 |
+
ckpt:
|
| 14 |
+
resume_ckpt_path: ""
|
| 15 |
+
inference_ckpt_path: ""
|
| 16 |
+
save_ckpt_steps: 2500
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
train_output_dir: debug/syncnet
|
| 20 |
+
num_val_samples: 2048
|
| 21 |
+
batch_size: 64 # 64
|
| 22 |
+
num_workers: 11 # 11
|
| 23 |
+
latent_space: false
|
| 24 |
+
num_frames: 25
|
| 25 |
+
resolution: 256
|
| 26 |
+
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/hdtf_vox_avatars_ads_affine.txt
|
| 27 |
+
# /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/hdtf_voxceleb_avatars_affine.txt
|
| 28 |
+
train_data_dir: ""
|
| 29 |
+
val_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/vox_affine_val.txt
|
| 30 |
+
# /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/voxceleb_val.txt
|
| 31 |
+
val_data_dir: ""
|
| 32 |
+
audio_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
| 33 |
+
lower_half: true
|
| 34 |
+
pretrained_audio_model_path: facebook/wav2vec2-large-xlsr-53
|
| 35 |
+
audio_sample_rate: 16000
|
| 36 |
+
video_fps: 25
|
| 37 |
+
|
| 38 |
+
optimizer:
|
| 39 |
+
lr: 1e-5
|
| 40 |
+
max_grad_norm: 1.0
|
| 41 |
+
|
| 42 |
+
run:
|
| 43 |
+
max_train_steps: 10000000
|
| 44 |
+
mixed_precision_training: true
|
| 45 |
+
seed: 42
|
LatentSync/configs/unet/first_stage.yaml
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
syncnet_config_path: configs/syncnet/syncnet_16_pixel.yaml
|
| 3 |
+
train_output_dir: debug/unet
|
| 4 |
+
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/all_data_v6.txt
|
| 5 |
+
train_data_dir: ""
|
| 6 |
+
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/whisper_new
|
| 7 |
+
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel_new
|
| 8 |
+
|
| 9 |
+
val_video_path: assets/demo1_video.mp4
|
| 10 |
+
val_audio_path: assets/demo1_audio.wav
|
| 11 |
+
batch_size: 8 # 8
|
| 12 |
+
num_workers: 11 # 11
|
| 13 |
+
num_frames: 16
|
| 14 |
+
resolution: 256
|
| 15 |
+
mask: fix_mask
|
| 16 |
+
audio_sample_rate: 16000
|
| 17 |
+
video_fps: 25
|
| 18 |
+
|
| 19 |
+
ckpt:
|
| 20 |
+
resume_ckpt_path: checkpoints/latentsync_unet.pt
|
| 21 |
+
save_ckpt_steps: 5000
|
| 22 |
+
|
| 23 |
+
run:
|
| 24 |
+
pixel_space_supervise: false
|
| 25 |
+
use_syncnet: false
|
| 26 |
+
sync_loss_weight: 0.05 # 1/283
|
| 27 |
+
perceptual_loss_weight: 0.1 # 0.1
|
| 28 |
+
recon_loss_weight: 1 # 1
|
| 29 |
+
guidance_scale: 1.0 # 1.5 or 1.0
|
| 30 |
+
trepa_loss_weight: 10
|
| 31 |
+
inference_steps: 20
|
| 32 |
+
seed: 1247
|
| 33 |
+
use_mixed_noise: true
|
| 34 |
+
mixed_noise_alpha: 1 # 1
|
| 35 |
+
mixed_precision_training: true
|
| 36 |
+
enable_gradient_checkpointing: false
|
| 37 |
+
enable_xformers_memory_efficient_attention: true
|
| 38 |
+
max_train_steps: 10000000
|
| 39 |
+
max_train_epochs: -1
|
| 40 |
+
|
| 41 |
+
optimizer:
|
| 42 |
+
lr: 1e-5
|
| 43 |
+
scale_lr: false
|
| 44 |
+
max_grad_norm: 1.0
|
| 45 |
+
lr_scheduler: constant
|
| 46 |
+
lr_warmup_steps: 0
|
| 47 |
+
|
| 48 |
+
model:
|
| 49 |
+
act_fn: silu
|
| 50 |
+
add_audio_layer: true
|
| 51 |
+
custom_audio_layer: false
|
| 52 |
+
audio_condition_method: cross_attn # Choose between [cross_attn, group_norm]
|
| 53 |
+
attention_head_dim: 8
|
| 54 |
+
block_out_channels: [320, 640, 1280, 1280]
|
| 55 |
+
center_input_sample: false
|
| 56 |
+
cross_attention_dim: 384
|
| 57 |
+
down_block_types:
|
| 58 |
+
[
|
| 59 |
+
"CrossAttnDownBlock3D",
|
| 60 |
+
"CrossAttnDownBlock3D",
|
| 61 |
+
"CrossAttnDownBlock3D",
|
| 62 |
+
"DownBlock3D",
|
| 63 |
+
]
|
| 64 |
+
mid_block_type: UNetMidBlock3DCrossAttn
|
| 65 |
+
up_block_types:
|
| 66 |
+
[
|
| 67 |
+
"UpBlock3D",
|
| 68 |
+
"CrossAttnUpBlock3D",
|
| 69 |
+
"CrossAttnUpBlock3D",
|
| 70 |
+
"CrossAttnUpBlock3D",
|
| 71 |
+
]
|
| 72 |
+
downsample_padding: 1
|
| 73 |
+
flip_sin_to_cos: true
|
| 74 |
+
freq_shift: 0
|
| 75 |
+
in_channels: 13 # 49
|
| 76 |
+
layers_per_block: 2
|
| 77 |
+
mid_block_scale_factor: 1
|
| 78 |
+
norm_eps: 1e-5
|
| 79 |
+
norm_num_groups: 32
|
| 80 |
+
out_channels: 4 # 16
|
| 81 |
+
sample_size: 64
|
| 82 |
+
resnet_time_scale_shift: default # Choose between [default, scale_shift]
|
| 83 |
+
unet_use_cross_frame_attention: false
|
| 84 |
+
unet_use_temporal_attention: false
|
| 85 |
+
|
| 86 |
+
# Actually we don't use the motion module in the final version of LatentSync
|
| 87 |
+
# When we started the project, we used the codebase of AnimateDiff and tried motion module, the results are poor
|
| 88 |
+
# We decied to leave the code here for possible future usage
|
| 89 |
+
use_motion_module: false
|
| 90 |
+
motion_module_resolutions: [1, 2, 4, 8]
|
| 91 |
+
motion_module_mid_block: false
|
| 92 |
+
motion_module_decoder_only: false
|
| 93 |
+
motion_module_type: Vanilla
|
| 94 |
+
motion_module_kwargs:
|
| 95 |
+
num_attention_heads: 8
|
| 96 |
+
num_transformer_block: 1
|
| 97 |
+
attention_block_types:
|
| 98 |
+
- Temporal_Self
|
| 99 |
+
- Temporal_Self
|
| 100 |
+
temporal_position_encoding: true
|
| 101 |
+
temporal_position_encoding_max_len: 16
|
| 102 |
+
temporal_attention_dim_div: 1
|
| 103 |
+
zero_initialize: true
|
LatentSync/configs/unet/second_stage.yaml
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
syncnet_config_path: configs/syncnet/syncnet_16_pixel.yaml
|
| 3 |
+
train_output_dir: debug/unet
|
| 4 |
+
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/all_data_v6.txt
|
| 5 |
+
train_data_dir: ""
|
| 6 |
+
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/whisper_new
|
| 7 |
+
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel_new
|
| 8 |
+
|
| 9 |
+
val_video_path: assets/demo1_video.mp4
|
| 10 |
+
val_audio_path: assets/demo1_audio.wav
|
| 11 |
+
batch_size: 2 # 8
|
| 12 |
+
num_workers: 11 # 11
|
| 13 |
+
num_frames: 16
|
| 14 |
+
resolution: 256
|
| 15 |
+
mask: fix_mask
|
| 16 |
+
audio_sample_rate: 16000
|
| 17 |
+
video_fps: 25
|
| 18 |
+
|
| 19 |
+
ckpt:
|
| 20 |
+
resume_ckpt_path: checkpoints/latentsync_unet.pt
|
| 21 |
+
save_ckpt_steps: 5000
|
| 22 |
+
|
| 23 |
+
run:
|
| 24 |
+
pixel_space_supervise: true
|
| 25 |
+
use_syncnet: true
|
| 26 |
+
sync_loss_weight: 0.05 # 1/283
|
| 27 |
+
perceptual_loss_weight: 0.1 # 0.1
|
| 28 |
+
recon_loss_weight: 1 # 1
|
| 29 |
+
guidance_scale: 1.0 # 1.5 or 1.0
|
| 30 |
+
trepa_loss_weight: 10
|
| 31 |
+
inference_steps: 20
|
| 32 |
+
seed: 1247
|
| 33 |
+
use_mixed_noise: true
|
| 34 |
+
mixed_noise_alpha: 1 # 1
|
| 35 |
+
mixed_precision_training: true
|
| 36 |
+
enable_gradient_checkpointing: false
|
| 37 |
+
enable_xformers_memory_efficient_attention: true
|
| 38 |
+
max_train_steps: 10000000
|
| 39 |
+
max_train_epochs: -1
|
| 40 |
+
|
| 41 |
+
optimizer:
|
| 42 |
+
lr: 1e-5
|
| 43 |
+
scale_lr: false
|
| 44 |
+
max_grad_norm: 1.0
|
| 45 |
+
lr_scheduler: constant
|
| 46 |
+
lr_warmup_steps: 0
|
| 47 |
+
|
| 48 |
+
model:
|
| 49 |
+
act_fn: silu
|
| 50 |
+
add_audio_layer: true
|
| 51 |
+
custom_audio_layer: false
|
| 52 |
+
audio_condition_method: cross_attn # Choose between [cross_attn, group_norm]
|
| 53 |
+
attention_head_dim: 8
|
| 54 |
+
block_out_channels: [320, 640, 1280, 1280]
|
| 55 |
+
center_input_sample: false
|
| 56 |
+
cross_attention_dim: 384
|
| 57 |
+
down_block_types:
|
| 58 |
+
[
|
| 59 |
+
"CrossAttnDownBlock3D",
|
| 60 |
+
"CrossAttnDownBlock3D",
|
| 61 |
+
"CrossAttnDownBlock3D",
|
| 62 |
+
"DownBlock3D",
|
| 63 |
+
]
|
| 64 |
+
mid_block_type: UNetMidBlock3DCrossAttn
|
| 65 |
+
up_block_types:
|
| 66 |
+
[
|
| 67 |
+
"UpBlock3D",
|
| 68 |
+
"CrossAttnUpBlock3D",
|
| 69 |
+
"CrossAttnUpBlock3D",
|
| 70 |
+
"CrossAttnUpBlock3D",
|
| 71 |
+
]
|
| 72 |
+
downsample_padding: 1
|
| 73 |
+
flip_sin_to_cos: true
|
| 74 |
+
freq_shift: 0
|
| 75 |
+
in_channels: 13 # 49
|
| 76 |
+
layers_per_block: 2
|
| 77 |
+
mid_block_scale_factor: 1
|
| 78 |
+
norm_eps: 1e-5
|
| 79 |
+
norm_num_groups: 32
|
| 80 |
+
out_channels: 4 # 16
|
| 81 |
+
sample_size: 64
|
| 82 |
+
resnet_time_scale_shift: default # Choose between [default, scale_shift]
|
| 83 |
+
unet_use_cross_frame_attention: false
|
| 84 |
+
unet_use_temporal_attention: false
|
| 85 |
+
|
| 86 |
+
# Actually we don't use the motion module in the final version of LatentSync
|
| 87 |
+
# When we started the project, we used the codebase of AnimateDiff and tried motion module, the results are poor
|
| 88 |
+
# We decied to leave the code here for possible future usage
|
| 89 |
+
use_motion_module: false
|
| 90 |
+
motion_module_resolutions: [1, 2, 4, 8]
|
| 91 |
+
motion_module_mid_block: false
|
| 92 |
+
motion_module_decoder_only: false
|
| 93 |
+
motion_module_type: Vanilla
|
| 94 |
+
motion_module_kwargs:
|
| 95 |
+
num_attention_heads: 8
|
| 96 |
+
num_transformer_block: 1
|
| 97 |
+
attention_block_types:
|
| 98 |
+
- Temporal_Self
|
| 99 |
+
- Temporal_Self
|
| 100 |
+
temporal_position_encoding: true
|
| 101 |
+
temporal_position_encoding_max_len: 16
|
| 102 |
+
temporal_attention_dim_div: 1
|
| 103 |
+
zero_initialize: true
|
LatentSync/data_processing_pipeline.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
python -m preprocess.data_processing_pipeline \
|
| 4 |
+
--total_num_workers 20 \
|
| 5 |
+
--per_gpu_num_workers 20 \
|
| 6 |
+
--resolution 256 \
|
| 7 |
+
--sync_conf_threshold 3 \
|
| 8 |
+
--temp_dir temp \
|
| 9 |
+
--input_dir /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/raw
|
LatentSync/eval/detectors/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Face detector
|
| 2 |
+
|
| 3 |
+
This face detector is adapted from `https://github.com/cs-giung/face-detection-pytorch`.
|
LatentSync/eval/detectors/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .s3fd import S3FD
|
LatentSync/eval/detectors/s3fd/__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from .nets import S3FDNet
|
| 7 |
+
from .box_utils import nms_
|
| 8 |
+
|
| 9 |
+
PATH_WEIGHT = 'checkpoints/auxiliary/sfd_face.pth'
|
| 10 |
+
img_mean = np.array([104., 117., 123.])[:, np.newaxis, np.newaxis].astype('float32')
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class S3FD():
|
| 14 |
+
|
| 15 |
+
def __init__(self, device='cuda'):
|
| 16 |
+
|
| 17 |
+
tstamp = time.time()
|
| 18 |
+
self.device = device
|
| 19 |
+
|
| 20 |
+
print('[S3FD] loading with', self.device)
|
| 21 |
+
self.net = S3FDNet(device=self.device).to(self.device)
|
| 22 |
+
state_dict = torch.load(PATH_WEIGHT, map_location=self.device)
|
| 23 |
+
self.net.load_state_dict(state_dict)
|
| 24 |
+
self.net.eval()
|
| 25 |
+
print('[S3FD] finished loading (%.4f sec)' % (time.time() - tstamp))
|
| 26 |
+
|
| 27 |
+
def detect_faces(self, image, conf_th=0.8, scales=[1]):
|
| 28 |
+
|
| 29 |
+
w, h = image.shape[1], image.shape[0]
|
| 30 |
+
|
| 31 |
+
bboxes = np.empty(shape=(0, 5))
|
| 32 |
+
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
for s in scales:
|
| 35 |
+
scaled_img = cv2.resize(image, dsize=(0, 0), fx=s, fy=s, interpolation=cv2.INTER_LINEAR)
|
| 36 |
+
|
| 37 |
+
scaled_img = np.swapaxes(scaled_img, 1, 2)
|
| 38 |
+
scaled_img = np.swapaxes(scaled_img, 1, 0)
|
| 39 |
+
scaled_img = scaled_img[[2, 1, 0], :, :]
|
| 40 |
+
scaled_img = scaled_img.astype('float32')
|
| 41 |
+
scaled_img -= img_mean
|
| 42 |
+
scaled_img = scaled_img[[2, 1, 0], :, :]
|
| 43 |
+
x = torch.from_numpy(scaled_img).unsqueeze(0).to(self.device)
|
| 44 |
+
y = self.net(x)
|
| 45 |
+
|
| 46 |
+
detections = y.data
|
| 47 |
+
scale = torch.Tensor([w, h, w, h])
|
| 48 |
+
|
| 49 |
+
for i in range(detections.size(1)):
|
| 50 |
+
j = 0
|
| 51 |
+
while detections[0, i, j, 0] > conf_th:
|
| 52 |
+
score = detections[0, i, j, 0]
|
| 53 |
+
pt = (detections[0, i, j, 1:] * scale).cpu().numpy()
|
| 54 |
+
bbox = (pt[0], pt[1], pt[2], pt[3], score)
|
| 55 |
+
bboxes = np.vstack((bboxes, bbox))
|
| 56 |
+
j += 1
|
| 57 |
+
|
| 58 |
+
keep = nms_(bboxes, 0.1)
|
| 59 |
+
bboxes = bboxes[keep]
|
| 60 |
+
|
| 61 |
+
return bboxes
|
LatentSync/eval/detectors/s3fd/box_utils.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from itertools import product as product
|
| 3 |
+
import torch
|
| 4 |
+
from torch.autograd import Function
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def nms_(dets, thresh):
|
| 9 |
+
"""
|
| 10 |
+
Courtesy of Ross Girshick
|
| 11 |
+
[https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py]
|
| 12 |
+
"""
|
| 13 |
+
x1 = dets[:, 0]
|
| 14 |
+
y1 = dets[:, 1]
|
| 15 |
+
x2 = dets[:, 2]
|
| 16 |
+
y2 = dets[:, 3]
|
| 17 |
+
scores = dets[:, 4]
|
| 18 |
+
|
| 19 |
+
areas = (x2 - x1) * (y2 - y1)
|
| 20 |
+
order = scores.argsort()[::-1]
|
| 21 |
+
|
| 22 |
+
keep = []
|
| 23 |
+
while order.size > 0:
|
| 24 |
+
i = order[0]
|
| 25 |
+
keep.append(int(i))
|
| 26 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
| 27 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
| 28 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
| 29 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
| 30 |
+
|
| 31 |
+
w = np.maximum(0.0, xx2 - xx1)
|
| 32 |
+
h = np.maximum(0.0, yy2 - yy1)
|
| 33 |
+
inter = w * h
|
| 34 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
| 35 |
+
|
| 36 |
+
inds = np.where(ovr <= thresh)[0]
|
| 37 |
+
order = order[inds + 1]
|
| 38 |
+
|
| 39 |
+
return np.array(keep).astype(np.int32)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def decode(loc, priors, variances):
|
| 43 |
+
"""Decode locations from predictions using priors to undo
|
| 44 |
+
the encoding we did for offset regression at train time.
|
| 45 |
+
Args:
|
| 46 |
+
loc (tensor): location predictions for loc layers,
|
| 47 |
+
Shape: [num_priors,4]
|
| 48 |
+
priors (tensor): Prior boxes in center-offset form.
|
| 49 |
+
Shape: [num_priors,4].
|
| 50 |
+
variances: (list[float]) Variances of priorboxes
|
| 51 |
+
Return:
|
| 52 |
+
decoded bounding box predictions
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
boxes = torch.cat((
|
| 56 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
| 57 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
| 58 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
| 59 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 60 |
+
return boxes
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def nms(boxes, scores, overlap=0.5, top_k=200):
|
| 64 |
+
"""Apply non-maximum suppression at test time to avoid detecting too many
|
| 65 |
+
overlapping bounding boxes for a given object.
|
| 66 |
+
Args:
|
| 67 |
+
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
|
| 68 |
+
scores: (tensor) The class predscores for the img, Shape:[num_priors].
|
| 69 |
+
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
|
| 70 |
+
top_k: (int) The Maximum number of box preds to consider.
|
| 71 |
+
Return:
|
| 72 |
+
The indices of the kept boxes with respect to num_priors.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
keep = scores.new(scores.size(0)).zero_().long()
|
| 76 |
+
if boxes.numel() == 0:
|
| 77 |
+
return keep, 0
|
| 78 |
+
x1 = boxes[:, 0]
|
| 79 |
+
y1 = boxes[:, 1]
|
| 80 |
+
x2 = boxes[:, 2]
|
| 81 |
+
y2 = boxes[:, 3]
|
| 82 |
+
area = torch.mul(x2 - x1, y2 - y1)
|
| 83 |
+
v, idx = scores.sort(0) # sort in ascending order
|
| 84 |
+
# I = I[v >= 0.01]
|
| 85 |
+
idx = idx[-top_k:] # indices of the top-k largest vals
|
| 86 |
+
xx1 = boxes.new()
|
| 87 |
+
yy1 = boxes.new()
|
| 88 |
+
xx2 = boxes.new()
|
| 89 |
+
yy2 = boxes.new()
|
| 90 |
+
w = boxes.new()
|
| 91 |
+
h = boxes.new()
|
| 92 |
+
|
| 93 |
+
# keep = torch.Tensor()
|
| 94 |
+
count = 0
|
| 95 |
+
while idx.numel() > 0:
|
| 96 |
+
i = idx[-1] # index of current largest val
|
| 97 |
+
# keep.append(i)
|
| 98 |
+
keep[count] = i
|
| 99 |
+
count += 1
|
| 100 |
+
if idx.size(0) == 1:
|
| 101 |
+
break
|
| 102 |
+
idx = idx[:-1] # remove kept element from view
|
| 103 |
+
# load bboxes of next highest vals
|
| 104 |
+
with warnings.catch_warnings():
|
| 105 |
+
# Ignore UserWarning within this block
|
| 106 |
+
warnings.simplefilter("ignore", category=UserWarning)
|
| 107 |
+
torch.index_select(x1, 0, idx, out=xx1)
|
| 108 |
+
torch.index_select(y1, 0, idx, out=yy1)
|
| 109 |
+
torch.index_select(x2, 0, idx, out=xx2)
|
| 110 |
+
torch.index_select(y2, 0, idx, out=yy2)
|
| 111 |
+
# store element-wise max with next highest score
|
| 112 |
+
xx1 = torch.clamp(xx1, min=x1[i])
|
| 113 |
+
yy1 = torch.clamp(yy1, min=y1[i])
|
| 114 |
+
xx2 = torch.clamp(xx2, max=x2[i])
|
| 115 |
+
yy2 = torch.clamp(yy2, max=y2[i])
|
| 116 |
+
w.resize_as_(xx2)
|
| 117 |
+
h.resize_as_(yy2)
|
| 118 |
+
w = xx2 - xx1
|
| 119 |
+
h = yy2 - yy1
|
| 120 |
+
# check sizes of xx1 and xx2.. after each iteration
|
| 121 |
+
w = torch.clamp(w, min=0.0)
|
| 122 |
+
h = torch.clamp(h, min=0.0)
|
| 123 |
+
inter = w * h
|
| 124 |
+
# IoU = i / (area(a) + area(b) - i)
|
| 125 |
+
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
|
| 126 |
+
union = (rem_areas - inter) + area[i]
|
| 127 |
+
IoU = inter / union # store result in iou
|
| 128 |
+
# keep only elements with an IoU <= overlap
|
| 129 |
+
idx = idx[IoU.le(overlap)]
|
| 130 |
+
return keep, count
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class Detect(object):
|
| 134 |
+
|
| 135 |
+
def __init__(self, num_classes=2,
|
| 136 |
+
top_k=750, nms_thresh=0.3, conf_thresh=0.05,
|
| 137 |
+
variance=[0.1, 0.2], nms_top_k=5000):
|
| 138 |
+
|
| 139 |
+
self.num_classes = num_classes
|
| 140 |
+
self.top_k = top_k
|
| 141 |
+
self.nms_thresh = nms_thresh
|
| 142 |
+
self.conf_thresh = conf_thresh
|
| 143 |
+
self.variance = variance
|
| 144 |
+
self.nms_top_k = nms_top_k
|
| 145 |
+
|
| 146 |
+
def forward(self, loc_data, conf_data, prior_data):
|
| 147 |
+
|
| 148 |
+
num = loc_data.size(0)
|
| 149 |
+
num_priors = prior_data.size(0)
|
| 150 |
+
|
| 151 |
+
conf_preds = conf_data.view(num, num_priors, self.num_classes).transpose(2, 1)
|
| 152 |
+
batch_priors = prior_data.view(-1, num_priors, 4).expand(num, num_priors, 4)
|
| 153 |
+
batch_priors = batch_priors.contiguous().view(-1, 4)
|
| 154 |
+
|
| 155 |
+
decoded_boxes = decode(loc_data.view(-1, 4), batch_priors, self.variance)
|
| 156 |
+
decoded_boxes = decoded_boxes.view(num, num_priors, 4)
|
| 157 |
+
|
| 158 |
+
output = torch.zeros(num, self.num_classes, self.top_k, 5)
|
| 159 |
+
|
| 160 |
+
for i in range(num):
|
| 161 |
+
boxes = decoded_boxes[i].clone()
|
| 162 |
+
conf_scores = conf_preds[i].clone()
|
| 163 |
+
|
| 164 |
+
for cl in range(1, self.num_classes):
|
| 165 |
+
c_mask = conf_scores[cl].gt(self.conf_thresh)
|
| 166 |
+
scores = conf_scores[cl][c_mask]
|
| 167 |
+
|
| 168 |
+
if scores.dim() == 0:
|
| 169 |
+
continue
|
| 170 |
+
l_mask = c_mask.unsqueeze(1).expand_as(boxes)
|
| 171 |
+
boxes_ = boxes[l_mask].view(-1, 4)
|
| 172 |
+
ids, count = nms(boxes_, scores, self.nms_thresh, self.nms_top_k)
|
| 173 |
+
count = count if count < self.top_k else self.top_k
|
| 174 |
+
|
| 175 |
+
output[i, cl, :count] = torch.cat((scores[ids[:count]].unsqueeze(1), boxes_[ids[:count]]), 1)
|
| 176 |
+
|
| 177 |
+
return output
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class PriorBox(object):
|
| 181 |
+
|
| 182 |
+
def __init__(self, input_size, feature_maps,
|
| 183 |
+
variance=[0.1, 0.2],
|
| 184 |
+
min_sizes=[16, 32, 64, 128, 256, 512],
|
| 185 |
+
steps=[4, 8, 16, 32, 64, 128],
|
| 186 |
+
clip=False):
|
| 187 |
+
|
| 188 |
+
super(PriorBox, self).__init__()
|
| 189 |
+
|
| 190 |
+
self.imh = input_size[0]
|
| 191 |
+
self.imw = input_size[1]
|
| 192 |
+
self.feature_maps = feature_maps
|
| 193 |
+
|
| 194 |
+
self.variance = variance
|
| 195 |
+
self.min_sizes = min_sizes
|
| 196 |
+
self.steps = steps
|
| 197 |
+
self.clip = clip
|
| 198 |
+
|
| 199 |
+
def forward(self):
|
| 200 |
+
mean = []
|
| 201 |
+
for k, fmap in enumerate(self.feature_maps):
|
| 202 |
+
feath = fmap[0]
|
| 203 |
+
featw = fmap[1]
|
| 204 |
+
for i, j in product(range(feath), range(featw)):
|
| 205 |
+
f_kw = self.imw / self.steps[k]
|
| 206 |
+
f_kh = self.imh / self.steps[k]
|
| 207 |
+
|
| 208 |
+
cx = (j + 0.5) / f_kw
|
| 209 |
+
cy = (i + 0.5) / f_kh
|
| 210 |
+
|
| 211 |
+
s_kw = self.min_sizes[k] / self.imw
|
| 212 |
+
s_kh = self.min_sizes[k] / self.imh
|
| 213 |
+
|
| 214 |
+
mean += [cx, cy, s_kw, s_kh]
|
| 215 |
+
|
| 216 |
+
output = torch.FloatTensor(mean).view(-1, 4)
|
| 217 |
+
|
| 218 |
+
if self.clip:
|
| 219 |
+
output.clamp_(max=1, min=0)
|
| 220 |
+
|
| 221 |
+
return output
|
LatentSync/eval/detectors/s3fd/nets.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.nn.init as init
|
| 5 |
+
from .box_utils import Detect, PriorBox
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class L2Norm(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, n_channels, scale):
|
| 11 |
+
super(L2Norm, self).__init__()
|
| 12 |
+
self.n_channels = n_channels
|
| 13 |
+
self.gamma = scale or None
|
| 14 |
+
self.eps = 1e-10
|
| 15 |
+
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
| 16 |
+
self.reset_parameters()
|
| 17 |
+
|
| 18 |
+
def reset_parameters(self):
|
| 19 |
+
init.constant_(self.weight, self.gamma)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
| 23 |
+
x = torch.div(x, norm)
|
| 24 |
+
out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
|
| 25 |
+
return out
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class S3FDNet(nn.Module):
|
| 29 |
+
|
| 30 |
+
def __init__(self, device='cuda'):
|
| 31 |
+
super(S3FDNet, self).__init__()
|
| 32 |
+
self.device = device
|
| 33 |
+
|
| 34 |
+
self.vgg = nn.ModuleList([
|
| 35 |
+
nn.Conv2d(3, 64, 3, 1, padding=1),
|
| 36 |
+
nn.ReLU(inplace=True),
|
| 37 |
+
nn.Conv2d(64, 64, 3, 1, padding=1),
|
| 38 |
+
nn.ReLU(inplace=True),
|
| 39 |
+
nn.MaxPool2d(2, 2),
|
| 40 |
+
|
| 41 |
+
nn.Conv2d(64, 128, 3, 1, padding=1),
|
| 42 |
+
nn.ReLU(inplace=True),
|
| 43 |
+
nn.Conv2d(128, 128, 3, 1, padding=1),
|
| 44 |
+
nn.ReLU(inplace=True),
|
| 45 |
+
nn.MaxPool2d(2, 2),
|
| 46 |
+
|
| 47 |
+
nn.Conv2d(128, 256, 3, 1, padding=1),
|
| 48 |
+
nn.ReLU(inplace=True),
|
| 49 |
+
nn.Conv2d(256, 256, 3, 1, padding=1),
|
| 50 |
+
nn.ReLU(inplace=True),
|
| 51 |
+
nn.Conv2d(256, 256, 3, 1, padding=1),
|
| 52 |
+
nn.ReLU(inplace=True),
|
| 53 |
+
nn.MaxPool2d(2, 2, ceil_mode=True),
|
| 54 |
+
|
| 55 |
+
nn.Conv2d(256, 512, 3, 1, padding=1),
|
| 56 |
+
nn.ReLU(inplace=True),
|
| 57 |
+
nn.Conv2d(512, 512, 3, 1, padding=1),
|
| 58 |
+
nn.ReLU(inplace=True),
|
| 59 |
+
nn.Conv2d(512, 512, 3, 1, padding=1),
|
| 60 |
+
nn.ReLU(inplace=True),
|
| 61 |
+
nn.MaxPool2d(2, 2),
|
| 62 |
+
|
| 63 |
+
nn.Conv2d(512, 512, 3, 1, padding=1),
|
| 64 |
+
nn.ReLU(inplace=True),
|
| 65 |
+
nn.Conv2d(512, 512, 3, 1, padding=1),
|
| 66 |
+
nn.ReLU(inplace=True),
|
| 67 |
+
nn.Conv2d(512, 512, 3, 1, padding=1),
|
| 68 |
+
nn.ReLU(inplace=True),
|
| 69 |
+
nn.MaxPool2d(2, 2),
|
| 70 |
+
|
| 71 |
+
nn.Conv2d(512, 1024, 3, 1, padding=6, dilation=6),
|
| 72 |
+
nn.ReLU(inplace=True),
|
| 73 |
+
nn.Conv2d(1024, 1024, 1, 1),
|
| 74 |
+
nn.ReLU(inplace=True),
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
self.L2Norm3_3 = L2Norm(256, 10)
|
| 78 |
+
self.L2Norm4_3 = L2Norm(512, 8)
|
| 79 |
+
self.L2Norm5_3 = L2Norm(512, 5)
|
| 80 |
+
|
| 81 |
+
self.extras = nn.ModuleList([
|
| 82 |
+
nn.Conv2d(1024, 256, 1, 1),
|
| 83 |
+
nn.Conv2d(256, 512, 3, 2, padding=1),
|
| 84 |
+
nn.Conv2d(512, 128, 1, 1),
|
| 85 |
+
nn.Conv2d(128, 256, 3, 2, padding=1),
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
self.loc = nn.ModuleList([
|
| 89 |
+
nn.Conv2d(256, 4, 3, 1, padding=1),
|
| 90 |
+
nn.Conv2d(512, 4, 3, 1, padding=1),
|
| 91 |
+
nn.Conv2d(512, 4, 3, 1, padding=1),
|
| 92 |
+
nn.Conv2d(1024, 4, 3, 1, padding=1),
|
| 93 |
+
nn.Conv2d(512, 4, 3, 1, padding=1),
|
| 94 |
+
nn.Conv2d(256, 4, 3, 1, padding=1),
|
| 95 |
+
])
|
| 96 |
+
|
| 97 |
+
self.conf = nn.ModuleList([
|
| 98 |
+
nn.Conv2d(256, 4, 3, 1, padding=1),
|
| 99 |
+
nn.Conv2d(512, 2, 3, 1, padding=1),
|
| 100 |
+
nn.Conv2d(512, 2, 3, 1, padding=1),
|
| 101 |
+
nn.Conv2d(1024, 2, 3, 1, padding=1),
|
| 102 |
+
nn.Conv2d(512, 2, 3, 1, padding=1),
|
| 103 |
+
nn.Conv2d(256, 2, 3, 1, padding=1),
|
| 104 |
+
])
|
| 105 |
+
|
| 106 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 107 |
+
self.detect = Detect()
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
size = x.size()[2:]
|
| 111 |
+
sources = list()
|
| 112 |
+
loc = list()
|
| 113 |
+
conf = list()
|
| 114 |
+
|
| 115 |
+
for k in range(16):
|
| 116 |
+
x = self.vgg[k](x)
|
| 117 |
+
s = self.L2Norm3_3(x)
|
| 118 |
+
sources.append(s)
|
| 119 |
+
|
| 120 |
+
for k in range(16, 23):
|
| 121 |
+
x = self.vgg[k](x)
|
| 122 |
+
s = self.L2Norm4_3(x)
|
| 123 |
+
sources.append(s)
|
| 124 |
+
|
| 125 |
+
for k in range(23, 30):
|
| 126 |
+
x = self.vgg[k](x)
|
| 127 |
+
s = self.L2Norm5_3(x)
|
| 128 |
+
sources.append(s)
|
| 129 |
+
|
| 130 |
+
for k in range(30, len(self.vgg)):
|
| 131 |
+
x = self.vgg[k](x)
|
| 132 |
+
sources.append(x)
|
| 133 |
+
|
| 134 |
+
# apply extra layers and cache source layer outputs
|
| 135 |
+
for k, v in enumerate(self.extras):
|
| 136 |
+
x = F.relu(v(x), inplace=True)
|
| 137 |
+
if k % 2 == 1:
|
| 138 |
+
sources.append(x)
|
| 139 |
+
|
| 140 |
+
# apply multibox head to source layers
|
| 141 |
+
loc_x = self.loc[0](sources[0])
|
| 142 |
+
conf_x = self.conf[0](sources[0])
|
| 143 |
+
|
| 144 |
+
max_conf, _ = torch.max(conf_x[:, 0:3, :, :], dim=1, keepdim=True)
|
| 145 |
+
conf_x = torch.cat((max_conf, conf_x[:, 3:, :, :]), dim=1)
|
| 146 |
+
|
| 147 |
+
loc.append(loc_x.permute(0, 2, 3, 1).contiguous())
|
| 148 |
+
conf.append(conf_x.permute(0, 2, 3, 1).contiguous())
|
| 149 |
+
|
| 150 |
+
for i in range(1, len(sources)):
|
| 151 |
+
x = sources[i]
|
| 152 |
+
conf.append(self.conf[i](x).permute(0, 2, 3, 1).contiguous())
|
| 153 |
+
loc.append(self.loc[i](x).permute(0, 2, 3, 1).contiguous())
|
| 154 |
+
|
| 155 |
+
features_maps = []
|
| 156 |
+
for i in range(len(loc)):
|
| 157 |
+
feat = []
|
| 158 |
+
feat += [loc[i].size(1), loc[i].size(2)]
|
| 159 |
+
features_maps += [feat]
|
| 160 |
+
|
| 161 |
+
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
|
| 162 |
+
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
|
| 163 |
+
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
self.priorbox = PriorBox(size, features_maps)
|
| 166 |
+
self.priors = self.priorbox.forward()
|
| 167 |
+
|
| 168 |
+
output = self.detect.forward(
|
| 169 |
+
loc.view(loc.size(0), -1, 4),
|
| 170 |
+
self.softmax(conf.view(conf.size(0), -1, 2)),
|
| 171 |
+
self.priors.type(type(x.data)).to(self.device)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return output
|
LatentSync/eval/draw_syncnet_lines.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Chart:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.loss_list = []
|
| 22 |
+
|
| 23 |
+
def add_ckpt(self, ckpt_path, line_name):
|
| 24 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 25 |
+
train_step_list = ckpt["train_step_list"]
|
| 26 |
+
train_loss_list = ckpt["train_loss_list"]
|
| 27 |
+
val_step_list = ckpt["val_step_list"]
|
| 28 |
+
val_loss_list = ckpt["val_loss_list"]
|
| 29 |
+
val_step_list = [val_step_list[0]] + val_step_list[4::5]
|
| 30 |
+
val_loss_list = [val_loss_list[0]] + val_loss_list[4::5]
|
| 31 |
+
self.loss_list.append((line_name, train_step_list, train_loss_list, val_step_list, val_loss_list))
|
| 32 |
+
|
| 33 |
+
def draw(self, save_path, plot_val=True):
|
| 34 |
+
# Global settings
|
| 35 |
+
plt.rcParams["font.size"] = 14
|
| 36 |
+
plt.rcParams["font.family"] = "serif"
|
| 37 |
+
plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans", "Lucida Grande"]
|
| 38 |
+
plt.rcParams["font.serif"] = ["Times New Roman", "DejaVu Serif"]
|
| 39 |
+
|
| 40 |
+
# Creating the plot
|
| 41 |
+
plt.figure(figsize=(7.766, 4.8)) # Golden ratio
|
| 42 |
+
for loss in self.loss_list:
|
| 43 |
+
if plot_val:
|
| 44 |
+
(line,) = plt.plot(loss[1], loss[2], label=loss[0], linewidth=0.5, alpha=0.5)
|
| 45 |
+
line_color = line.get_color()
|
| 46 |
+
plt.plot(loss[3], loss[4], linewidth=1.5, color=line_color)
|
| 47 |
+
else:
|
| 48 |
+
plt.plot(loss[1], loss[2], label=loss[0], linewidth=1)
|
| 49 |
+
plt.xlabel("Step")
|
| 50 |
+
plt.ylabel("Loss")
|
| 51 |
+
legend = plt.legend()
|
| 52 |
+
# legend = plt.legend(loc='upper right', bbox_to_anchor=(1, 0.82))
|
| 53 |
+
|
| 54 |
+
# Adjust the linewidth of legend
|
| 55 |
+
for line in legend.get_lines():
|
| 56 |
+
line.set_linewidth(2)
|
| 57 |
+
|
| 58 |
+
plt.savefig(save_path, transparent=True)
|
| 59 |
+
plt.close()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
chart = Chart()
|
| 64 |
+
# chart.add_ckpt("output/syncnet/train-2024_10_25-18:14:43/checkpoints/checkpoint-10000.pt", "w/ self-attn")
|
| 65 |
+
# chart.add_ckpt("output/syncnet/train-2024_10_25-18:21:59/checkpoints/checkpoint-10000.pt", "w/o self-attn")
|
| 66 |
+
chart.add_ckpt("output/syncnet/train-2024_10_24-21:03:11/checkpoints/checkpoint-10000.pt", "Dim 512")
|
| 67 |
+
chart.add_ckpt("output/syncnet/train-2024_10_25-18:21:59/checkpoints/checkpoint-10000.pt", "Dim 2048")
|
| 68 |
+
chart.add_ckpt("output/syncnet/train-2024_10_24-22:37:04/checkpoints/checkpoint-10000.pt", "Dim 4096")
|
| 69 |
+
chart.add_ckpt("output/syncnet/train-2024_10_25-02:30:17/checkpoints/checkpoint-10000.pt", "Dim 6144")
|
| 70 |
+
chart.draw("ablation.pdf", plot_val=True)
|
LatentSync/eval/eval_fvd.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import mediapipe as mp
|
| 16 |
+
import cv2
|
| 17 |
+
from decord import VideoReader
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
import os
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import tqdm
|
| 23 |
+
from eval.fvd import compute_our_fvd
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class FVD:
|
| 27 |
+
def __init__(self, resolution=(224, 224)):
|
| 28 |
+
self.face_detector = mp.solutions.face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5)
|
| 29 |
+
self.resolution = resolution
|
| 30 |
+
|
| 31 |
+
def detect_face(self, image):
|
| 32 |
+
height, width = image.shape[:2]
|
| 33 |
+
# Process the image and detect faces.
|
| 34 |
+
results = self.face_detector.process(image)
|
| 35 |
+
|
| 36 |
+
if not results.detections: # Face not detected
|
| 37 |
+
raise Exception("Face not detected")
|
| 38 |
+
|
| 39 |
+
detection = results.detections[0] # Only use the first face in the image
|
| 40 |
+
bounding_box = detection.location_data.relative_bounding_box
|
| 41 |
+
xmin = int(bounding_box.xmin * width)
|
| 42 |
+
ymin = int(bounding_box.ymin * height)
|
| 43 |
+
face_width = int(bounding_box.width * width)
|
| 44 |
+
face_height = int(bounding_box.height * height)
|
| 45 |
+
|
| 46 |
+
# Crop the image to the bounding box.
|
| 47 |
+
xmin = max(0, xmin)
|
| 48 |
+
ymin = max(0, ymin)
|
| 49 |
+
xmax = min(width, xmin + face_width)
|
| 50 |
+
ymax = min(height, ymin + face_height)
|
| 51 |
+
image = image[ymin:ymax, xmin:xmax]
|
| 52 |
+
|
| 53 |
+
return image
|
| 54 |
+
|
| 55 |
+
def detect_video(self, video_path, real: bool = True):
|
| 56 |
+
vr = VideoReader(video_path)
|
| 57 |
+
video_frames = vr[20:36].asnumpy() # Use one frame per second
|
| 58 |
+
vr.seek(0) # avoid memory leak
|
| 59 |
+
faces = []
|
| 60 |
+
for frame in video_frames:
|
| 61 |
+
face = self.detect_face(frame)
|
| 62 |
+
face = cv2.resize(face, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_AREA)
|
| 63 |
+
faces.append(face)
|
| 64 |
+
|
| 65 |
+
if len(faces) != 16:
|
| 66 |
+
return None
|
| 67 |
+
faces = np.stack(faces, axis=0) # (f, h, w, c)
|
| 68 |
+
faces = torch.from_numpy(faces)
|
| 69 |
+
return faces
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def eval_fvd(real_videos_dir, fake_videos_dir):
|
| 73 |
+
fvd = FVD()
|
| 74 |
+
real_features_list = []
|
| 75 |
+
fake_features_list = []
|
| 76 |
+
for file in tqdm.tqdm(os.listdir(fake_videos_dir)):
|
| 77 |
+
if file.endswith(".mp4"):
|
| 78 |
+
real_video_path = os.path.join(real_videos_dir, file.replace("_out.mp4", ".mp4"))
|
| 79 |
+
fake_video_path = os.path.join(fake_videos_dir, file)
|
| 80 |
+
real_features = fvd.detect_video(real_video_path, real=True)
|
| 81 |
+
fake_features = fvd.detect_video(fake_video_path, real=False)
|
| 82 |
+
if real_features is None or fake_features is None:
|
| 83 |
+
continue
|
| 84 |
+
real_features_list.append(real_features)
|
| 85 |
+
fake_features_list.append(fake_features)
|
| 86 |
+
|
| 87 |
+
real_features = torch.stack(real_features_list) / 255.0
|
| 88 |
+
fake_features = torch.stack(fake_features_list) / 255.0
|
| 89 |
+
print(compute_our_fvd(real_features, fake_features, device="cpu"))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
real_videos_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/segmented/cross"
|
| 94 |
+
fake_videos_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/segmented/latentsync_cross"
|
| 95 |
+
|
| 96 |
+
eval_fvd(real_videos_dir, fake_videos_dir)
|
LatentSync/eval/eval_sync_conf.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
import tqdm
|
| 18 |
+
from statistics import fmean
|
| 19 |
+
from eval.syncnet import SyncNetEval
|
| 20 |
+
from eval.syncnet_detect import SyncNetDetector
|
| 21 |
+
from latentsync.utils.util import red_text
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def syncnet_eval(syncnet, syncnet_detector, video_path, temp_dir, detect_results_dir="detect_results"):
|
| 26 |
+
syncnet_detector(video_path=video_path, min_track=50)
|
| 27 |
+
crop_videos = os.listdir(os.path.join(detect_results_dir, "crop"))
|
| 28 |
+
if crop_videos == []:
|
| 29 |
+
raise Exception(red_text(f"Face not detected in {video_path}"))
|
| 30 |
+
av_offset_list = []
|
| 31 |
+
conf_list = []
|
| 32 |
+
for video in crop_videos:
|
| 33 |
+
av_offset, _, conf = syncnet.evaluate(
|
| 34 |
+
video_path=os.path.join(detect_results_dir, "crop", video), temp_dir=temp_dir
|
| 35 |
+
)
|
| 36 |
+
av_offset_list.append(av_offset)
|
| 37 |
+
conf_list.append(conf)
|
| 38 |
+
av_offset = int(fmean(av_offset_list))
|
| 39 |
+
conf = fmean(conf_list)
|
| 40 |
+
print(f"Input video: {video_path}\nSyncNet confidence: {conf:.2f}\nAV offset: {av_offset}")
|
| 41 |
+
return av_offset, conf
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
parser = argparse.ArgumentParser(description="SyncNet")
|
| 46 |
+
parser.add_argument("--initial_model", type=str, default="checkpoints/auxiliary/syncnet_v2.model", help="")
|
| 47 |
+
parser.add_argument("--video_path", type=str, default=None, help="")
|
| 48 |
+
parser.add_argument("--videos_dir", type=str, default="/root/processed")
|
| 49 |
+
parser.add_argument("--temp_dir", type=str, default="temp", help="")
|
| 50 |
+
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 54 |
+
|
| 55 |
+
syncnet = SyncNetEval(device=device)
|
| 56 |
+
syncnet.loadParameters(args.initial_model)
|
| 57 |
+
|
| 58 |
+
syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
|
| 59 |
+
|
| 60 |
+
if args.video_path is not None:
|
| 61 |
+
syncnet_eval(syncnet, syncnet_detector, args.video_path, args.temp_dir)
|
| 62 |
+
else:
|
| 63 |
+
sync_conf_list = []
|
| 64 |
+
video_names = sorted([f for f in os.listdir(args.videos_dir) if f.endswith(".mp4")])
|
| 65 |
+
for video_name in tqdm.tqdm(video_names):
|
| 66 |
+
try:
|
| 67 |
+
_, conf = syncnet_eval(
|
| 68 |
+
syncnet, syncnet_detector, os.path.join(args.videos_dir, video_name), args.temp_dir
|
| 69 |
+
)
|
| 70 |
+
sync_conf_list.append(conf)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(e)
|
| 73 |
+
print(f"The average sync confidence is {fmean(sync_conf_list):.02f}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
main()
|
LatentSync/eval/eval_sync_conf.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
python -m eval.eval_sync_conf --video_path "RD_Radio1_000_006_out.mp4"
|
LatentSync/eval/eval_syncnet_acc.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
from tqdm.auto import tqdm
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from latentsync.models.syncnet import SyncNet
|
| 21 |
+
from latentsync.data.syncnet_dataset import SyncNetDataset
|
| 22 |
+
from diffusers import AutoencoderKL
|
| 23 |
+
from omegaconf import OmegaConf
|
| 24 |
+
from accelerate.utils import set_seed
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main(config):
|
| 28 |
+
set_seed(config.run.seed)
|
| 29 |
+
|
| 30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
|
| 32 |
+
if config.data.latent_space:
|
| 33 |
+
vae = AutoencoderKL.from_pretrained(
|
| 34 |
+
"runwayml/stable-diffusion-inpainting", subfolder="vae", revision="fp16", torch_dtype=torch.float16
|
| 35 |
+
)
|
| 36 |
+
vae.requires_grad_(False)
|
| 37 |
+
vae.to(device)
|
| 38 |
+
|
| 39 |
+
# Dataset and Dataloader setup
|
| 40 |
+
dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config)
|
| 41 |
+
|
| 42 |
+
test_dataloader = torch.utils.data.DataLoader(
|
| 43 |
+
dataset,
|
| 44 |
+
batch_size=config.data.batch_size,
|
| 45 |
+
shuffle=False,
|
| 46 |
+
num_workers=config.data.num_workers,
|
| 47 |
+
drop_last=False,
|
| 48 |
+
worker_init_fn=dataset.worker_init_fn,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Model
|
| 52 |
+
syncnet = SyncNet(OmegaConf.to_container(config.model)).to(device)
|
| 53 |
+
|
| 54 |
+
print(f"Load checkpoint from: {config.ckpt.inference_ckpt_path}")
|
| 55 |
+
checkpoint = torch.load(config.ckpt.inference_ckpt_path, map_location=device)
|
| 56 |
+
|
| 57 |
+
syncnet.load_state_dict(checkpoint["state_dict"])
|
| 58 |
+
syncnet.to(dtype=torch.float16)
|
| 59 |
+
syncnet.requires_grad_(False)
|
| 60 |
+
syncnet.eval()
|
| 61 |
+
|
| 62 |
+
global_step = 0
|
| 63 |
+
num_val_batches = config.data.num_val_samples // config.data.batch_size
|
| 64 |
+
progress_bar = tqdm(range(0, num_val_batches), initial=0, desc="Testing accuracy")
|
| 65 |
+
|
| 66 |
+
num_correct_preds = 0
|
| 67 |
+
num_total_preds = 0
|
| 68 |
+
|
| 69 |
+
while True:
|
| 70 |
+
for step, batch in enumerate(test_dataloader):
|
| 71 |
+
### >>>> Test >>>> ###
|
| 72 |
+
|
| 73 |
+
frames = batch["frames"].to(device, dtype=torch.float16)
|
| 74 |
+
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
|
| 75 |
+
y = batch["y"].to(device, dtype=torch.float16).squeeze(1)
|
| 76 |
+
|
| 77 |
+
if config.data.latent_space:
|
| 78 |
+
frames = rearrange(frames, "b f c h w -> (b f) c h w")
|
| 79 |
+
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
frames = vae.encode(frames).latent_dist.sample() * 0.18215
|
| 82 |
+
|
| 83 |
+
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
|
| 84 |
+
else:
|
| 85 |
+
frames = rearrange(frames, "b f c h w -> b (f c) h w")
|
| 86 |
+
|
| 87 |
+
if config.data.lower_half:
|
| 88 |
+
height = frames.shape[2]
|
| 89 |
+
frames = frames[:, :, height // 2 :, :]
|
| 90 |
+
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
|
| 93 |
+
|
| 94 |
+
sims = nn.functional.cosine_similarity(vision_embeds, audio_embeds)
|
| 95 |
+
|
| 96 |
+
preds = (sims > 0.5).to(dtype=torch.float16)
|
| 97 |
+
num_correct_preds += (preds == y).sum().item()
|
| 98 |
+
num_total_preds += len(sims)
|
| 99 |
+
|
| 100 |
+
progress_bar.update(1)
|
| 101 |
+
global_step += 1
|
| 102 |
+
|
| 103 |
+
if global_step >= num_val_batches:
|
| 104 |
+
progress_bar.close()
|
| 105 |
+
print(f"Accuracy score: {num_correct_preds / num_total_preds*100:.2f}%")
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
parser = argparse.ArgumentParser(description="Code to test the accuracy of expert lip-sync discriminator")
|
| 111 |
+
|
| 112 |
+
parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_latent.yaml")
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
|
| 115 |
+
# Load a configuration file
|
| 116 |
+
config = OmegaConf.load(args.config_path)
|
| 117 |
+
|
| 118 |
+
main(config)
|
LatentSync/eval/eval_syncnet_acc.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
python -m eval.eval_syncnet_acc --config_path "configs/syncnet/syncnet_16_pixel.yaml"
|
LatentSync/eval/fvd.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/universome/fvd-comparison/blob/master/our_fvd.py
|
| 2 |
+
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
import scipy
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
|
| 10 |
+
mu_gen, sigma_gen = compute_stats(feats_fake)
|
| 11 |
+
mu_real, sigma_real = compute_stats(feats_real)
|
| 12 |
+
|
| 13 |
+
m = np.square(mu_gen - mu_real).sum()
|
| 14 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
| 15 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
| 16 |
+
|
| 17 |
+
return float(fid)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 21 |
+
mu = feats.mean(axis=0) # [d]
|
| 22 |
+
sigma = np.cov(feats, rowvar=False) # [d, d]
|
| 23 |
+
|
| 24 |
+
return mu, sigma
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@torch.no_grad()
|
| 28 |
+
def compute_our_fvd(videos_fake: np.ndarray, videos_real: np.ndarray, device: str = "cuda") -> float:
|
| 29 |
+
i3d_path = "checkpoints/auxiliary/i3d_torchscript.pt"
|
| 30 |
+
i3d_kwargs = dict(
|
| 31 |
+
rescale=False, resize=False, return_features=True
|
| 32 |
+
) # Return raw features before the softmax layer.
|
| 33 |
+
|
| 34 |
+
with open(i3d_path, "rb") as f:
|
| 35 |
+
i3d_model = torch.jit.load(f).eval().to(device)
|
| 36 |
+
|
| 37 |
+
videos_fake = videos_fake.permute(0, 4, 1, 2, 3).to(device)
|
| 38 |
+
videos_real = videos_real.permute(0, 4, 1, 2, 3).to(device)
|
| 39 |
+
|
| 40 |
+
feats_fake = i3d_model(videos_fake, **i3d_kwargs).cpu().numpy()
|
| 41 |
+
feats_real = i3d_model(videos_real, **i3d_kwargs).cpu().numpy()
|
| 42 |
+
|
| 43 |
+
return compute_fvd(feats_fake, feats_real)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def main():
|
| 47 |
+
# input shape: (b, f, h, w, c)
|
| 48 |
+
videos_fake = torch.rand(10, 16, 224, 224, 3)
|
| 49 |
+
videos_real = torch.rand(10, 16, 224, 224, 3)
|
| 50 |
+
|
| 51 |
+
our_fvd_result = compute_our_fvd(videos_fake, videos_real)
|
| 52 |
+
print(f"[FVD scores] Ours: {our_fvd_result}")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
main()
|
LatentSync/eval/hyper_iqa.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/SSL92/hyperIQA/blob/master/models.py
|
| 2 |
+
|
| 3 |
+
import torch as torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.nn import init
|
| 7 |
+
import math
|
| 8 |
+
import torch.utils.model_zoo as model_zoo
|
| 9 |
+
|
| 10 |
+
model_urls = {
|
| 11 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
| 12 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
| 13 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
| 14 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
| 15 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class HyperNet(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Hyper network for learning perceptual rules.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
lda_out_channels: local distortion aware module output size.
|
| 25 |
+
hyper_in_channels: input feature channels for hyper network.
|
| 26 |
+
target_in_size: input vector size for target network.
|
| 27 |
+
target_fc(i)_size: fully connection layer size of target network.
|
| 28 |
+
feature_size: input feature map width/height for hyper network.
|
| 29 |
+
|
| 30 |
+
Note:
|
| 31 |
+
For size match, input args must satisfy: 'target_fc(i)_size * target_fc(i+1)_size' is divisible by 'feature_size ^ 2'.
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
def __init__(self, lda_out_channels, hyper_in_channels, target_in_size, target_fc1_size, target_fc2_size, target_fc3_size, target_fc4_size, feature_size):
|
| 35 |
+
super(HyperNet, self).__init__()
|
| 36 |
+
|
| 37 |
+
self.hyperInChn = hyper_in_channels
|
| 38 |
+
self.target_in_size = target_in_size
|
| 39 |
+
self.f1 = target_fc1_size
|
| 40 |
+
self.f2 = target_fc2_size
|
| 41 |
+
self.f3 = target_fc3_size
|
| 42 |
+
self.f4 = target_fc4_size
|
| 43 |
+
self.feature_size = feature_size
|
| 44 |
+
|
| 45 |
+
self.res = resnet50_backbone(lda_out_channels, target_in_size, pretrained=True)
|
| 46 |
+
|
| 47 |
+
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 48 |
+
|
| 49 |
+
# Conv layers for resnet output features
|
| 50 |
+
self.conv1 = nn.Sequential(
|
| 51 |
+
nn.Conv2d(2048, 1024, 1, padding=(0, 0)),
|
| 52 |
+
nn.ReLU(inplace=True),
|
| 53 |
+
nn.Conv2d(1024, 512, 1, padding=(0, 0)),
|
| 54 |
+
nn.ReLU(inplace=True),
|
| 55 |
+
nn.Conv2d(512, self.hyperInChn, 1, padding=(0, 0)),
|
| 56 |
+
nn.ReLU(inplace=True)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Hyper network part, conv for generating target fc weights, fc for generating target fc biases
|
| 60 |
+
self.fc1w_conv = nn.Conv2d(self.hyperInChn, int(self.target_in_size * self.f1 / feature_size ** 2), 3, padding=(1, 1))
|
| 61 |
+
self.fc1b_fc = nn.Linear(self.hyperInChn, self.f1)
|
| 62 |
+
|
| 63 |
+
self.fc2w_conv = nn.Conv2d(self.hyperInChn, int(self.f1 * self.f2 / feature_size ** 2), 3, padding=(1, 1))
|
| 64 |
+
self.fc2b_fc = nn.Linear(self.hyperInChn, self.f2)
|
| 65 |
+
|
| 66 |
+
self.fc3w_conv = nn.Conv2d(self.hyperInChn, int(self.f2 * self.f3 / feature_size ** 2), 3, padding=(1, 1))
|
| 67 |
+
self.fc3b_fc = nn.Linear(self.hyperInChn, self.f3)
|
| 68 |
+
|
| 69 |
+
self.fc4w_conv = nn.Conv2d(self.hyperInChn, int(self.f3 * self.f4 / feature_size ** 2), 3, padding=(1, 1))
|
| 70 |
+
self.fc4b_fc = nn.Linear(self.hyperInChn, self.f4)
|
| 71 |
+
|
| 72 |
+
self.fc5w_fc = nn.Linear(self.hyperInChn, self.f4)
|
| 73 |
+
self.fc5b_fc = nn.Linear(self.hyperInChn, 1)
|
| 74 |
+
|
| 75 |
+
# initialize
|
| 76 |
+
for i, m_name in enumerate(self._modules):
|
| 77 |
+
if i > 2:
|
| 78 |
+
nn.init.kaiming_normal_(self._modules[m_name].weight.data)
|
| 79 |
+
|
| 80 |
+
def forward(self, img):
|
| 81 |
+
feature_size = self.feature_size
|
| 82 |
+
|
| 83 |
+
res_out = self.res(img)
|
| 84 |
+
|
| 85 |
+
# input vector for target net
|
| 86 |
+
target_in_vec = res_out['target_in_vec'].reshape(-1, self.target_in_size, 1, 1)
|
| 87 |
+
|
| 88 |
+
# input features for hyper net
|
| 89 |
+
hyper_in_feat = self.conv1(res_out['hyper_in_feat']).reshape(-1, self.hyperInChn, feature_size, feature_size)
|
| 90 |
+
|
| 91 |
+
# generating target net weights & biases
|
| 92 |
+
target_fc1w = self.fc1w_conv(hyper_in_feat).reshape(-1, self.f1, self.target_in_size, 1, 1)
|
| 93 |
+
target_fc1b = self.fc1b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f1)
|
| 94 |
+
|
| 95 |
+
target_fc2w = self.fc2w_conv(hyper_in_feat).reshape(-1, self.f2, self.f1, 1, 1)
|
| 96 |
+
target_fc2b = self.fc2b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f2)
|
| 97 |
+
|
| 98 |
+
target_fc3w = self.fc3w_conv(hyper_in_feat).reshape(-1, self.f3, self.f2, 1, 1)
|
| 99 |
+
target_fc3b = self.fc3b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f3)
|
| 100 |
+
|
| 101 |
+
target_fc4w = self.fc4w_conv(hyper_in_feat).reshape(-1, self.f4, self.f3, 1, 1)
|
| 102 |
+
target_fc4b = self.fc4b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f4)
|
| 103 |
+
|
| 104 |
+
target_fc5w = self.fc5w_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, 1, self.f4, 1, 1)
|
| 105 |
+
target_fc5b = self.fc5b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, 1)
|
| 106 |
+
|
| 107 |
+
out = {}
|
| 108 |
+
out['target_in_vec'] = target_in_vec
|
| 109 |
+
out['target_fc1w'] = target_fc1w
|
| 110 |
+
out['target_fc1b'] = target_fc1b
|
| 111 |
+
out['target_fc2w'] = target_fc2w
|
| 112 |
+
out['target_fc2b'] = target_fc2b
|
| 113 |
+
out['target_fc3w'] = target_fc3w
|
| 114 |
+
out['target_fc3b'] = target_fc3b
|
| 115 |
+
out['target_fc4w'] = target_fc4w
|
| 116 |
+
out['target_fc4b'] = target_fc4b
|
| 117 |
+
out['target_fc5w'] = target_fc5w
|
| 118 |
+
out['target_fc5b'] = target_fc5b
|
| 119 |
+
|
| 120 |
+
return out
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class TargetNet(nn.Module):
|
| 124 |
+
"""
|
| 125 |
+
Target network for quality prediction.
|
| 126 |
+
"""
|
| 127 |
+
def __init__(self, paras):
|
| 128 |
+
super(TargetNet, self).__init__()
|
| 129 |
+
self.l1 = nn.Sequential(
|
| 130 |
+
TargetFC(paras['target_fc1w'], paras['target_fc1b']),
|
| 131 |
+
nn.Sigmoid(),
|
| 132 |
+
)
|
| 133 |
+
self.l2 = nn.Sequential(
|
| 134 |
+
TargetFC(paras['target_fc2w'], paras['target_fc2b']),
|
| 135 |
+
nn.Sigmoid(),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.l3 = nn.Sequential(
|
| 139 |
+
TargetFC(paras['target_fc3w'], paras['target_fc3b']),
|
| 140 |
+
nn.Sigmoid(),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.l4 = nn.Sequential(
|
| 144 |
+
TargetFC(paras['target_fc4w'], paras['target_fc4b']),
|
| 145 |
+
nn.Sigmoid(),
|
| 146 |
+
TargetFC(paras['target_fc5w'], paras['target_fc5b']),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
q = self.l1(x)
|
| 151 |
+
# q = F.dropout(q)
|
| 152 |
+
q = self.l2(q)
|
| 153 |
+
q = self.l3(q)
|
| 154 |
+
q = self.l4(q).squeeze()
|
| 155 |
+
return q
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class TargetFC(nn.Module):
|
| 159 |
+
"""
|
| 160 |
+
Fully connection operations for target net
|
| 161 |
+
|
| 162 |
+
Note:
|
| 163 |
+
Weights & biases are different for different images in a batch,
|
| 164 |
+
thus here we use group convolution for calculating images in a batch with individual weights & biases.
|
| 165 |
+
"""
|
| 166 |
+
def __init__(self, weight, bias):
|
| 167 |
+
super(TargetFC, self).__init__()
|
| 168 |
+
self.weight = weight
|
| 169 |
+
self.bias = bias
|
| 170 |
+
|
| 171 |
+
def forward(self, input_):
|
| 172 |
+
|
| 173 |
+
input_re = input_.reshape(-1, input_.shape[0] * input_.shape[1], input_.shape[2], input_.shape[3])
|
| 174 |
+
weight_re = self.weight.reshape(self.weight.shape[0] * self.weight.shape[1], self.weight.shape[2], self.weight.shape[3], self.weight.shape[4])
|
| 175 |
+
bias_re = self.bias.reshape(self.bias.shape[0] * self.bias.shape[1])
|
| 176 |
+
out = F.conv2d(input=input_re, weight=weight_re, bias=bias_re, groups=self.weight.shape[0])
|
| 177 |
+
|
| 178 |
+
return out.reshape(input_.shape[0], self.weight.shape[1], input_.shape[2], input_.shape[3])
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class Bottleneck(nn.Module):
|
| 182 |
+
expansion = 4
|
| 183 |
+
|
| 184 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 185 |
+
super(Bottleneck, self).__init__()
|
| 186 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 187 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 188 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 189 |
+
padding=1, bias=False)
|
| 190 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 191 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 192 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 193 |
+
self.relu = nn.ReLU(inplace=True)
|
| 194 |
+
self.downsample = downsample
|
| 195 |
+
self.stride = stride
|
| 196 |
+
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
residual = x
|
| 199 |
+
|
| 200 |
+
out = self.conv1(x)
|
| 201 |
+
out = self.bn1(out)
|
| 202 |
+
out = self.relu(out)
|
| 203 |
+
|
| 204 |
+
out = self.conv2(out)
|
| 205 |
+
out = self.bn2(out)
|
| 206 |
+
out = self.relu(out)
|
| 207 |
+
|
| 208 |
+
out = self.conv3(out)
|
| 209 |
+
out = self.bn3(out)
|
| 210 |
+
|
| 211 |
+
if self.downsample is not None:
|
| 212 |
+
residual = self.downsample(x)
|
| 213 |
+
|
| 214 |
+
out += residual
|
| 215 |
+
out = self.relu(out)
|
| 216 |
+
|
| 217 |
+
return out
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class ResNetBackbone(nn.Module):
|
| 221 |
+
|
| 222 |
+
def __init__(self, lda_out_channels, in_chn, block, layers, num_classes=1000):
|
| 223 |
+
super(ResNetBackbone, self).__init__()
|
| 224 |
+
self.inplanes = 64
|
| 225 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 226 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 227 |
+
self.relu = nn.ReLU(inplace=True)
|
| 228 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 229 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 230 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 231 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 232 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 233 |
+
|
| 234 |
+
# local distortion aware module
|
| 235 |
+
self.lda1_pool = nn.Sequential(
|
| 236 |
+
nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False),
|
| 237 |
+
nn.AvgPool2d(7, stride=7),
|
| 238 |
+
)
|
| 239 |
+
self.lda1_fc = nn.Linear(16 * 64, lda_out_channels)
|
| 240 |
+
|
| 241 |
+
self.lda2_pool = nn.Sequential(
|
| 242 |
+
nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False),
|
| 243 |
+
nn.AvgPool2d(7, stride=7),
|
| 244 |
+
)
|
| 245 |
+
self.lda2_fc = nn.Linear(32 * 16, lda_out_channels)
|
| 246 |
+
|
| 247 |
+
self.lda3_pool = nn.Sequential(
|
| 248 |
+
nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False),
|
| 249 |
+
nn.AvgPool2d(7, stride=7),
|
| 250 |
+
)
|
| 251 |
+
self.lda3_fc = nn.Linear(64 * 4, lda_out_channels)
|
| 252 |
+
|
| 253 |
+
self.lda4_pool = nn.AvgPool2d(7, stride=7)
|
| 254 |
+
self.lda4_fc = nn.Linear(2048, in_chn - lda_out_channels * 3)
|
| 255 |
+
|
| 256 |
+
for m in self.modules():
|
| 257 |
+
if isinstance(m, nn.Conv2d):
|
| 258 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 259 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 260 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 261 |
+
m.weight.data.fill_(1)
|
| 262 |
+
m.bias.data.zero_()
|
| 263 |
+
|
| 264 |
+
# initialize
|
| 265 |
+
nn.init.kaiming_normal_(self.lda1_pool._modules['0'].weight.data)
|
| 266 |
+
nn.init.kaiming_normal_(self.lda2_pool._modules['0'].weight.data)
|
| 267 |
+
nn.init.kaiming_normal_(self.lda3_pool._modules['0'].weight.data)
|
| 268 |
+
nn.init.kaiming_normal_(self.lda1_fc.weight.data)
|
| 269 |
+
nn.init.kaiming_normal_(self.lda2_fc.weight.data)
|
| 270 |
+
nn.init.kaiming_normal_(self.lda3_fc.weight.data)
|
| 271 |
+
nn.init.kaiming_normal_(self.lda4_fc.weight.data)
|
| 272 |
+
|
| 273 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 274 |
+
downsample = None
|
| 275 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 276 |
+
downsample = nn.Sequential(
|
| 277 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 278 |
+
kernel_size=1, stride=stride, bias=False),
|
| 279 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
layers = []
|
| 283 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 284 |
+
self.inplanes = planes * block.expansion
|
| 285 |
+
for i in range(1, blocks):
|
| 286 |
+
layers.append(block(self.inplanes, planes))
|
| 287 |
+
|
| 288 |
+
return nn.Sequential(*layers)
|
| 289 |
+
|
| 290 |
+
def forward(self, x):
|
| 291 |
+
x = self.conv1(x)
|
| 292 |
+
x = self.bn1(x)
|
| 293 |
+
x = self.relu(x)
|
| 294 |
+
x = self.maxpool(x)
|
| 295 |
+
x = self.layer1(x)
|
| 296 |
+
|
| 297 |
+
# the same effect as lda operation in the paper, but save much more memory
|
| 298 |
+
lda_1 = self.lda1_fc(self.lda1_pool(x).reshape(x.size(0), -1))
|
| 299 |
+
x = self.layer2(x)
|
| 300 |
+
lda_2 = self.lda2_fc(self.lda2_pool(x).reshape(x.size(0), -1))
|
| 301 |
+
x = self.layer3(x)
|
| 302 |
+
lda_3 = self.lda3_fc(self.lda3_pool(x).reshape(x.size(0), -1))
|
| 303 |
+
x = self.layer4(x)
|
| 304 |
+
lda_4 = self.lda4_fc(self.lda4_pool(x).reshape(x.size(0), -1))
|
| 305 |
+
|
| 306 |
+
vec = torch.cat((lda_1, lda_2, lda_3, lda_4), 1)
|
| 307 |
+
|
| 308 |
+
out = {}
|
| 309 |
+
out['hyper_in_feat'] = x
|
| 310 |
+
out['target_in_vec'] = vec
|
| 311 |
+
|
| 312 |
+
return out
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def resnet50_backbone(lda_out_channels, in_chn, pretrained=False, **kwargs):
|
| 316 |
+
"""Constructs a ResNet-50 model_hyper.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
pretrained (bool): If True, returns a model_hyper pre-trained on ImageNet
|
| 320 |
+
"""
|
| 321 |
+
model = ResNetBackbone(lda_out_channels, in_chn, Bottleneck, [3, 4, 6, 3], **kwargs)
|
| 322 |
+
if pretrained:
|
| 323 |
+
save_model = model_zoo.load_url(model_urls['resnet50'])
|
| 324 |
+
model_dict = model.state_dict()
|
| 325 |
+
state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
|
| 326 |
+
model_dict.update(state_dict)
|
| 327 |
+
model.load_state_dict(model_dict)
|
| 328 |
+
else:
|
| 329 |
+
model.apply(weights_init_xavier)
|
| 330 |
+
return model
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def weights_init_xavier(m):
|
| 334 |
+
classname = m.__class__.__name__
|
| 335 |
+
# print(classname)
|
| 336 |
+
# if isinstance(m, nn.Conv2d):
|
| 337 |
+
if classname.find('Conv') != -1:
|
| 338 |
+
init.kaiming_normal_(m.weight.data)
|
| 339 |
+
elif classname.find('Linear') != -1:
|
| 340 |
+
init.kaiming_normal_(m.weight.data)
|
| 341 |
+
elif classname.find('BatchNorm2d') != -1:
|
| 342 |
+
init.uniform_(m.weight.data, 1.0, 0.02)
|
| 343 |
+
init.constant_(m.bias.data, 0.0)
|
LatentSync/eval/inference_videos.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import subprocess
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def inference_video_from_dir(input_dir, output_dir, unet_config_path, ckpt_path):
|
| 21 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 22 |
+
video_names = sorted([f for f in os.listdir(input_dir) if f.endswith(".mp4")])
|
| 23 |
+
for video_name in tqdm(video_names):
|
| 24 |
+
video_path = os.path.join(input_dir, video_name)
|
| 25 |
+
audio_path = os.path.join(input_dir, video_name.replace(".mp4", "_audio.wav"))
|
| 26 |
+
video_out_path = os.path.join(output_dir, video_name.replace(".mp4", "_out.mp4"))
|
| 27 |
+
inference_command = f"python inference.py --unet_config_path {unet_config_path} --video_path {video_path} --audio_path {audio_path} --video_out_path {video_out_path} --inference_ckpt_path {ckpt_path} --seed 1247"
|
| 28 |
+
subprocess.run(inference_command, shell=True)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/segmented/cross"
|
| 33 |
+
output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/segmented/latentsync_cross"
|
| 34 |
+
unet_config_path = "configs/unet/unet_latent_16_diffusion.yaml"
|
| 35 |
+
ckpt_path = "output/unet/train-2024_10_08-16:23:43/checkpoints/checkpoint-1920000.pt"
|
| 36 |
+
|
| 37 |
+
inference_video_from_dir(input_dir, output_dir, unet_config_path, ckpt_path)
|
LatentSync/eval/syncnet/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .syncnet_eval import SyncNetEval
|
LatentSync/eval/syncnet/syncnet.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/joonson/syncnet_python/blob/master/SyncNetModel.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def save(model, filename):
|
| 8 |
+
with open(filename, "wb") as f:
|
| 9 |
+
torch.save(model, f)
|
| 10 |
+
print("%s saved." % filename)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load(filename):
|
| 14 |
+
net = torch.load(filename)
|
| 15 |
+
return net
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class S(nn.Module):
|
| 19 |
+
def __init__(self, num_layers_in_fc_layers=1024):
|
| 20 |
+
super(S, self).__init__()
|
| 21 |
+
|
| 22 |
+
self.__nFeatures__ = 24
|
| 23 |
+
self.__nChs__ = 32
|
| 24 |
+
self.__midChs__ = 32
|
| 25 |
+
|
| 26 |
+
self.netcnnaud = nn.Sequential(
|
| 27 |
+
nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
| 28 |
+
nn.BatchNorm2d(64),
|
| 29 |
+
nn.ReLU(inplace=True),
|
| 30 |
+
nn.MaxPool2d(kernel_size=(1, 1), stride=(1, 1)),
|
| 31 |
+
nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
| 32 |
+
nn.BatchNorm2d(192),
|
| 33 |
+
nn.ReLU(inplace=True),
|
| 34 |
+
nn.MaxPool2d(kernel_size=(3, 3), stride=(1, 2)),
|
| 35 |
+
nn.Conv2d(192, 384, kernel_size=(3, 3), padding=(1, 1)),
|
| 36 |
+
nn.BatchNorm2d(384),
|
| 37 |
+
nn.ReLU(inplace=True),
|
| 38 |
+
nn.Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1)),
|
| 39 |
+
nn.BatchNorm2d(256),
|
| 40 |
+
nn.ReLU(inplace=True),
|
| 41 |
+
nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1)),
|
| 42 |
+
nn.BatchNorm2d(256),
|
| 43 |
+
nn.ReLU(inplace=True),
|
| 44 |
+
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
|
| 45 |
+
nn.Conv2d(256, 512, kernel_size=(5, 4), padding=(0, 0)),
|
| 46 |
+
nn.BatchNorm2d(512),
|
| 47 |
+
nn.ReLU(),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.netfcaud = nn.Sequential(
|
| 51 |
+
nn.Linear(512, 512),
|
| 52 |
+
nn.BatchNorm1d(512),
|
| 53 |
+
nn.ReLU(),
|
| 54 |
+
nn.Linear(512, num_layers_in_fc_layers),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.netfclip = nn.Sequential(
|
| 58 |
+
nn.Linear(512, 512),
|
| 59 |
+
nn.BatchNorm1d(512),
|
| 60 |
+
nn.ReLU(),
|
| 61 |
+
nn.Linear(512, num_layers_in_fc_layers),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.netcnnlip = nn.Sequential(
|
| 65 |
+
nn.Conv3d(3, 96, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=0),
|
| 66 |
+
nn.BatchNorm3d(96),
|
| 67 |
+
nn.ReLU(inplace=True),
|
| 68 |
+
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2)),
|
| 69 |
+
nn.Conv3d(96, 256, kernel_size=(1, 5, 5), stride=(1, 2, 2), padding=(0, 1, 1)),
|
| 70 |
+
nn.BatchNorm3d(256),
|
| 71 |
+
nn.ReLU(inplace=True),
|
| 72 |
+
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
|
| 73 |
+
nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
|
| 74 |
+
nn.BatchNorm3d(256),
|
| 75 |
+
nn.ReLU(inplace=True),
|
| 76 |
+
nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
|
| 77 |
+
nn.BatchNorm3d(256),
|
| 78 |
+
nn.ReLU(inplace=True),
|
| 79 |
+
nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
|
| 80 |
+
nn.BatchNorm3d(256),
|
| 81 |
+
nn.ReLU(inplace=True),
|
| 82 |
+
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2)),
|
| 83 |
+
nn.Conv3d(256, 512, kernel_size=(1, 6, 6), padding=0),
|
| 84 |
+
nn.BatchNorm3d(512),
|
| 85 |
+
nn.ReLU(inplace=True),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward_aud(self, x):
|
| 89 |
+
|
| 90 |
+
mid = self.netcnnaud(x)
|
| 91 |
+
# N x ch x 24 x M
|
| 92 |
+
mid = mid.view((mid.size()[0], -1))
|
| 93 |
+
# N x (ch x 24)
|
| 94 |
+
out = self.netfcaud(mid)
|
| 95 |
+
|
| 96 |
+
return out
|
| 97 |
+
|
| 98 |
+
def forward_lip(self, x):
|
| 99 |
+
|
| 100 |
+
mid = self.netcnnlip(x)
|
| 101 |
+
mid = mid.view((mid.size()[0], -1))
|
| 102 |
+
# N x (ch x 24)
|
| 103 |
+
out = self.netfclip(mid)
|
| 104 |
+
|
| 105 |
+
return out
|
| 106 |
+
|
| 107 |
+
def forward_lipfeat(self, x):
|
| 108 |
+
|
| 109 |
+
mid = self.netcnnlip(x)
|
| 110 |
+
out = mid.view((mid.size()[0], -1))
|
| 111 |
+
# N x (ch x 24)
|
| 112 |
+
|
| 113 |
+
return out
|
LatentSync/eval/syncnet/syncnet_eval.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/joonson/syncnet_python/blob/master/SyncNetInstance.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy
|
| 5 |
+
import time, pdb, argparse, subprocess, os, math, glob
|
| 6 |
+
import cv2
|
| 7 |
+
import python_speech_features
|
| 8 |
+
|
| 9 |
+
from scipy import signal
|
| 10 |
+
from scipy.io import wavfile
|
| 11 |
+
from .syncnet import S
|
| 12 |
+
from shutil import rmtree
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ==================== Get OFFSET ====================
|
| 16 |
+
|
| 17 |
+
# Video 25 FPS, Audio 16000HZ
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def calc_pdist(feat1, feat2, vshift=10):
|
| 21 |
+
win_size = vshift * 2 + 1
|
| 22 |
+
|
| 23 |
+
feat2p = torch.nn.functional.pad(feat2, (0, 0, vshift, vshift))
|
| 24 |
+
|
| 25 |
+
dists = []
|
| 26 |
+
|
| 27 |
+
for i in range(0, len(feat1)):
|
| 28 |
+
|
| 29 |
+
dists.append(
|
| 30 |
+
torch.nn.functional.pairwise_distance(feat1[[i], :].repeat(win_size, 1), feat2p[i : i + win_size, :])
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
return dists
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ==================== MAIN DEF ====================
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SyncNetEval(torch.nn.Module):
|
| 40 |
+
def __init__(self, dropout=0, num_layers_in_fc_layers=1024, device="cpu"):
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
self.__S__ = S(num_layers_in_fc_layers=num_layers_in_fc_layers).to(device)
|
| 44 |
+
self.device = device
|
| 45 |
+
|
| 46 |
+
def evaluate(self, video_path, temp_dir="temp", batch_size=20, vshift=15):
|
| 47 |
+
|
| 48 |
+
self.__S__.eval()
|
| 49 |
+
|
| 50 |
+
# ========== ==========
|
| 51 |
+
# Convert files
|
| 52 |
+
# ========== ==========
|
| 53 |
+
|
| 54 |
+
if os.path.exists(temp_dir):
|
| 55 |
+
rmtree(temp_dir)
|
| 56 |
+
|
| 57 |
+
os.makedirs(temp_dir)
|
| 58 |
+
|
| 59 |
+
# temp_video_path = os.path.join(temp_dir, "temp.mp4")
|
| 60 |
+
# command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -vf scale='224:224' {temp_video_path}"
|
| 61 |
+
# subprocess.call(command, shell=True)
|
| 62 |
+
|
| 63 |
+
command = (
|
| 64 |
+
f"ffmpeg -loglevel error -nostdin -y -i {video_path} -f image2 {os.path.join(temp_dir, '%06d.jpg')}"
|
| 65 |
+
)
|
| 66 |
+
subprocess.call(command, shell=True, stdout=None)
|
| 67 |
+
|
| 68 |
+
command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {os.path.join(temp_dir, 'audio.wav')}"
|
| 69 |
+
subprocess.call(command, shell=True, stdout=None)
|
| 70 |
+
|
| 71 |
+
# ========== ==========
|
| 72 |
+
# Load video
|
| 73 |
+
# ========== ==========
|
| 74 |
+
|
| 75 |
+
images = []
|
| 76 |
+
|
| 77 |
+
flist = glob.glob(os.path.join(temp_dir, "*.jpg"))
|
| 78 |
+
flist.sort()
|
| 79 |
+
|
| 80 |
+
for fname in flist:
|
| 81 |
+
img_input = cv2.imread(fname)
|
| 82 |
+
img_input = cv2.resize(img_input, (224, 224)) # HARD CODED, CHANGE BEFORE RELEASE
|
| 83 |
+
images.append(img_input)
|
| 84 |
+
|
| 85 |
+
im = numpy.stack(images, axis=3)
|
| 86 |
+
im = numpy.expand_dims(im, axis=0)
|
| 87 |
+
im = numpy.transpose(im, (0, 3, 4, 1, 2))
|
| 88 |
+
|
| 89 |
+
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
|
| 90 |
+
|
| 91 |
+
# ========== ==========
|
| 92 |
+
# Load audio
|
| 93 |
+
# ========== ==========
|
| 94 |
+
|
| 95 |
+
sample_rate, audio = wavfile.read(os.path.join(temp_dir, "audio.wav"))
|
| 96 |
+
mfcc = zip(*python_speech_features.mfcc(audio, sample_rate))
|
| 97 |
+
mfcc = numpy.stack([numpy.array(i) for i in mfcc])
|
| 98 |
+
|
| 99 |
+
cc = numpy.expand_dims(numpy.expand_dims(mfcc, axis=0), axis=0)
|
| 100 |
+
cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
|
| 101 |
+
|
| 102 |
+
# ========== ==========
|
| 103 |
+
# Check audio and video input length
|
| 104 |
+
# ========== ==========
|
| 105 |
+
|
| 106 |
+
# if (float(len(audio)) / 16000) != (float(len(images)) / 25):
|
| 107 |
+
# print(
|
| 108 |
+
# "WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."
|
| 109 |
+
# % (float(len(audio)) / 16000, float(len(images)) / 25)
|
| 110 |
+
# )
|
| 111 |
+
|
| 112 |
+
min_length = min(len(images), math.floor(len(audio) / 640))
|
| 113 |
+
|
| 114 |
+
# ========== ==========
|
| 115 |
+
# Generate video and audio feats
|
| 116 |
+
# ========== ==========
|
| 117 |
+
|
| 118 |
+
lastframe = min_length - 5
|
| 119 |
+
im_feat = []
|
| 120 |
+
cc_feat = []
|
| 121 |
+
|
| 122 |
+
tS = time.time()
|
| 123 |
+
for i in range(0, lastframe, batch_size):
|
| 124 |
+
|
| 125 |
+
im_batch = [imtv[:, :, vframe : vframe + 5, :, :] for vframe in range(i, min(lastframe, i + batch_size))]
|
| 126 |
+
im_in = torch.cat(im_batch, 0)
|
| 127 |
+
im_out = self.__S__.forward_lip(im_in.to(self.device))
|
| 128 |
+
im_feat.append(im_out.data.cpu())
|
| 129 |
+
|
| 130 |
+
cc_batch = [
|
| 131 |
+
cct[:, :, :, vframe * 4 : vframe * 4 + 20] for vframe in range(i, min(lastframe, i + batch_size))
|
| 132 |
+
]
|
| 133 |
+
cc_in = torch.cat(cc_batch, 0)
|
| 134 |
+
cc_out = self.__S__.forward_aud(cc_in.to(self.device))
|
| 135 |
+
cc_feat.append(cc_out.data.cpu())
|
| 136 |
+
|
| 137 |
+
im_feat = torch.cat(im_feat, 0)
|
| 138 |
+
cc_feat = torch.cat(cc_feat, 0)
|
| 139 |
+
|
| 140 |
+
# ========== ==========
|
| 141 |
+
# Compute offset
|
| 142 |
+
# ========== ==========
|
| 143 |
+
|
| 144 |
+
dists = calc_pdist(im_feat, cc_feat, vshift=vshift)
|
| 145 |
+
mean_dists = torch.mean(torch.stack(dists, 1), 1)
|
| 146 |
+
|
| 147 |
+
min_dist, minidx = torch.min(mean_dists, 0)
|
| 148 |
+
|
| 149 |
+
av_offset = vshift - minidx
|
| 150 |
+
conf = torch.median(mean_dists) - min_dist
|
| 151 |
+
|
| 152 |
+
fdist = numpy.stack([dist[minidx].numpy() for dist in dists])
|
| 153 |
+
# fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15)
|
| 154 |
+
fconf = torch.median(mean_dists).numpy() - fdist
|
| 155 |
+
framewise_conf = signal.medfilt(fconf, kernel_size=9)
|
| 156 |
+
|
| 157 |
+
# numpy.set_printoptions(formatter={"float": "{: 0.3f}".format})
|
| 158 |
+
rmtree(temp_dir)
|
| 159 |
+
return av_offset.item(), min_dist.item(), conf.item()
|
| 160 |
+
|
| 161 |
+
def extract_feature(self, opt, videofile):
|
| 162 |
+
|
| 163 |
+
self.__S__.eval()
|
| 164 |
+
|
| 165 |
+
# ========== ==========
|
| 166 |
+
# Load video
|
| 167 |
+
# ========== ==========
|
| 168 |
+
cap = cv2.VideoCapture(videofile)
|
| 169 |
+
|
| 170 |
+
frame_num = 1
|
| 171 |
+
images = []
|
| 172 |
+
while frame_num:
|
| 173 |
+
frame_num += 1
|
| 174 |
+
ret, image = cap.read()
|
| 175 |
+
if ret == 0:
|
| 176 |
+
break
|
| 177 |
+
|
| 178 |
+
images.append(image)
|
| 179 |
+
|
| 180 |
+
im = numpy.stack(images, axis=3)
|
| 181 |
+
im = numpy.expand_dims(im, axis=0)
|
| 182 |
+
im = numpy.transpose(im, (0, 3, 4, 1, 2))
|
| 183 |
+
|
| 184 |
+
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
|
| 185 |
+
|
| 186 |
+
# ========== ==========
|
| 187 |
+
# Generate video feats
|
| 188 |
+
# ========== ==========
|
| 189 |
+
|
| 190 |
+
lastframe = len(images) - 4
|
| 191 |
+
im_feat = []
|
| 192 |
+
|
| 193 |
+
tS = time.time()
|
| 194 |
+
for i in range(0, lastframe, opt.batch_size):
|
| 195 |
+
|
| 196 |
+
im_batch = [
|
| 197 |
+
imtv[:, :, vframe : vframe + 5, :, :] for vframe in range(i, min(lastframe, i + opt.batch_size))
|
| 198 |
+
]
|
| 199 |
+
im_in = torch.cat(im_batch, 0)
|
| 200 |
+
im_out = self.__S__.forward_lipfeat(im_in.to(self.device))
|
| 201 |
+
im_feat.append(im_out.data.cpu())
|
| 202 |
+
|
| 203 |
+
im_feat = torch.cat(im_feat, 0)
|
| 204 |
+
|
| 205 |
+
# ========== ==========
|
| 206 |
+
# Compute offset
|
| 207 |
+
# ========== ==========
|
| 208 |
+
|
| 209 |
+
print("Compute time %.3f sec." % (time.time() - tS))
|
| 210 |
+
|
| 211 |
+
return im_feat
|
| 212 |
+
|
| 213 |
+
def loadParameters(self, path):
|
| 214 |
+
loaded_state = torch.load(path, map_location=lambda storage, loc: storage)
|
| 215 |
+
|
| 216 |
+
self_state = self.__S__.state_dict()
|
| 217 |
+
|
| 218 |
+
for name, param in loaded_state.items():
|
| 219 |
+
|
| 220 |
+
self_state[name].copy_(param)
|
LatentSync/eval/syncnet_detect.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/joonson/syncnet_python/blob/master/run_pipeline.py
|
| 2 |
+
|
| 3 |
+
import os, pdb, subprocess, glob, cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from shutil import rmtree
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from scenedetect.video_manager import VideoManager
|
| 9 |
+
from scenedetect.scene_manager import SceneManager
|
| 10 |
+
from scenedetect.stats_manager import StatsManager
|
| 11 |
+
from scenedetect.detectors import ContentDetector
|
| 12 |
+
|
| 13 |
+
from scipy.interpolate import interp1d
|
| 14 |
+
from scipy.io import wavfile
|
| 15 |
+
from scipy import signal
|
| 16 |
+
|
| 17 |
+
from eval.detectors import S3FD
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SyncNetDetector:
|
| 21 |
+
def __init__(self, device, detect_results_dir="detect_results"):
|
| 22 |
+
self.s3f_detector = S3FD(device=device)
|
| 23 |
+
self.detect_results_dir = detect_results_dir
|
| 24 |
+
|
| 25 |
+
def __call__(self, video_path: str, min_track=50, scale=False):
|
| 26 |
+
crop_dir = os.path.join(self.detect_results_dir, "crop")
|
| 27 |
+
video_dir = os.path.join(self.detect_results_dir, "video")
|
| 28 |
+
frames_dir = os.path.join(self.detect_results_dir, "frames")
|
| 29 |
+
temp_dir = os.path.join(self.detect_results_dir, "temp")
|
| 30 |
+
|
| 31 |
+
# ========== DELETE EXISTING DIRECTORIES ==========
|
| 32 |
+
if os.path.exists(crop_dir):
|
| 33 |
+
rmtree(crop_dir)
|
| 34 |
+
|
| 35 |
+
if os.path.exists(video_dir):
|
| 36 |
+
rmtree(video_dir)
|
| 37 |
+
|
| 38 |
+
if os.path.exists(frames_dir):
|
| 39 |
+
rmtree(frames_dir)
|
| 40 |
+
|
| 41 |
+
if os.path.exists(temp_dir):
|
| 42 |
+
rmtree(temp_dir)
|
| 43 |
+
|
| 44 |
+
# ========== MAKE NEW DIRECTORIES ==========
|
| 45 |
+
|
| 46 |
+
os.makedirs(crop_dir)
|
| 47 |
+
os.makedirs(video_dir)
|
| 48 |
+
os.makedirs(frames_dir)
|
| 49 |
+
os.makedirs(temp_dir)
|
| 50 |
+
|
| 51 |
+
# ========== CONVERT VIDEO AND EXTRACT FRAMES ==========
|
| 52 |
+
|
| 53 |
+
if scale:
|
| 54 |
+
scaled_video_path = os.path.join(video_dir, "scaled.mp4")
|
| 55 |
+
command = f"ffmpeg -loglevel error -y -nostdin -i {video_path} -vf scale='224:224' {scaled_video_path}"
|
| 56 |
+
subprocess.run(command, shell=True)
|
| 57 |
+
video_path = scaled_video_path
|
| 58 |
+
|
| 59 |
+
command = f"ffmpeg -y -nostdin -loglevel error -i {video_path} -qscale:v 2 -async 1 -r 25 {os.path.join(video_dir, 'video.mp4')}"
|
| 60 |
+
subprocess.run(command, shell=True, stdout=None)
|
| 61 |
+
|
| 62 |
+
command = f"ffmpeg -y -nostdin -loglevel error -i {os.path.join(video_dir, 'video.mp4')} -qscale:v 2 -f image2 {os.path.join(frames_dir, '%06d.jpg')}"
|
| 63 |
+
subprocess.run(command, shell=True, stdout=None)
|
| 64 |
+
|
| 65 |
+
command = f"ffmpeg -y -nostdin -loglevel error -i {os.path.join(video_dir, 'video.mp4')} -ac 1 -vn -acodec pcm_s16le -ar 16000 {os.path.join(video_dir, 'audio.wav')}"
|
| 66 |
+
subprocess.run(command, shell=True, stdout=None)
|
| 67 |
+
|
| 68 |
+
faces = self.detect_face(frames_dir)
|
| 69 |
+
|
| 70 |
+
scene = self.scene_detect(video_dir)
|
| 71 |
+
|
| 72 |
+
# Face tracking
|
| 73 |
+
alltracks = []
|
| 74 |
+
|
| 75 |
+
for shot in scene:
|
| 76 |
+
if shot[1].frame_num - shot[0].frame_num >= min_track:
|
| 77 |
+
alltracks.extend(self.track_face(faces[shot[0].frame_num : shot[1].frame_num], min_track=min_track))
|
| 78 |
+
|
| 79 |
+
# Face crop
|
| 80 |
+
for ii, track in enumerate(alltracks):
|
| 81 |
+
self.crop_video(track, os.path.join(crop_dir, "%05d" % ii), frames_dir, 25, temp_dir, video_dir)
|
| 82 |
+
|
| 83 |
+
rmtree(temp_dir)
|
| 84 |
+
|
| 85 |
+
def scene_detect(self, video_dir):
|
| 86 |
+
video_manager = VideoManager([os.path.join(video_dir, "video.mp4")])
|
| 87 |
+
stats_manager = StatsManager()
|
| 88 |
+
scene_manager = SceneManager(stats_manager)
|
| 89 |
+
# Add ContentDetector algorithm (constructor takes detector options like threshold).
|
| 90 |
+
scene_manager.add_detector(ContentDetector())
|
| 91 |
+
base_timecode = video_manager.get_base_timecode()
|
| 92 |
+
|
| 93 |
+
video_manager.set_downscale_factor()
|
| 94 |
+
|
| 95 |
+
video_manager.start()
|
| 96 |
+
|
| 97 |
+
scene_manager.detect_scenes(frame_source=video_manager)
|
| 98 |
+
|
| 99 |
+
scene_list = scene_manager.get_scene_list(base_timecode)
|
| 100 |
+
|
| 101 |
+
if scene_list == []:
|
| 102 |
+
scene_list = [(video_manager.get_base_timecode(), video_manager.get_current_timecode())]
|
| 103 |
+
|
| 104 |
+
return scene_list
|
| 105 |
+
|
| 106 |
+
def track_face(self, scenefaces, num_failed_det=25, min_track=50, min_face_size=100):
|
| 107 |
+
|
| 108 |
+
iouThres = 0.5 # Minimum IOU between consecutive face detections
|
| 109 |
+
tracks = []
|
| 110 |
+
|
| 111 |
+
while True:
|
| 112 |
+
track = []
|
| 113 |
+
for framefaces in scenefaces:
|
| 114 |
+
for face in framefaces:
|
| 115 |
+
if track == []:
|
| 116 |
+
track.append(face)
|
| 117 |
+
framefaces.remove(face)
|
| 118 |
+
elif face["frame"] - track[-1]["frame"] <= num_failed_det:
|
| 119 |
+
iou = bounding_box_iou(face["bbox"], track[-1]["bbox"])
|
| 120 |
+
if iou > iouThres:
|
| 121 |
+
track.append(face)
|
| 122 |
+
framefaces.remove(face)
|
| 123 |
+
continue
|
| 124 |
+
else:
|
| 125 |
+
break
|
| 126 |
+
|
| 127 |
+
if track == []:
|
| 128 |
+
break
|
| 129 |
+
elif len(track) > min_track:
|
| 130 |
+
|
| 131 |
+
framenum = np.array([f["frame"] for f in track])
|
| 132 |
+
bboxes = np.array([np.array(f["bbox"]) for f in track])
|
| 133 |
+
|
| 134 |
+
frame_i = np.arange(framenum[0], framenum[-1] + 1)
|
| 135 |
+
|
| 136 |
+
bboxes_i = []
|
| 137 |
+
for ij in range(0, 4):
|
| 138 |
+
interpfn = interp1d(framenum, bboxes[:, ij])
|
| 139 |
+
bboxes_i.append(interpfn(frame_i))
|
| 140 |
+
bboxes_i = np.stack(bboxes_i, axis=1)
|
| 141 |
+
|
| 142 |
+
if (
|
| 143 |
+
max(np.mean(bboxes_i[:, 2] - bboxes_i[:, 0]), np.mean(bboxes_i[:, 3] - bboxes_i[:, 1]))
|
| 144 |
+
> min_face_size
|
| 145 |
+
):
|
| 146 |
+
tracks.append({"frame": frame_i, "bbox": bboxes_i})
|
| 147 |
+
|
| 148 |
+
return tracks
|
| 149 |
+
|
| 150 |
+
def detect_face(self, frames_dir, facedet_scale=0.25):
|
| 151 |
+
flist = glob.glob(os.path.join(frames_dir, "*.jpg"))
|
| 152 |
+
flist.sort()
|
| 153 |
+
|
| 154 |
+
dets = []
|
| 155 |
+
|
| 156 |
+
for fidx, fname in enumerate(flist):
|
| 157 |
+
image = cv2.imread(fname)
|
| 158 |
+
|
| 159 |
+
image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 160 |
+
bboxes = self.s3f_detector.detect_faces(image_np, conf_th=0.9, scales=[facedet_scale])
|
| 161 |
+
|
| 162 |
+
dets.append([])
|
| 163 |
+
for bbox in bboxes:
|
| 164 |
+
dets[-1].append({"frame": fidx, "bbox": (bbox[:-1]).tolist(), "conf": bbox[-1]})
|
| 165 |
+
|
| 166 |
+
return dets
|
| 167 |
+
|
| 168 |
+
def crop_video(self, track, cropfile, frames_dir, frame_rate, temp_dir, video_dir, crop_scale=0.4):
|
| 169 |
+
|
| 170 |
+
flist = glob.glob(os.path.join(frames_dir, "*.jpg"))
|
| 171 |
+
flist.sort()
|
| 172 |
+
|
| 173 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 174 |
+
vOut = cv2.VideoWriter(cropfile + "t.mp4", fourcc, frame_rate, (224, 224))
|
| 175 |
+
|
| 176 |
+
dets = {"x": [], "y": [], "s": []}
|
| 177 |
+
|
| 178 |
+
for det in track["bbox"]:
|
| 179 |
+
|
| 180 |
+
dets["s"].append(max((det[3] - det[1]), (det[2] - det[0])) / 2)
|
| 181 |
+
dets["y"].append((det[1] + det[3]) / 2) # crop center x
|
| 182 |
+
dets["x"].append((det[0] + det[2]) / 2) # crop center y
|
| 183 |
+
|
| 184 |
+
# Smooth detections
|
| 185 |
+
dets["s"] = signal.medfilt(dets["s"], kernel_size=13)
|
| 186 |
+
dets["x"] = signal.medfilt(dets["x"], kernel_size=13)
|
| 187 |
+
dets["y"] = signal.medfilt(dets["y"], kernel_size=13)
|
| 188 |
+
|
| 189 |
+
for fidx, frame in enumerate(track["frame"]):
|
| 190 |
+
|
| 191 |
+
cs = crop_scale
|
| 192 |
+
|
| 193 |
+
bs = dets["s"][fidx] # Detection box size
|
| 194 |
+
bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount
|
| 195 |
+
|
| 196 |
+
image = cv2.imread(flist[frame])
|
| 197 |
+
|
| 198 |
+
frame = np.pad(image, ((bsi, bsi), (bsi, bsi), (0, 0)), "constant", constant_values=(110, 110))
|
| 199 |
+
my = dets["y"][fidx] + bsi # BBox center Y
|
| 200 |
+
mx = dets["x"][fidx] + bsi # BBox center X
|
| 201 |
+
|
| 202 |
+
face = frame[int(my - bs) : int(my + bs * (1 + 2 * cs)), int(mx - bs * (1 + cs)) : int(mx + bs * (1 + cs))]
|
| 203 |
+
|
| 204 |
+
vOut.write(cv2.resize(face, (224, 224)))
|
| 205 |
+
|
| 206 |
+
audiotmp = os.path.join(temp_dir, "audio.wav")
|
| 207 |
+
audiostart = (track["frame"][0]) / frame_rate
|
| 208 |
+
audioend = (track["frame"][-1] + 1) / frame_rate
|
| 209 |
+
|
| 210 |
+
vOut.release()
|
| 211 |
+
|
| 212 |
+
# ========== CROP AUDIO FILE ==========
|
| 213 |
+
|
| 214 |
+
command = "ffmpeg -y -nostdin -loglevel error -i %s -ss %.3f -to %.3f %s" % (
|
| 215 |
+
os.path.join(video_dir, "audio.wav"),
|
| 216 |
+
audiostart,
|
| 217 |
+
audioend,
|
| 218 |
+
audiotmp,
|
| 219 |
+
)
|
| 220 |
+
output = subprocess.run(command, shell=True, stdout=None)
|
| 221 |
+
|
| 222 |
+
sample_rate, audio = wavfile.read(audiotmp)
|
| 223 |
+
|
| 224 |
+
# ========== COMBINE AUDIO AND VIDEO FILES ==========
|
| 225 |
+
|
| 226 |
+
command = "ffmpeg -y -nostdin -loglevel error -i %st.mp4 -i %s -c:v copy -c:a aac %s.mp4" % (
|
| 227 |
+
cropfile,
|
| 228 |
+
audiotmp,
|
| 229 |
+
cropfile,
|
| 230 |
+
)
|
| 231 |
+
output = subprocess.run(command, shell=True, stdout=None)
|
| 232 |
+
|
| 233 |
+
os.remove(cropfile + "t.mp4")
|
| 234 |
+
|
| 235 |
+
return {"track": track, "proc_track": dets}
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def bounding_box_iou(boxA, boxB):
|
| 239 |
+
xA = max(boxA[0], boxB[0])
|
| 240 |
+
yA = max(boxA[1], boxB[1])
|
| 241 |
+
xB = min(boxA[2], boxB[2])
|
| 242 |
+
yB = min(boxA[3], boxB[3])
|
| 243 |
+
|
| 244 |
+
interArea = max(0, xB - xA) * max(0, yB - yA)
|
| 245 |
+
|
| 246 |
+
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
| 247 |
+
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
| 248 |
+
|
| 249 |
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
| 250 |
+
|
| 251 |
+
return iou
|
LatentSync/inference.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
python -m scripts.inference \
|
| 4 |
+
--unet_config_path "configs/unet/second_stage.yaml" \
|
| 5 |
+
--inference_ckpt_path "checkpoints/latentsync_unet.pt" \
|
| 6 |
+
--guidance_scale 1.0 \
|
| 7 |
+
--video_path "assets/demo1_video.mp4" \
|
| 8 |
+
--audio_path "assets/demo1_audio.wav" \
|
| 9 |
+
--video_out_path "video_out.mp4"
|
LatentSync/latentsync/data/syncnet_dataset.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import numpy as np
|
| 17 |
+
from torch.utils.data import Dataset
|
| 18 |
+
import torch
|
| 19 |
+
import random
|
| 20 |
+
from ..utils.util import gather_video_paths_recursively
|
| 21 |
+
from ..utils.image_processor import ImageProcessor
|
| 22 |
+
from ..utils.audio import melspectrogram
|
| 23 |
+
import math
|
| 24 |
+
|
| 25 |
+
from decord import AudioReader, VideoReader, cpu
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SyncNetDataset(Dataset):
|
| 29 |
+
def __init__(self, data_dir: str, fileslist: str, config):
|
| 30 |
+
if fileslist != "":
|
| 31 |
+
with open(fileslist) as file:
|
| 32 |
+
self.video_paths = [line.rstrip() for line in file]
|
| 33 |
+
elif data_dir != "":
|
| 34 |
+
self.video_paths = gather_video_paths_recursively(data_dir)
|
| 35 |
+
else:
|
| 36 |
+
raise ValueError("data_dir and fileslist cannot be both empty")
|
| 37 |
+
|
| 38 |
+
self.resolution = config.data.resolution
|
| 39 |
+
self.num_frames = config.data.num_frames
|
| 40 |
+
|
| 41 |
+
self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
|
| 42 |
+
|
| 43 |
+
self.audio_sample_rate = config.data.audio_sample_rate
|
| 44 |
+
self.video_fps = config.data.video_fps
|
| 45 |
+
self.audio_samples_length = int(
|
| 46 |
+
config.data.audio_sample_rate // config.data.video_fps * config.data.num_frames
|
| 47 |
+
)
|
| 48 |
+
self.image_processor = ImageProcessor(resolution=config.data.resolution, mask="half")
|
| 49 |
+
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
|
| 50 |
+
os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
return len(self.video_paths)
|
| 54 |
+
|
| 55 |
+
def read_audio(self, video_path: str):
|
| 56 |
+
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
|
| 57 |
+
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
|
| 58 |
+
return torch.from_numpy(original_mel)
|
| 59 |
+
|
| 60 |
+
def crop_audio_window(self, original_mel, start_index):
|
| 61 |
+
start_idx = int(80.0 * (start_index / float(self.video_fps)))
|
| 62 |
+
end_idx = start_idx + self.mel_window_length
|
| 63 |
+
return original_mel[:, start_idx:end_idx].unsqueeze(0)
|
| 64 |
+
|
| 65 |
+
def get_frames(self, video_reader: VideoReader):
|
| 66 |
+
total_num_frames = len(video_reader)
|
| 67 |
+
|
| 68 |
+
start_idx = random.randint(0, total_num_frames - self.num_frames)
|
| 69 |
+
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
|
| 70 |
+
|
| 71 |
+
while True:
|
| 72 |
+
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
|
| 73 |
+
# wrong_start_idx = random.randint(
|
| 74 |
+
# max(0, start_idx - 25), min(total_num_frames - self.num_frames, start_idx + 25)
|
| 75 |
+
# )
|
| 76 |
+
if wrong_start_idx == start_idx:
|
| 77 |
+
continue
|
| 78 |
+
# if wrong_start_idx >= start_idx - self.num_frames and wrong_start_idx <= start_idx + self.num_frames:
|
| 79 |
+
# continue
|
| 80 |
+
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
frames = video_reader.get_batch(frames_index).asnumpy()
|
| 84 |
+
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
|
| 85 |
+
|
| 86 |
+
return frames, wrong_frames, start_idx
|
| 87 |
+
|
| 88 |
+
def worker_init_fn(self, worker_id):
|
| 89 |
+
# Initialize the face mesh object in each worker process,
|
| 90 |
+
# because the face mesh object cannot be called in subprocesses
|
| 91 |
+
self.worker_id = worker_id
|
| 92 |
+
# setattr(self, f"image_processor_{worker_id}", ImageProcessor(self.resolution, self.mask))
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, idx):
|
| 95 |
+
# image_processor = getattr(self, f"image_processor_{self.worker_id}")
|
| 96 |
+
while True:
|
| 97 |
+
try:
|
| 98 |
+
idx = random.randint(0, len(self) - 1)
|
| 99 |
+
|
| 100 |
+
# Get video file path
|
| 101 |
+
video_path = self.video_paths[idx]
|
| 102 |
+
|
| 103 |
+
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
|
| 104 |
+
|
| 105 |
+
if len(vr) < 2 * self.num_frames:
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
frames, wrong_frames, start_idx = self.get_frames(vr)
|
| 109 |
+
|
| 110 |
+
mel_cache_path = os.path.join(
|
| 111 |
+
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if os.path.isfile(mel_cache_path):
|
| 115 |
+
try:
|
| 116 |
+
original_mel = torch.load(mel_cache_path)
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
|
| 119 |
+
os.remove(mel_cache_path)
|
| 120 |
+
original_mel = self.read_audio(video_path)
|
| 121 |
+
torch.save(original_mel, mel_cache_path)
|
| 122 |
+
else:
|
| 123 |
+
original_mel = self.read_audio(video_path)
|
| 124 |
+
torch.save(original_mel, mel_cache_path)
|
| 125 |
+
|
| 126 |
+
mel = self.crop_audio_window(original_mel, start_idx)
|
| 127 |
+
|
| 128 |
+
if mel.shape[-1] != self.mel_window_length:
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
if random.choice([True, False]):
|
| 132 |
+
y = torch.ones(1).float()
|
| 133 |
+
chosen_frames = frames
|
| 134 |
+
else:
|
| 135 |
+
y = torch.zeros(1).float()
|
| 136 |
+
chosen_frames = wrong_frames
|
| 137 |
+
|
| 138 |
+
chosen_frames = self.image_processor.process_images(chosen_frames)
|
| 139 |
+
# chosen_frames, _, _ = image_processor.prepare_masks_and_masked_images(
|
| 140 |
+
# chosen_frames, affine_transform=True
|
| 141 |
+
# )
|
| 142 |
+
|
| 143 |
+
vr.seek(0) # avoid memory leak
|
| 144 |
+
break
|
| 145 |
+
|
| 146 |
+
except Exception as e: # Handle the exception of face not detcted
|
| 147 |
+
print(f"{type(e).__name__} - {e} - {video_path}")
|
| 148 |
+
if "vr" in locals():
|
| 149 |
+
vr.seek(0) # avoid memory leak
|
| 150 |
+
|
| 151 |
+
sample = dict(frames=chosen_frames, audio_samples=mel, y=y)
|
| 152 |
+
|
| 153 |
+
return sample
|
LatentSync/latentsync/data/unet_dataset.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import numpy as np
|
| 17 |
+
from torch.utils.data import Dataset
|
| 18 |
+
import torch
|
| 19 |
+
import random
|
| 20 |
+
import cv2
|
| 21 |
+
from ..utils.image_processor import ImageProcessor, load_fixed_mask
|
| 22 |
+
from ..utils.audio import melspectrogram
|
| 23 |
+
from decord import AudioReader, VideoReader, cpu
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class UNetDataset(Dataset):
|
| 27 |
+
def __init__(self, train_data_dir: str, config):
|
| 28 |
+
if config.data.train_fileslist != "":
|
| 29 |
+
with open(config.data.train_fileslist) as file:
|
| 30 |
+
self.video_paths = [line.rstrip() for line in file]
|
| 31 |
+
elif train_data_dir != "":
|
| 32 |
+
self.video_paths = []
|
| 33 |
+
for file in os.listdir(train_data_dir):
|
| 34 |
+
if file.endswith(".mp4"):
|
| 35 |
+
self.video_paths.append(os.path.join(train_data_dir, file))
|
| 36 |
+
else:
|
| 37 |
+
raise ValueError("data_dir and fileslist cannot be both empty")
|
| 38 |
+
|
| 39 |
+
self.resolution = config.data.resolution
|
| 40 |
+
self.num_frames = config.data.num_frames
|
| 41 |
+
|
| 42 |
+
if self.num_frames == 16:
|
| 43 |
+
self.mel_window_length = 52
|
| 44 |
+
elif self.num_frames == 5:
|
| 45 |
+
self.mel_window_length = 16
|
| 46 |
+
else:
|
| 47 |
+
raise NotImplementedError("Only support 16 and 5 frames now")
|
| 48 |
+
|
| 49 |
+
self.audio_sample_rate = config.data.audio_sample_rate
|
| 50 |
+
self.video_fps = config.data.video_fps
|
| 51 |
+
self.mask = config.data.mask
|
| 52 |
+
self.mask_image = load_fixed_mask(self.resolution)
|
| 53 |
+
self.load_audio_data = config.model.add_audio_layer and config.run.use_syncnet
|
| 54 |
+
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
|
| 55 |
+
os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return len(self.video_paths)
|
| 59 |
+
|
| 60 |
+
def read_audio(self, video_path: str):
|
| 61 |
+
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
|
| 62 |
+
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
|
| 63 |
+
return torch.from_numpy(original_mel)
|
| 64 |
+
|
| 65 |
+
def crop_audio_window(self, original_mel, start_index):
|
| 66 |
+
start_idx = int(80.0 * (start_index / float(self.video_fps)))
|
| 67 |
+
end_idx = start_idx + self.mel_window_length
|
| 68 |
+
return original_mel[:, start_idx:end_idx].unsqueeze(0)
|
| 69 |
+
|
| 70 |
+
def get_frames(self, video_reader: VideoReader):
|
| 71 |
+
total_num_frames = len(video_reader)
|
| 72 |
+
|
| 73 |
+
start_idx = random.randint(self.num_frames // 2, total_num_frames - self.num_frames - self.num_frames // 2)
|
| 74 |
+
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
|
| 75 |
+
|
| 76 |
+
while True:
|
| 77 |
+
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
|
| 78 |
+
if wrong_start_idx > start_idx - self.num_frames and wrong_start_idx < start_idx + self.num_frames:
|
| 79 |
+
continue
|
| 80 |
+
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
frames = video_reader.get_batch(frames_index).asnumpy()
|
| 84 |
+
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
|
| 85 |
+
|
| 86 |
+
return frames, wrong_frames, start_idx
|
| 87 |
+
|
| 88 |
+
def worker_init_fn(self, worker_id):
|
| 89 |
+
# Initialize the face mesh object in each worker process,
|
| 90 |
+
# because the face mesh object cannot be called in subprocesses
|
| 91 |
+
self.worker_id = worker_id
|
| 92 |
+
setattr(
|
| 93 |
+
self,
|
| 94 |
+
f"image_processor_{worker_id}",
|
| 95 |
+
ImageProcessor(self.resolution, self.mask, mask_image=self.mask_image),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def __getitem__(self, idx):
|
| 99 |
+
image_processor = getattr(self, f"image_processor_{self.worker_id}")
|
| 100 |
+
while True:
|
| 101 |
+
try:
|
| 102 |
+
idx = random.randint(0, len(self) - 1)
|
| 103 |
+
|
| 104 |
+
# Get video file path
|
| 105 |
+
video_path = self.video_paths[idx]
|
| 106 |
+
|
| 107 |
+
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
|
| 108 |
+
|
| 109 |
+
if len(vr) < 3 * self.num_frames:
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
continuous_frames, ref_frames, start_idx = self.get_frames(vr)
|
| 113 |
+
|
| 114 |
+
if self.load_audio_data:
|
| 115 |
+
mel_cache_path = os.path.join(
|
| 116 |
+
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if os.path.isfile(mel_cache_path):
|
| 120 |
+
try:
|
| 121 |
+
original_mel = torch.load(mel_cache_path)
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
|
| 124 |
+
os.remove(mel_cache_path)
|
| 125 |
+
original_mel = self.read_audio(video_path)
|
| 126 |
+
torch.save(original_mel, mel_cache_path)
|
| 127 |
+
else:
|
| 128 |
+
original_mel = self.read_audio(video_path)
|
| 129 |
+
torch.save(original_mel, mel_cache_path)
|
| 130 |
+
|
| 131 |
+
mel = self.crop_audio_window(original_mel, start_idx)
|
| 132 |
+
|
| 133 |
+
if mel.shape[-1] != self.mel_window_length:
|
| 134 |
+
continue
|
| 135 |
+
else:
|
| 136 |
+
mel = []
|
| 137 |
+
|
| 138 |
+
gt, masked_gt, mask = image_processor.prepare_masks_and_masked_images(
|
| 139 |
+
continuous_frames, affine_transform=False
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
if self.mask == "fix_mask":
|
| 143 |
+
ref, _, _ = image_processor.prepare_masks_and_masked_images(ref_frames, affine_transform=False)
|
| 144 |
+
else:
|
| 145 |
+
ref = image_processor.process_images(ref_frames)
|
| 146 |
+
vr.seek(0) # avoid memory leak
|
| 147 |
+
break
|
| 148 |
+
|
| 149 |
+
except Exception as e: # Handle the exception of face not detcted
|
| 150 |
+
print(f"{type(e).__name__} - {e} - {video_path}")
|
| 151 |
+
if "vr" in locals():
|
| 152 |
+
vr.seek(0) # avoid memory leak
|
| 153 |
+
|
| 154 |
+
sample = dict(
|
| 155 |
+
gt=gt,
|
| 156 |
+
masked_gt=masked_gt,
|
| 157 |
+
ref=ref,
|
| 158 |
+
mel=mel,
|
| 159 |
+
mask=mask,
|
| 160 |
+
video_path=video_path,
|
| 161 |
+
start_idx=start_idx,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return sample
|
LatentSync/latentsync/models/attention.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from turtle import forward
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers.utils import BaseOutput
|
| 14 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 15 |
+
from diffusers.models.attention import Attention as CrossAttention, FeedForward, AdaLayerNorm
|
| 16 |
+
|
| 17 |
+
from einops import rearrange, repeat
|
| 18 |
+
from .utils import zero_module
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class Transformer3DModelOutput(BaseOutput):
|
| 23 |
+
sample: torch.FloatTensor
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_xformers_available():
|
| 27 |
+
import xformers
|
| 28 |
+
import xformers.ops
|
| 29 |
+
else:
|
| 30 |
+
xformers = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
| 34 |
+
@register_to_config
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
num_attention_heads: int = 16,
|
| 38 |
+
attention_head_dim: int = 88,
|
| 39 |
+
in_channels: Optional[int] = None,
|
| 40 |
+
num_layers: int = 1,
|
| 41 |
+
dropout: float = 0.0,
|
| 42 |
+
norm_num_groups: int = 32,
|
| 43 |
+
cross_attention_dim: Optional[int] = None,
|
| 44 |
+
attention_bias: bool = False,
|
| 45 |
+
activation_fn: str = "geglu",
|
| 46 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 47 |
+
use_linear_projection: bool = False,
|
| 48 |
+
only_cross_attention: bool = False,
|
| 49 |
+
upcast_attention: bool = False,
|
| 50 |
+
use_motion_module: bool = False,
|
| 51 |
+
unet_use_cross_frame_attention=None,
|
| 52 |
+
unet_use_temporal_attention=None,
|
| 53 |
+
add_audio_layer=False,
|
| 54 |
+
audio_condition_method="cross_attn",
|
| 55 |
+
custom_audio_layer: bool = False,
|
| 56 |
+
):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.use_linear_projection = use_linear_projection
|
| 59 |
+
self.num_attention_heads = num_attention_heads
|
| 60 |
+
self.attention_head_dim = attention_head_dim
|
| 61 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 62 |
+
|
| 63 |
+
# Define input layers
|
| 64 |
+
self.in_channels = in_channels
|
| 65 |
+
|
| 66 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 67 |
+
if use_linear_projection:
|
| 68 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 69 |
+
else:
|
| 70 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 71 |
+
|
| 72 |
+
if not custom_audio_layer:
|
| 73 |
+
# Define transformers blocks
|
| 74 |
+
self.transformer_blocks = nn.ModuleList(
|
| 75 |
+
[
|
| 76 |
+
BasicTransformerBlock(
|
| 77 |
+
inner_dim,
|
| 78 |
+
num_attention_heads,
|
| 79 |
+
attention_head_dim,
|
| 80 |
+
dropout=dropout,
|
| 81 |
+
cross_attention_dim=cross_attention_dim,
|
| 82 |
+
activation_fn=activation_fn,
|
| 83 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 84 |
+
attention_bias=attention_bias,
|
| 85 |
+
only_cross_attention=only_cross_attention,
|
| 86 |
+
upcast_attention=upcast_attention,
|
| 87 |
+
use_motion_module=use_motion_module,
|
| 88 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 89 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 90 |
+
add_audio_layer=add_audio_layer,
|
| 91 |
+
custom_audio_layer=custom_audio_layer,
|
| 92 |
+
audio_condition_method=audio_condition_method,
|
| 93 |
+
)
|
| 94 |
+
for d in range(num_layers)
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
self.transformer_blocks = nn.ModuleList(
|
| 99 |
+
[
|
| 100 |
+
AudioTransformerBlock(
|
| 101 |
+
inner_dim,
|
| 102 |
+
num_attention_heads,
|
| 103 |
+
attention_head_dim,
|
| 104 |
+
dropout=dropout,
|
| 105 |
+
cross_attention_dim=cross_attention_dim,
|
| 106 |
+
activation_fn=activation_fn,
|
| 107 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 108 |
+
attention_bias=attention_bias,
|
| 109 |
+
only_cross_attention=only_cross_attention,
|
| 110 |
+
upcast_attention=upcast_attention,
|
| 111 |
+
use_motion_module=use_motion_module,
|
| 112 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 113 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 114 |
+
add_audio_layer=add_audio_layer,
|
| 115 |
+
)
|
| 116 |
+
for d in range(num_layers)
|
| 117 |
+
]
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# 4. Define output layers
|
| 121 |
+
if use_linear_projection:
|
| 122 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
| 123 |
+
else:
|
| 124 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
| 125 |
+
|
| 126 |
+
if custom_audio_layer:
|
| 127 |
+
self.proj_out = zero_module(self.proj_out)
|
| 128 |
+
|
| 129 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
| 130 |
+
# Input
|
| 131 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
| 132 |
+
video_length = hidden_states.shape[2]
|
| 133 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
| 134 |
+
|
| 135 |
+
# No need to do this for audio input, because different audio samples are independent
|
| 136 |
+
# encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
| 137 |
+
|
| 138 |
+
batch, channel, height, weight = hidden_states.shape
|
| 139 |
+
residual = hidden_states
|
| 140 |
+
|
| 141 |
+
hidden_states = self.norm(hidden_states)
|
| 142 |
+
if not self.use_linear_projection:
|
| 143 |
+
hidden_states = self.proj_in(hidden_states)
|
| 144 |
+
inner_dim = hidden_states.shape[1]
|
| 145 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
| 146 |
+
else:
|
| 147 |
+
inner_dim = hidden_states.shape[1]
|
| 148 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
| 149 |
+
hidden_states = self.proj_in(hidden_states)
|
| 150 |
+
|
| 151 |
+
# Blocks
|
| 152 |
+
for block in self.transformer_blocks:
|
| 153 |
+
hidden_states = block(
|
| 154 |
+
hidden_states,
|
| 155 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 156 |
+
timestep=timestep,
|
| 157 |
+
video_length=video_length,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Output
|
| 161 |
+
if not self.use_linear_projection:
|
| 162 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 163 |
+
hidden_states = self.proj_out(hidden_states)
|
| 164 |
+
else:
|
| 165 |
+
hidden_states = self.proj_out(hidden_states)
|
| 166 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 167 |
+
|
| 168 |
+
output = hidden_states + residual
|
| 169 |
+
|
| 170 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
| 171 |
+
if not return_dict:
|
| 172 |
+
return (output,)
|
| 173 |
+
|
| 174 |
+
return Transformer3DModelOutput(sample=output)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class BasicTransformerBlock(nn.Module):
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
dim: int,
|
| 181 |
+
num_attention_heads: int,
|
| 182 |
+
attention_head_dim: int,
|
| 183 |
+
dropout=0.0,
|
| 184 |
+
cross_attention_dim: Optional[int] = None,
|
| 185 |
+
activation_fn: str = "geglu",
|
| 186 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 187 |
+
attention_bias: bool = False,
|
| 188 |
+
only_cross_attention: bool = False,
|
| 189 |
+
upcast_attention: bool = False,
|
| 190 |
+
use_motion_module: bool = False,
|
| 191 |
+
unet_use_cross_frame_attention=None,
|
| 192 |
+
unet_use_temporal_attention=None,
|
| 193 |
+
add_audio_layer=False,
|
| 194 |
+
custom_audio_layer=False,
|
| 195 |
+
audio_condition_method="cross_attn",
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.only_cross_attention = only_cross_attention
|
| 199 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
| 200 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
| 201 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
| 202 |
+
self.use_motion_module = use_motion_module
|
| 203 |
+
self.add_audio_layer = add_audio_layer
|
| 204 |
+
|
| 205 |
+
# SC-Attn
|
| 206 |
+
assert unet_use_cross_frame_attention is not None
|
| 207 |
+
if unet_use_cross_frame_attention:
|
| 208 |
+
raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
|
| 209 |
+
else:
|
| 210 |
+
self.attn1 = CrossAttention(
|
| 211 |
+
query_dim=dim,
|
| 212 |
+
heads=num_attention_heads,
|
| 213 |
+
dim_head=attention_head_dim,
|
| 214 |
+
dropout=dropout,
|
| 215 |
+
bias=attention_bias,
|
| 216 |
+
upcast_attention=upcast_attention,
|
| 217 |
+
)
|
| 218 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
| 219 |
+
|
| 220 |
+
# Cross-Attn
|
| 221 |
+
if add_audio_layer and audio_condition_method == "cross_attn" and not custom_audio_layer:
|
| 222 |
+
self.audio_cross_attn = AudioCrossAttn(
|
| 223 |
+
dim=dim,
|
| 224 |
+
cross_attention_dim=cross_attention_dim,
|
| 225 |
+
num_attention_heads=num_attention_heads,
|
| 226 |
+
attention_head_dim=attention_head_dim,
|
| 227 |
+
dropout=dropout,
|
| 228 |
+
attention_bias=attention_bias,
|
| 229 |
+
upcast_attention=upcast_attention,
|
| 230 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 231 |
+
use_ada_layer_norm=self.use_ada_layer_norm,
|
| 232 |
+
zero_proj_out=False,
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
self.audio_cross_attn = None
|
| 236 |
+
|
| 237 |
+
# Feed-forward
|
| 238 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
| 239 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 240 |
+
|
| 241 |
+
# Temp-Attn
|
| 242 |
+
assert unet_use_temporal_attention is not None
|
| 243 |
+
if unet_use_temporal_attention:
|
| 244 |
+
self.attn_temp = CrossAttention(
|
| 245 |
+
query_dim=dim,
|
| 246 |
+
heads=num_attention_heads,
|
| 247 |
+
dim_head=attention_head_dim,
|
| 248 |
+
dropout=dropout,
|
| 249 |
+
bias=attention_bias,
|
| 250 |
+
upcast_attention=upcast_attention,
|
| 251 |
+
)
|
| 252 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
| 253 |
+
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
| 254 |
+
|
| 255 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
| 256 |
+
if not is_xformers_available():
|
| 257 |
+
print("Here is how to install it")
|
| 258 |
+
raise ModuleNotFoundError(
|
| 259 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
| 260 |
+
" xformers",
|
| 261 |
+
name="xformers",
|
| 262 |
+
)
|
| 263 |
+
elif not torch.cuda.is_available():
|
| 264 |
+
raise ValueError(
|
| 265 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
| 266 |
+
" available for GPU "
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
try:
|
| 270 |
+
# Make sure we can run the memory efficient attention
|
| 271 |
+
_ = xformers.ops.memory_efficient_attention(
|
| 272 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 273 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 274 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 275 |
+
)
|
| 276 |
+
except Exception as e:
|
| 277 |
+
raise e
|
| 278 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
| 279 |
+
if self.audio_cross_attn is not None:
|
| 280 |
+
self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
|
| 281 |
+
use_memory_efficient_attention_xformers
|
| 282 |
+
)
|
| 283 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
| 284 |
+
|
| 285 |
+
def forward(
|
| 286 |
+
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
|
| 287 |
+
):
|
| 288 |
+
# SparseCausal-Attention
|
| 289 |
+
norm_hidden_states = (
|
| 290 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# if self.only_cross_attention:
|
| 294 |
+
# hidden_states = (
|
| 295 |
+
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
| 296 |
+
# )
|
| 297 |
+
# else:
|
| 298 |
+
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
| 299 |
+
|
| 300 |
+
# pdb.set_trace()
|
| 301 |
+
if self.unet_use_cross_frame_attention:
|
| 302 |
+
hidden_states = (
|
| 303 |
+
self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
|
| 304 |
+
+ hidden_states
|
| 305 |
+
)
|
| 306 |
+
else:
|
| 307 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
| 308 |
+
|
| 309 |
+
if self.audio_cross_attn is not None and encoder_hidden_states is not None:
|
| 310 |
+
hidden_states = self.audio_cross_attn(
|
| 311 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Feed-forward
|
| 315 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
| 316 |
+
|
| 317 |
+
# Temporal-Attention
|
| 318 |
+
if self.unet_use_temporal_attention:
|
| 319 |
+
d = hidden_states.shape[1]
|
| 320 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
| 321 |
+
norm_hidden_states = (
|
| 322 |
+
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
| 323 |
+
)
|
| 324 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
| 325 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
| 326 |
+
|
| 327 |
+
return hidden_states
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class AudioTransformerBlock(nn.Module):
|
| 331 |
+
def __init__(
|
| 332 |
+
self,
|
| 333 |
+
dim: int,
|
| 334 |
+
num_attention_heads: int,
|
| 335 |
+
attention_head_dim: int,
|
| 336 |
+
dropout=0.0,
|
| 337 |
+
cross_attention_dim: Optional[int] = None,
|
| 338 |
+
activation_fn: str = "geglu",
|
| 339 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 340 |
+
attention_bias: bool = False,
|
| 341 |
+
only_cross_attention: bool = False,
|
| 342 |
+
upcast_attention: bool = False,
|
| 343 |
+
use_motion_module: bool = False,
|
| 344 |
+
unet_use_cross_frame_attention=None,
|
| 345 |
+
unet_use_temporal_attention=None,
|
| 346 |
+
add_audio_layer=False,
|
| 347 |
+
):
|
| 348 |
+
super().__init__()
|
| 349 |
+
self.only_cross_attention = only_cross_attention
|
| 350 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
| 351 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
| 352 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
| 353 |
+
self.use_motion_module = use_motion_module
|
| 354 |
+
self.add_audio_layer = add_audio_layer
|
| 355 |
+
|
| 356 |
+
# SC-Attn
|
| 357 |
+
assert unet_use_cross_frame_attention is not None
|
| 358 |
+
if unet_use_cross_frame_attention:
|
| 359 |
+
raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
|
| 360 |
+
else:
|
| 361 |
+
self.attn1 = CrossAttention(
|
| 362 |
+
query_dim=dim,
|
| 363 |
+
heads=num_attention_heads,
|
| 364 |
+
dim_head=attention_head_dim,
|
| 365 |
+
dropout=dropout,
|
| 366 |
+
bias=attention_bias,
|
| 367 |
+
upcast_attention=upcast_attention,
|
| 368 |
+
)
|
| 369 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
| 370 |
+
|
| 371 |
+
self.audio_cross_attn = AudioCrossAttn(
|
| 372 |
+
dim=dim,
|
| 373 |
+
cross_attention_dim=cross_attention_dim,
|
| 374 |
+
num_attention_heads=num_attention_heads,
|
| 375 |
+
attention_head_dim=attention_head_dim,
|
| 376 |
+
dropout=dropout,
|
| 377 |
+
attention_bias=attention_bias,
|
| 378 |
+
upcast_attention=upcast_attention,
|
| 379 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 380 |
+
use_ada_layer_norm=self.use_ada_layer_norm,
|
| 381 |
+
zero_proj_out=False,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Feed-forward
|
| 385 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
| 386 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 387 |
+
|
| 388 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
| 389 |
+
if not is_xformers_available():
|
| 390 |
+
print("Here is how to install it")
|
| 391 |
+
raise ModuleNotFoundError(
|
| 392 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
| 393 |
+
" xformers",
|
| 394 |
+
name="xformers",
|
| 395 |
+
)
|
| 396 |
+
elif not torch.cuda.is_available():
|
| 397 |
+
raise ValueError(
|
| 398 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
| 399 |
+
" available for GPU "
|
| 400 |
+
)
|
| 401 |
+
else:
|
| 402 |
+
try:
|
| 403 |
+
# Make sure we can run the memory efficient attention
|
| 404 |
+
_ = xformers.ops.memory_efficient_attention(
|
| 405 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 406 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 407 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 408 |
+
)
|
| 409 |
+
except Exception as e:
|
| 410 |
+
raise e
|
| 411 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
| 412 |
+
if self.audio_cross_attn is not None:
|
| 413 |
+
self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
|
| 414 |
+
use_memory_efficient_attention_xformers
|
| 415 |
+
)
|
| 416 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
| 417 |
+
|
| 418 |
+
def forward(
|
| 419 |
+
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
|
| 420 |
+
):
|
| 421 |
+
# SparseCausal-Attention
|
| 422 |
+
norm_hidden_states = (
|
| 423 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# pdb.set_trace()
|
| 427 |
+
if self.unet_use_cross_frame_attention:
|
| 428 |
+
hidden_states = (
|
| 429 |
+
self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
|
| 430 |
+
+ hidden_states
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
| 434 |
+
|
| 435 |
+
if self.audio_cross_attn is not None and encoder_hidden_states is not None:
|
| 436 |
+
hidden_states = self.audio_cross_attn(
|
| 437 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Feed-forward
|
| 441 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
| 442 |
+
|
| 443 |
+
return hidden_states
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class AudioCrossAttn(nn.Module):
|
| 447 |
+
def __init__(
|
| 448 |
+
self,
|
| 449 |
+
dim,
|
| 450 |
+
cross_attention_dim,
|
| 451 |
+
num_attention_heads,
|
| 452 |
+
attention_head_dim,
|
| 453 |
+
dropout,
|
| 454 |
+
attention_bias,
|
| 455 |
+
upcast_attention,
|
| 456 |
+
num_embeds_ada_norm,
|
| 457 |
+
use_ada_layer_norm,
|
| 458 |
+
zero_proj_out=False,
|
| 459 |
+
):
|
| 460 |
+
super().__init__()
|
| 461 |
+
|
| 462 |
+
self.norm = AdaLayerNorm(dim, num_embeds_ada_norm) if use_ada_layer_norm else nn.LayerNorm(dim)
|
| 463 |
+
self.attn = CrossAttention(
|
| 464 |
+
query_dim=dim,
|
| 465 |
+
cross_attention_dim=cross_attention_dim,
|
| 466 |
+
heads=num_attention_heads,
|
| 467 |
+
dim_head=attention_head_dim,
|
| 468 |
+
dropout=dropout,
|
| 469 |
+
bias=attention_bias,
|
| 470 |
+
upcast_attention=upcast_attention,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
if zero_proj_out:
|
| 474 |
+
self.proj_out = zero_module(nn.Linear(dim, dim))
|
| 475 |
+
|
| 476 |
+
self.zero_proj_out = zero_proj_out
|
| 477 |
+
self.use_ada_layer_norm = use_ada_layer_norm
|
| 478 |
+
|
| 479 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
|
| 480 |
+
previous_hidden_states = hidden_states
|
| 481 |
+
hidden_states = self.norm(hidden_states, timestep) if self.use_ada_layer_norm else self.norm(hidden_states)
|
| 482 |
+
|
| 483 |
+
if encoder_hidden_states.dim() == 4:
|
| 484 |
+
encoder_hidden_states = rearrange(encoder_hidden_states, "b f n d -> (b f) n d")
|
| 485 |
+
|
| 486 |
+
hidden_states = self.attn(
|
| 487 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
if self.zero_proj_out:
|
| 491 |
+
hidden_states = self.proj_out(hidden_states)
|
| 492 |
+
return hidden_states + previous_hidden_states
|
LatentSync/latentsync/models/motion_module.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
|
| 2 |
+
|
| 3 |
+
# Actually we don't use the motion module in the final version of LatentSync
|
| 4 |
+
# When we started the project, we used the codebase of AnimateDiff and tried motion module
|
| 5 |
+
# But the results are poor, and we decied to leave the code here for possible future usage
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 14 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 15 |
+
from diffusers.utils import BaseOutput
|
| 16 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 17 |
+
from diffusers.models.attention import Attention as CrossAttention, FeedForward
|
| 18 |
+
|
| 19 |
+
from einops import rearrange, repeat
|
| 20 |
+
import math
|
| 21 |
+
from .utils import zero_module
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
| 26 |
+
sample: torch.FloatTensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if is_xformers_available():
|
| 30 |
+
import xformers
|
| 31 |
+
import xformers.ops
|
| 32 |
+
else:
|
| 33 |
+
xformers = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
|
| 37 |
+
if motion_module_type == "Vanilla":
|
| 38 |
+
return VanillaTemporalModule(
|
| 39 |
+
in_channels=in_channels,
|
| 40 |
+
**motion_module_kwargs,
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class VanillaTemporalModule(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
in_channels,
|
| 50 |
+
num_attention_heads=8,
|
| 51 |
+
num_transformer_block=2,
|
| 52 |
+
attention_block_types=("Temporal_Self", "Temporal_Self"),
|
| 53 |
+
cross_frame_attention_mode=None,
|
| 54 |
+
temporal_position_encoding=False,
|
| 55 |
+
temporal_position_encoding_max_len=24,
|
| 56 |
+
temporal_attention_dim_div=1,
|
| 57 |
+
zero_initialize=True,
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
| 62 |
+
in_channels=in_channels,
|
| 63 |
+
num_attention_heads=num_attention_heads,
|
| 64 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
| 65 |
+
num_layers=num_transformer_block,
|
| 66 |
+
attention_block_types=attention_block_types,
|
| 67 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
| 68 |
+
temporal_position_encoding=temporal_position_encoding,
|
| 69 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if zero_initialize:
|
| 73 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
| 74 |
+
|
| 75 |
+
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
| 76 |
+
hidden_states = input_tensor
|
| 77 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
| 78 |
+
|
| 79 |
+
output = hidden_states
|
| 80 |
+
return output
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class TemporalTransformer3DModel(nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
in_channels,
|
| 87 |
+
num_attention_heads,
|
| 88 |
+
attention_head_dim,
|
| 89 |
+
num_layers,
|
| 90 |
+
attention_block_types=(
|
| 91 |
+
"Temporal_Self",
|
| 92 |
+
"Temporal_Self",
|
| 93 |
+
),
|
| 94 |
+
dropout=0.0,
|
| 95 |
+
norm_num_groups=32,
|
| 96 |
+
cross_attention_dim=768,
|
| 97 |
+
activation_fn="geglu",
|
| 98 |
+
attention_bias=False,
|
| 99 |
+
upcast_attention=False,
|
| 100 |
+
cross_frame_attention_mode=None,
|
| 101 |
+
temporal_position_encoding=False,
|
| 102 |
+
temporal_position_encoding_max_len=24,
|
| 103 |
+
):
|
| 104 |
+
super().__init__()
|
| 105 |
+
|
| 106 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 107 |
+
|
| 108 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 109 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 110 |
+
|
| 111 |
+
self.transformer_blocks = nn.ModuleList(
|
| 112 |
+
[
|
| 113 |
+
TemporalTransformerBlock(
|
| 114 |
+
dim=inner_dim,
|
| 115 |
+
num_attention_heads=num_attention_heads,
|
| 116 |
+
attention_head_dim=attention_head_dim,
|
| 117 |
+
attention_block_types=attention_block_types,
|
| 118 |
+
dropout=dropout,
|
| 119 |
+
norm_num_groups=norm_num_groups,
|
| 120 |
+
cross_attention_dim=cross_attention_dim,
|
| 121 |
+
activation_fn=activation_fn,
|
| 122 |
+
attention_bias=attention_bias,
|
| 123 |
+
upcast_attention=upcast_attention,
|
| 124 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
| 125 |
+
temporal_position_encoding=temporal_position_encoding,
|
| 126 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
| 127 |
+
)
|
| 128 |
+
for d in range(num_layers)
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
| 132 |
+
|
| 133 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 134 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
| 135 |
+
video_length = hidden_states.shape[2]
|
| 136 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
| 137 |
+
|
| 138 |
+
batch, channel, height, weight = hidden_states.shape
|
| 139 |
+
residual = hidden_states
|
| 140 |
+
|
| 141 |
+
hidden_states = self.norm(hidden_states)
|
| 142 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
|
| 143 |
+
hidden_states = self.proj_in(hidden_states)
|
| 144 |
+
|
| 145 |
+
# Transformer Blocks
|
| 146 |
+
for block in self.transformer_blocks:
|
| 147 |
+
hidden_states = block(
|
| 148 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# output
|
| 152 |
+
hidden_states = self.proj_out(hidden_states)
|
| 153 |
+
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2).contiguous()
|
| 154 |
+
|
| 155 |
+
output = hidden_states + residual
|
| 156 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
| 157 |
+
|
| 158 |
+
return output
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class TemporalTransformerBlock(nn.Module):
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
dim,
|
| 165 |
+
num_attention_heads,
|
| 166 |
+
attention_head_dim,
|
| 167 |
+
attention_block_types=(
|
| 168 |
+
"Temporal_Self",
|
| 169 |
+
"Temporal_Self",
|
| 170 |
+
),
|
| 171 |
+
dropout=0.0,
|
| 172 |
+
norm_num_groups=32,
|
| 173 |
+
cross_attention_dim=768,
|
| 174 |
+
activation_fn="geglu",
|
| 175 |
+
attention_bias=False,
|
| 176 |
+
upcast_attention=False,
|
| 177 |
+
cross_frame_attention_mode=None,
|
| 178 |
+
temporal_position_encoding=False,
|
| 179 |
+
temporal_position_encoding_max_len=24,
|
| 180 |
+
):
|
| 181 |
+
super().__init__()
|
| 182 |
+
|
| 183 |
+
attention_blocks = []
|
| 184 |
+
norms = []
|
| 185 |
+
|
| 186 |
+
for block_name in attention_block_types:
|
| 187 |
+
attention_blocks.append(
|
| 188 |
+
VersatileAttention(
|
| 189 |
+
attention_mode=block_name.split("_")[0],
|
| 190 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
| 191 |
+
query_dim=dim,
|
| 192 |
+
heads=num_attention_heads,
|
| 193 |
+
dim_head=attention_head_dim,
|
| 194 |
+
dropout=dropout,
|
| 195 |
+
bias=attention_bias,
|
| 196 |
+
upcast_attention=upcast_attention,
|
| 197 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
| 198 |
+
temporal_position_encoding=temporal_position_encoding,
|
| 199 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
norms.append(nn.LayerNorm(dim))
|
| 203 |
+
|
| 204 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
| 205 |
+
self.norms = nn.ModuleList(norms)
|
| 206 |
+
|
| 207 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
| 208 |
+
self.ff_norm = nn.LayerNorm(dim)
|
| 209 |
+
|
| 210 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
| 211 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
| 212 |
+
norm_hidden_states = norm(hidden_states)
|
| 213 |
+
hidden_states = (
|
| 214 |
+
attention_block(
|
| 215 |
+
norm_hidden_states,
|
| 216 |
+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
| 217 |
+
video_length=video_length,
|
| 218 |
+
)
|
| 219 |
+
+ hidden_states
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
| 223 |
+
|
| 224 |
+
output = hidden_states
|
| 225 |
+
return output
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class PositionalEncoding(nn.Module):
|
| 229 |
+
def __init__(self, d_model, dropout=0.0, max_len=24):
|
| 230 |
+
super().__init__()
|
| 231 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 232 |
+
position = torch.arange(max_len).unsqueeze(1)
|
| 233 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
| 234 |
+
pe = torch.zeros(1, max_len, d_model)
|
| 235 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
| 236 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
| 237 |
+
self.register_buffer("pe", pe)
|
| 238 |
+
|
| 239 |
+
def forward(self, x):
|
| 240 |
+
x = x + self.pe[:, : x.size(1)]
|
| 241 |
+
return self.dropout(x)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class VersatileAttention(CrossAttention):
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
attention_mode=None,
|
| 248 |
+
cross_frame_attention_mode=None,
|
| 249 |
+
temporal_position_encoding=False,
|
| 250 |
+
temporal_position_encoding_max_len=24,
|
| 251 |
+
*args,
|
| 252 |
+
**kwargs,
|
| 253 |
+
):
|
| 254 |
+
super().__init__(*args, **kwargs)
|
| 255 |
+
assert attention_mode == "Temporal"
|
| 256 |
+
|
| 257 |
+
self.attention_mode = attention_mode
|
| 258 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
| 259 |
+
|
| 260 |
+
self.pos_encoder = (
|
| 261 |
+
PositionalEncoding(kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len)
|
| 262 |
+
if (temporal_position_encoding and attention_mode == "Temporal")
|
| 263 |
+
else None
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def extra_repr(self):
|
| 267 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
| 268 |
+
|
| 269 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
| 270 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 271 |
+
|
| 272 |
+
if self.attention_mode == "Temporal":
|
| 273 |
+
d = hidden_states.shape[1]
|
| 274 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
| 275 |
+
|
| 276 |
+
if self.pos_encoder is not None:
|
| 277 |
+
hidden_states = self.pos_encoder(hidden_states)
|
| 278 |
+
|
| 279 |
+
encoder_hidden_states = (
|
| 280 |
+
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
|
| 281 |
+
if encoder_hidden_states is not None
|
| 282 |
+
else encoder_hidden_states
|
| 283 |
+
)
|
| 284 |
+
else:
|
| 285 |
+
raise NotImplementedError
|
| 286 |
+
|
| 287 |
+
# encoder_hidden_states = encoder_hidden_states
|
| 288 |
+
|
| 289 |
+
if self.group_norm is not None:
|
| 290 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 291 |
+
|
| 292 |
+
query = self.to_q(hidden_states)
|
| 293 |
+
dim = query.shape[-1]
|
| 294 |
+
query = self.reshape_heads_to_batch_dim(query)
|
| 295 |
+
|
| 296 |
+
if self.added_kv_proj_dim is not None:
|
| 297 |
+
raise NotImplementedError
|
| 298 |
+
|
| 299 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
| 300 |
+
key = self.to_k(encoder_hidden_states)
|
| 301 |
+
value = self.to_v(encoder_hidden_states)
|
| 302 |
+
|
| 303 |
+
key = self.reshape_heads_to_batch_dim(key)
|
| 304 |
+
value = self.reshape_heads_to_batch_dim(value)
|
| 305 |
+
|
| 306 |
+
if attention_mask is not None:
|
| 307 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
| 308 |
+
target_length = query.shape[1]
|
| 309 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 310 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
| 311 |
+
|
| 312 |
+
# attention, what we cannot get enough of
|
| 313 |
+
if self._use_memory_efficient_attention_xformers:
|
| 314 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
| 315 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
| 316 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 317 |
+
else:
|
| 318 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
| 319 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
| 320 |
+
else:
|
| 321 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
| 322 |
+
|
| 323 |
+
# linear proj
|
| 324 |
+
hidden_states = self.to_out[0](hidden_states)
|
| 325 |
+
|
| 326 |
+
# dropout
|
| 327 |
+
hidden_states = self.to_out[1](hidden_states)
|
| 328 |
+
|
| 329 |
+
if self.attention_mode == "Temporal":
|
| 330 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
| 331 |
+
|
| 332 |
+
return hidden_states
|
LatentSync/latentsync/models/resnet.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class InflatedConv3d(nn.Conv2d):
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
video_length = x.shape[2]
|
| 13 |
+
|
| 14 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
| 15 |
+
x = super().forward(x)
|
| 16 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
| 17 |
+
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
video_length = x.shape[2]
|
| 24 |
+
|
| 25 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
| 26 |
+
x = super().forward(x)
|
| 27 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
| 28 |
+
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Upsample3D(nn.Module):
|
| 33 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.channels = channels
|
| 36 |
+
self.out_channels = out_channels or channels
|
| 37 |
+
self.use_conv = use_conv
|
| 38 |
+
self.use_conv_transpose = use_conv_transpose
|
| 39 |
+
self.name = name
|
| 40 |
+
|
| 41 |
+
conv = None
|
| 42 |
+
if use_conv_transpose:
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
elif use_conv:
|
| 45 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
| 46 |
+
|
| 47 |
+
def forward(self, hidden_states, output_size=None):
|
| 48 |
+
assert hidden_states.shape[1] == self.channels
|
| 49 |
+
|
| 50 |
+
if self.use_conv_transpose:
|
| 51 |
+
raise NotImplementedError
|
| 52 |
+
|
| 53 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
| 54 |
+
dtype = hidden_states.dtype
|
| 55 |
+
if dtype == torch.bfloat16:
|
| 56 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 57 |
+
|
| 58 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
| 59 |
+
if hidden_states.shape[0] >= 64:
|
| 60 |
+
hidden_states = hidden_states.contiguous()
|
| 61 |
+
|
| 62 |
+
# if `output_size` is passed we force the interpolation output
|
| 63 |
+
# size and do not make use of `scale_factor=2`
|
| 64 |
+
if output_size is None:
|
| 65 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
|
| 66 |
+
else:
|
| 67 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
| 68 |
+
|
| 69 |
+
# If the input is bfloat16, we cast back to bfloat16
|
| 70 |
+
if dtype == torch.bfloat16:
|
| 71 |
+
hidden_states = hidden_states.to(dtype)
|
| 72 |
+
|
| 73 |
+
# if self.use_conv:
|
| 74 |
+
# if self.name == "conv":
|
| 75 |
+
# hidden_states = self.conv(hidden_states)
|
| 76 |
+
# else:
|
| 77 |
+
# hidden_states = self.Conv2d_0(hidden_states)
|
| 78 |
+
hidden_states = self.conv(hidden_states)
|
| 79 |
+
|
| 80 |
+
return hidden_states
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Downsample3D(nn.Module):
|
| 84 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.channels = channels
|
| 87 |
+
self.out_channels = out_channels or channels
|
| 88 |
+
self.use_conv = use_conv
|
| 89 |
+
self.padding = padding
|
| 90 |
+
stride = 2
|
| 91 |
+
self.name = name
|
| 92 |
+
|
| 93 |
+
if use_conv:
|
| 94 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
| 95 |
+
else:
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
def forward(self, hidden_states):
|
| 99 |
+
assert hidden_states.shape[1] == self.channels
|
| 100 |
+
if self.use_conv and self.padding == 0:
|
| 101 |
+
raise NotImplementedError
|
| 102 |
+
|
| 103 |
+
assert hidden_states.shape[1] == self.channels
|
| 104 |
+
hidden_states = self.conv(hidden_states)
|
| 105 |
+
|
| 106 |
+
return hidden_states
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ResnetBlock3D(nn.Module):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
*,
|
| 113 |
+
in_channels,
|
| 114 |
+
out_channels=None,
|
| 115 |
+
conv_shortcut=False,
|
| 116 |
+
dropout=0.0,
|
| 117 |
+
temb_channels=512,
|
| 118 |
+
groups=32,
|
| 119 |
+
groups_out=None,
|
| 120 |
+
pre_norm=True,
|
| 121 |
+
eps=1e-6,
|
| 122 |
+
non_linearity="swish",
|
| 123 |
+
time_embedding_norm="default",
|
| 124 |
+
output_scale_factor=1.0,
|
| 125 |
+
use_in_shortcut=None,
|
| 126 |
+
use_inflated_groupnorm=False,
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.pre_norm = pre_norm
|
| 130 |
+
self.pre_norm = True
|
| 131 |
+
self.in_channels = in_channels
|
| 132 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 133 |
+
self.out_channels = out_channels
|
| 134 |
+
self.use_conv_shortcut = conv_shortcut
|
| 135 |
+
self.time_embedding_norm = time_embedding_norm
|
| 136 |
+
self.output_scale_factor = output_scale_factor
|
| 137 |
+
|
| 138 |
+
if groups_out is None:
|
| 139 |
+
groups_out = groups
|
| 140 |
+
|
| 141 |
+
assert use_inflated_groupnorm != None
|
| 142 |
+
if use_inflated_groupnorm:
|
| 143 |
+
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 144 |
+
else:
|
| 145 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 146 |
+
|
| 147 |
+
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 148 |
+
|
| 149 |
+
if temb_channels is not None:
|
| 150 |
+
time_emb_proj_out_channels = out_channels
|
| 151 |
+
# if self.time_embedding_norm == "default":
|
| 152 |
+
# time_emb_proj_out_channels = out_channels
|
| 153 |
+
# elif self.time_embedding_norm == "scale_shift":
|
| 154 |
+
# time_emb_proj_out_channels = out_channels * 2
|
| 155 |
+
# else:
|
| 156 |
+
# raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
| 157 |
+
|
| 158 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
| 159 |
+
else:
|
| 160 |
+
self.time_emb_proj = None
|
| 161 |
+
|
| 162 |
+
if self.time_embedding_norm == "scale_shift":
|
| 163 |
+
self.double_len_linear = torch.nn.Linear(time_emb_proj_out_channels, 2 * time_emb_proj_out_channels)
|
| 164 |
+
else:
|
| 165 |
+
self.double_len_linear = None
|
| 166 |
+
|
| 167 |
+
if use_inflated_groupnorm:
|
| 168 |
+
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
| 169 |
+
else:
|
| 170 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
| 171 |
+
|
| 172 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 173 |
+
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 174 |
+
|
| 175 |
+
if non_linearity == "swish":
|
| 176 |
+
self.nonlinearity = lambda x: F.silu(x)
|
| 177 |
+
elif non_linearity == "mish":
|
| 178 |
+
self.nonlinearity = Mish()
|
| 179 |
+
elif non_linearity == "silu":
|
| 180 |
+
self.nonlinearity = nn.SiLU()
|
| 181 |
+
|
| 182 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
| 183 |
+
|
| 184 |
+
self.conv_shortcut = None
|
| 185 |
+
if self.use_in_shortcut:
|
| 186 |
+
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 187 |
+
|
| 188 |
+
def forward(self, input_tensor, temb):
|
| 189 |
+
hidden_states = input_tensor
|
| 190 |
+
|
| 191 |
+
hidden_states = self.norm1(hidden_states)
|
| 192 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 193 |
+
|
| 194 |
+
hidden_states = self.conv1(hidden_states)
|
| 195 |
+
|
| 196 |
+
if temb is not None:
|
| 197 |
+
if temb.dim() == 2:
|
| 198 |
+
# input (1, 1280)
|
| 199 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))
|
| 200 |
+
temb = temb[:, :, None, None, None] # unsqueeze
|
| 201 |
+
else:
|
| 202 |
+
# input (1, 1280, 16)
|
| 203 |
+
temb = temb.permute(0, 2, 1)
|
| 204 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))
|
| 205 |
+
if self.double_len_linear is not None:
|
| 206 |
+
temb = self.double_len_linear(self.nonlinearity(temb))
|
| 207 |
+
temb = temb.permute(0, 2, 1)
|
| 208 |
+
temb = temb[:, :, :, None, None]
|
| 209 |
+
|
| 210 |
+
if temb is not None and self.time_embedding_norm == "default":
|
| 211 |
+
hidden_states = hidden_states + temb
|
| 212 |
+
|
| 213 |
+
hidden_states = self.norm2(hidden_states)
|
| 214 |
+
|
| 215 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
| 216 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
| 217 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
| 218 |
+
|
| 219 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 220 |
+
|
| 221 |
+
hidden_states = self.dropout(hidden_states)
|
| 222 |
+
hidden_states = self.conv2(hidden_states)
|
| 223 |
+
|
| 224 |
+
if self.conv_shortcut is not None:
|
| 225 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 226 |
+
|
| 227 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
| 228 |
+
|
| 229 |
+
return output_tensor
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class Mish(torch.nn.Module):
|
| 233 |
+
def forward(self, hidden_states):
|
| 234 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
LatentSync/latentsync/models/syncnet.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
+
from ..utils.util import cosine_loss
|
| 20 |
+
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from diffusers.models.attention import CrossAttention, FeedForward
|
| 25 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 26 |
+
from einops import rearrange
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SyncNet(nn.Module):
|
| 30 |
+
def __init__(self, config):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.audio_encoder = DownEncoder2D(
|
| 33 |
+
in_channels=config["audio_encoder"]["in_channels"],
|
| 34 |
+
block_out_channels=config["audio_encoder"]["block_out_channels"],
|
| 35 |
+
downsample_factors=config["audio_encoder"]["downsample_factors"],
|
| 36 |
+
dropout=config["audio_encoder"]["dropout"],
|
| 37 |
+
attn_blocks=config["audio_encoder"]["attn_blocks"],
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.visual_encoder = DownEncoder2D(
|
| 41 |
+
in_channels=config["visual_encoder"]["in_channels"],
|
| 42 |
+
block_out_channels=config["visual_encoder"]["block_out_channels"],
|
| 43 |
+
downsample_factors=config["visual_encoder"]["downsample_factors"],
|
| 44 |
+
dropout=config["visual_encoder"]["dropout"],
|
| 45 |
+
attn_blocks=config["visual_encoder"]["attn_blocks"],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.eval()
|
| 49 |
+
|
| 50 |
+
def forward(self, image_sequences, audio_sequences):
|
| 51 |
+
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
| 52 |
+
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
| 53 |
+
|
| 54 |
+
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
| 55 |
+
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
| 56 |
+
|
| 57 |
+
# Make them unit vectors
|
| 58 |
+
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
| 59 |
+
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
| 60 |
+
|
| 61 |
+
return vision_embeds, audio_embeds
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ResnetBlock2D(nn.Module):
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
in_channels: int,
|
| 68 |
+
out_channels: int,
|
| 69 |
+
dropout: float = 0.0,
|
| 70 |
+
norm_num_groups: int = 32,
|
| 71 |
+
eps: float = 1e-6,
|
| 72 |
+
act_fn: str = "silu",
|
| 73 |
+
downsample_factor=2,
|
| 74 |
+
):
|
| 75 |
+
super().__init__()
|
| 76 |
+
|
| 77 |
+
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
| 78 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 79 |
+
|
| 80 |
+
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
|
| 81 |
+
self.dropout = nn.Dropout(dropout)
|
| 82 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 83 |
+
|
| 84 |
+
if act_fn == "relu":
|
| 85 |
+
self.act_fn = nn.ReLU()
|
| 86 |
+
elif act_fn == "silu":
|
| 87 |
+
self.act_fn = nn.SiLU()
|
| 88 |
+
|
| 89 |
+
if in_channels != out_channels:
|
| 90 |
+
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 91 |
+
else:
|
| 92 |
+
self.conv_shortcut = None
|
| 93 |
+
|
| 94 |
+
if isinstance(downsample_factor, list):
|
| 95 |
+
downsample_factor = tuple(downsample_factor)
|
| 96 |
+
|
| 97 |
+
if downsample_factor == 1:
|
| 98 |
+
self.downsample_conv = None
|
| 99 |
+
else:
|
| 100 |
+
self.downsample_conv = nn.Conv2d(
|
| 101 |
+
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
|
| 102 |
+
)
|
| 103 |
+
self.pad = (0, 1, 0, 1)
|
| 104 |
+
if isinstance(downsample_factor, tuple):
|
| 105 |
+
if downsample_factor[0] == 1:
|
| 106 |
+
self.pad = (0, 1, 1, 1) # The padding order is from back to front
|
| 107 |
+
elif downsample_factor[1] == 1:
|
| 108 |
+
self.pad = (1, 1, 0, 1)
|
| 109 |
+
|
| 110 |
+
def forward(self, input_tensor):
|
| 111 |
+
hidden_states = input_tensor
|
| 112 |
+
|
| 113 |
+
hidden_states = self.norm1(hidden_states)
|
| 114 |
+
hidden_states = self.act_fn(hidden_states)
|
| 115 |
+
|
| 116 |
+
hidden_states = self.conv1(hidden_states)
|
| 117 |
+
hidden_states = self.norm2(hidden_states)
|
| 118 |
+
hidden_states = self.act_fn(hidden_states)
|
| 119 |
+
|
| 120 |
+
hidden_states = self.dropout(hidden_states)
|
| 121 |
+
hidden_states = self.conv2(hidden_states)
|
| 122 |
+
|
| 123 |
+
if self.conv_shortcut is not None:
|
| 124 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 125 |
+
|
| 126 |
+
hidden_states += input_tensor
|
| 127 |
+
|
| 128 |
+
if self.downsample_conv is not None:
|
| 129 |
+
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
|
| 130 |
+
hidden_states = self.downsample_conv(hidden_states)
|
| 131 |
+
|
| 132 |
+
return hidden_states
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class AttentionBlock2D(nn.Module):
|
| 136 |
+
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
|
| 137 |
+
super().__init__()
|
| 138 |
+
if not is_xformers_available():
|
| 139 |
+
raise ModuleNotFoundError(
|
| 140 |
+
"You have to install xformers to enable memory efficient attetion", name="xformers"
|
| 141 |
+
)
|
| 142 |
+
# inner_dim = dim_head * heads
|
| 143 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
|
| 144 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
| 145 |
+
self.norm3 = nn.LayerNorm(query_dim)
|
| 146 |
+
|
| 147 |
+
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
|
| 148 |
+
|
| 149 |
+
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
| 150 |
+
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
| 151 |
+
|
| 152 |
+
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
|
| 153 |
+
self.attn._use_memory_efficient_attention_xformers = True
|
| 154 |
+
|
| 155 |
+
def forward(self, hidden_states):
|
| 156 |
+
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
|
| 157 |
+
|
| 158 |
+
batch, channel, height, width = hidden_states.shape
|
| 159 |
+
residual = hidden_states
|
| 160 |
+
|
| 161 |
+
hidden_states = self.norm1(hidden_states)
|
| 162 |
+
hidden_states = self.conv_in(hidden_states)
|
| 163 |
+
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
| 164 |
+
|
| 165 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 166 |
+
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
|
| 167 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
| 168 |
+
|
| 169 |
+
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
|
| 170 |
+
hidden_states = self.conv_out(hidden_states)
|
| 171 |
+
|
| 172 |
+
hidden_states = hidden_states + residual
|
| 173 |
+
return hidden_states
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class DownEncoder2D(nn.Module):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
in_channels=4 * 16,
|
| 180 |
+
block_out_channels=[64, 128, 256, 256],
|
| 181 |
+
downsample_factors=[2, 2, 2, 2],
|
| 182 |
+
layers_per_block=2,
|
| 183 |
+
norm_num_groups=32,
|
| 184 |
+
attn_blocks=[1, 1, 1, 1],
|
| 185 |
+
dropout: float = 0.0,
|
| 186 |
+
act_fn="silu",
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.layers_per_block = layers_per_block
|
| 190 |
+
|
| 191 |
+
# in
|
| 192 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
| 193 |
+
|
| 194 |
+
# down
|
| 195 |
+
self.down_blocks = nn.ModuleList([])
|
| 196 |
+
|
| 197 |
+
output_channels = block_out_channels[0]
|
| 198 |
+
for i, block_out_channel in enumerate(block_out_channels):
|
| 199 |
+
input_channels = output_channels
|
| 200 |
+
output_channels = block_out_channel
|
| 201 |
+
# is_final_block = i == len(block_out_channels) - 1
|
| 202 |
+
|
| 203 |
+
down_block = ResnetBlock2D(
|
| 204 |
+
in_channels=input_channels,
|
| 205 |
+
out_channels=output_channels,
|
| 206 |
+
downsample_factor=downsample_factors[i],
|
| 207 |
+
norm_num_groups=norm_num_groups,
|
| 208 |
+
dropout=dropout,
|
| 209 |
+
act_fn=act_fn,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
self.down_blocks.append(down_block)
|
| 213 |
+
|
| 214 |
+
if attn_blocks[i] == 1:
|
| 215 |
+
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
|
| 216 |
+
self.down_blocks.append(attention_block)
|
| 217 |
+
|
| 218 |
+
# out
|
| 219 |
+
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
| 220 |
+
self.act_fn_out = nn.ReLU()
|
| 221 |
+
|
| 222 |
+
def forward(self, hidden_states):
|
| 223 |
+
hidden_states = self.conv_in(hidden_states)
|
| 224 |
+
|
| 225 |
+
# down
|
| 226 |
+
for down_block in self.down_blocks:
|
| 227 |
+
hidden_states = down_block(hidden_states)
|
| 228 |
+
|
| 229 |
+
# post-process
|
| 230 |
+
hidden_states = self.norm_out(hidden_states)
|
| 231 |
+
hidden_states = self.act_fn_out(hidden_states)
|
| 232 |
+
|
| 233 |
+
return hidden_states
|
LatentSync/latentsync/models/syncnet_wav2lip.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
|
| 2 |
+
# The code here is for ablation study.
|
| 3 |
+
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SyncNetWav2Lip(nn.Module):
|
| 9 |
+
def __init__(self, act_fn="leaky"):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
# input image sequences: (15, 128, 256)
|
| 13 |
+
self.visual_encoder = nn.Sequential(
|
| 14 |
+
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256)
|
| 15 |
+
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127)
|
| 16 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 17 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 18 |
+
Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64)
|
| 19 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 20 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 21 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 22 |
+
Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22)
|
| 23 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 24 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 25 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11)
|
| 26 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 27 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 28 |
+
Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6)
|
| 29 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 30 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 31 |
+
Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3)
|
| 32 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
|
| 33 |
+
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# input audio sequences: (1, 80, 16)
|
| 37 |
+
self.audio_encoder = nn.Sequential(
|
| 38 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
|
| 39 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 40 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 41 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16)
|
| 42 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 43 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 44 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6)
|
| 45 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 46 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 47 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3)
|
| 48 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 49 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 50 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
|
| 51 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 52 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 53 |
+
Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
|
| 54 |
+
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(self, image_sequences, audio_sequences):
|
| 58 |
+
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
| 59 |
+
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
| 60 |
+
|
| 61 |
+
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
| 62 |
+
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
| 63 |
+
|
| 64 |
+
# Make them unit vectors
|
| 65 |
+
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
| 66 |
+
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
| 67 |
+
|
| 68 |
+
return vision_embeds, audio_embeds
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Conv2d(nn.Module):
|
| 72 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs):
|
| 73 |
+
super().__init__(*args, **kwargs)
|
| 74 |
+
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
|
| 75 |
+
if act_fn == "relu":
|
| 76 |
+
self.act_fn = nn.ReLU()
|
| 77 |
+
elif act_fn == "tanh":
|
| 78 |
+
self.act_fn = nn.Tanh()
|
| 79 |
+
elif act_fn == "silu":
|
| 80 |
+
self.act_fn = nn.SiLU()
|
| 81 |
+
elif act_fn == "leaky":
|
| 82 |
+
self.act_fn = nn.LeakyReLU(0.2, inplace=True)
|
| 83 |
+
|
| 84 |
+
self.residual = residual
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
out = self.conv_block(x)
|
| 88 |
+
if self.residual:
|
| 89 |
+
out += x
|
| 90 |
+
return self.act_fn(out)
|
LatentSync/latentsync/models/unet.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Optional, Tuple, Union
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers import UNet2DConditionModel
|
| 14 |
+
from diffusers.utils import BaseOutput, logging
|
| 15 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 16 |
+
from .unet_blocks import (
|
| 17 |
+
CrossAttnDownBlock3D,
|
| 18 |
+
CrossAttnUpBlock3D,
|
| 19 |
+
DownBlock3D,
|
| 20 |
+
UNetMidBlock3DCrossAttn,
|
| 21 |
+
UpBlock3D,
|
| 22 |
+
get_down_block,
|
| 23 |
+
get_up_block,
|
| 24 |
+
)
|
| 25 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
| 26 |
+
|
| 27 |
+
from ..utils.util import zero_rank_log
|
| 28 |
+
from einops import rearrange
|
| 29 |
+
from .utils import zero_module
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class UNet3DConditionOutput(BaseOutput):
|
| 37 |
+
sample: torch.FloatTensor
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
| 41 |
+
_supports_gradient_checkpointing = True
|
| 42 |
+
|
| 43 |
+
@register_to_config
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
sample_size: Optional[int] = None,
|
| 47 |
+
in_channels: int = 4,
|
| 48 |
+
out_channels: int = 4,
|
| 49 |
+
center_input_sample: bool = False,
|
| 50 |
+
flip_sin_to_cos: bool = True,
|
| 51 |
+
freq_shift: int = 0,
|
| 52 |
+
down_block_types: Tuple[str] = (
|
| 53 |
+
"CrossAttnDownBlock3D",
|
| 54 |
+
"CrossAttnDownBlock3D",
|
| 55 |
+
"CrossAttnDownBlock3D",
|
| 56 |
+
"DownBlock3D",
|
| 57 |
+
),
|
| 58 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
| 59 |
+
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
| 60 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 61 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 62 |
+
layers_per_block: int = 2,
|
| 63 |
+
downsample_padding: int = 1,
|
| 64 |
+
mid_block_scale_factor: float = 1,
|
| 65 |
+
act_fn: str = "silu",
|
| 66 |
+
norm_num_groups: int = 32,
|
| 67 |
+
norm_eps: float = 1e-5,
|
| 68 |
+
cross_attention_dim: int = 1280,
|
| 69 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
| 70 |
+
dual_cross_attention: bool = False,
|
| 71 |
+
use_linear_projection: bool = False,
|
| 72 |
+
class_embed_type: Optional[str] = None,
|
| 73 |
+
num_class_embeds: Optional[int] = None,
|
| 74 |
+
upcast_attention: bool = False,
|
| 75 |
+
resnet_time_scale_shift: str = "default",
|
| 76 |
+
use_inflated_groupnorm=False,
|
| 77 |
+
# Additional
|
| 78 |
+
use_motion_module=False,
|
| 79 |
+
motion_module_resolutions=(1, 2, 4, 8),
|
| 80 |
+
motion_module_mid_block=False,
|
| 81 |
+
motion_module_decoder_only=False,
|
| 82 |
+
motion_module_type=None,
|
| 83 |
+
motion_module_kwargs={},
|
| 84 |
+
unet_use_cross_frame_attention=False,
|
| 85 |
+
unet_use_temporal_attention=False,
|
| 86 |
+
add_audio_layer=False,
|
| 87 |
+
audio_condition_method: str = "cross_attn",
|
| 88 |
+
custom_audio_layer=False,
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
|
| 92 |
+
self.sample_size = sample_size
|
| 93 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 94 |
+
self.use_motion_module = use_motion_module
|
| 95 |
+
self.add_audio_layer = add_audio_layer
|
| 96 |
+
|
| 97 |
+
self.conv_in = zero_module(InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)))
|
| 98 |
+
|
| 99 |
+
# time
|
| 100 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 101 |
+
timestep_input_dim = block_out_channels[0]
|
| 102 |
+
|
| 103 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 104 |
+
|
| 105 |
+
# class embedding
|
| 106 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 107 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 108 |
+
elif class_embed_type == "timestep":
|
| 109 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 110 |
+
elif class_embed_type == "identity":
|
| 111 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 112 |
+
else:
|
| 113 |
+
self.class_embedding = None
|
| 114 |
+
|
| 115 |
+
self.down_blocks = nn.ModuleList([])
|
| 116 |
+
self.mid_block = None
|
| 117 |
+
self.up_blocks = nn.ModuleList([])
|
| 118 |
+
|
| 119 |
+
if isinstance(only_cross_attention, bool):
|
| 120 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
| 121 |
+
|
| 122 |
+
if isinstance(attention_head_dim, int):
|
| 123 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
| 124 |
+
|
| 125 |
+
# down
|
| 126 |
+
output_channel = block_out_channels[0]
|
| 127 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 128 |
+
res = 2**i
|
| 129 |
+
input_channel = output_channel
|
| 130 |
+
output_channel = block_out_channels[i]
|
| 131 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 132 |
+
|
| 133 |
+
down_block = get_down_block(
|
| 134 |
+
down_block_type,
|
| 135 |
+
num_layers=layers_per_block,
|
| 136 |
+
in_channels=input_channel,
|
| 137 |
+
out_channels=output_channel,
|
| 138 |
+
temb_channels=time_embed_dim,
|
| 139 |
+
add_downsample=not is_final_block,
|
| 140 |
+
resnet_eps=norm_eps,
|
| 141 |
+
resnet_act_fn=act_fn,
|
| 142 |
+
resnet_groups=norm_num_groups,
|
| 143 |
+
cross_attention_dim=cross_attention_dim,
|
| 144 |
+
attn_num_head_channels=attention_head_dim[i],
|
| 145 |
+
downsample_padding=downsample_padding,
|
| 146 |
+
dual_cross_attention=dual_cross_attention,
|
| 147 |
+
use_linear_projection=use_linear_projection,
|
| 148 |
+
only_cross_attention=only_cross_attention[i],
|
| 149 |
+
upcast_attention=upcast_attention,
|
| 150 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 151 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 152 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 153 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 154 |
+
use_motion_module=use_motion_module
|
| 155 |
+
and (res in motion_module_resolutions)
|
| 156 |
+
and (not motion_module_decoder_only),
|
| 157 |
+
motion_module_type=motion_module_type,
|
| 158 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 159 |
+
add_audio_layer=add_audio_layer,
|
| 160 |
+
audio_condition_method=audio_condition_method,
|
| 161 |
+
custom_audio_layer=custom_audio_layer,
|
| 162 |
+
)
|
| 163 |
+
self.down_blocks.append(down_block)
|
| 164 |
+
|
| 165 |
+
# mid
|
| 166 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
| 167 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
| 168 |
+
in_channels=block_out_channels[-1],
|
| 169 |
+
temb_channels=time_embed_dim,
|
| 170 |
+
resnet_eps=norm_eps,
|
| 171 |
+
resnet_act_fn=act_fn,
|
| 172 |
+
output_scale_factor=mid_block_scale_factor,
|
| 173 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 174 |
+
cross_attention_dim=cross_attention_dim,
|
| 175 |
+
attn_num_head_channels=attention_head_dim[-1],
|
| 176 |
+
resnet_groups=norm_num_groups,
|
| 177 |
+
dual_cross_attention=dual_cross_attention,
|
| 178 |
+
use_linear_projection=use_linear_projection,
|
| 179 |
+
upcast_attention=upcast_attention,
|
| 180 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 181 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 182 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 183 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
| 184 |
+
motion_module_type=motion_module_type,
|
| 185 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 186 |
+
add_audio_layer=add_audio_layer,
|
| 187 |
+
audio_condition_method=audio_condition_method,
|
| 188 |
+
custom_audio_layer=custom_audio_layer,
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
| 192 |
+
|
| 193 |
+
# count how many layers upsample the videos
|
| 194 |
+
self.num_upsamplers = 0
|
| 195 |
+
|
| 196 |
+
# up
|
| 197 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 198 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
| 199 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
| 200 |
+
output_channel = reversed_block_out_channels[0]
|
| 201 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 202 |
+
res = 2 ** (3 - i)
|
| 203 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 204 |
+
|
| 205 |
+
prev_output_channel = output_channel
|
| 206 |
+
output_channel = reversed_block_out_channels[i]
|
| 207 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 208 |
+
|
| 209 |
+
# add upsample block for all BUT final layer
|
| 210 |
+
if not is_final_block:
|
| 211 |
+
add_upsample = True
|
| 212 |
+
self.num_upsamplers += 1
|
| 213 |
+
else:
|
| 214 |
+
add_upsample = False
|
| 215 |
+
|
| 216 |
+
up_block = get_up_block(
|
| 217 |
+
up_block_type,
|
| 218 |
+
num_layers=layers_per_block + 1,
|
| 219 |
+
in_channels=input_channel,
|
| 220 |
+
out_channels=output_channel,
|
| 221 |
+
prev_output_channel=prev_output_channel,
|
| 222 |
+
temb_channels=time_embed_dim,
|
| 223 |
+
add_upsample=add_upsample,
|
| 224 |
+
resnet_eps=norm_eps,
|
| 225 |
+
resnet_act_fn=act_fn,
|
| 226 |
+
resnet_groups=norm_num_groups,
|
| 227 |
+
cross_attention_dim=cross_attention_dim,
|
| 228 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
| 229 |
+
dual_cross_attention=dual_cross_attention,
|
| 230 |
+
use_linear_projection=use_linear_projection,
|
| 231 |
+
only_cross_attention=only_cross_attention[i],
|
| 232 |
+
upcast_attention=upcast_attention,
|
| 233 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 234 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 235 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 236 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 237 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
| 238 |
+
motion_module_type=motion_module_type,
|
| 239 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 240 |
+
add_audio_layer=add_audio_layer,
|
| 241 |
+
audio_condition_method=audio_condition_method,
|
| 242 |
+
custom_audio_layer=custom_audio_layer,
|
| 243 |
+
)
|
| 244 |
+
self.up_blocks.append(up_block)
|
| 245 |
+
prev_output_channel = output_channel
|
| 246 |
+
|
| 247 |
+
# out
|
| 248 |
+
if use_inflated_groupnorm:
|
| 249 |
+
self.conv_norm_out = InflatedGroupNorm(
|
| 250 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
| 251 |
+
)
|
| 252 |
+
else:
|
| 253 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 254 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
| 255 |
+
)
|
| 256 |
+
self.conv_act = nn.SiLU()
|
| 257 |
+
|
| 258 |
+
self.conv_out = zero_module(InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1))
|
| 259 |
+
|
| 260 |
+
def set_attention_slice(self, slice_size):
|
| 261 |
+
r"""
|
| 262 |
+
Enable sliced attention computation.
|
| 263 |
+
|
| 264 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
| 265 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
| 269 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
| 270 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
| 271 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
| 272 |
+
must be a multiple of `slice_size`.
|
| 273 |
+
"""
|
| 274 |
+
sliceable_head_dims = []
|
| 275 |
+
|
| 276 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
| 277 |
+
if hasattr(module, "set_attention_slice"):
|
| 278 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
| 279 |
+
|
| 280 |
+
for child in module.children():
|
| 281 |
+
fn_recursive_retrieve_slicable_dims(child)
|
| 282 |
+
|
| 283 |
+
# retrieve number of attention layers
|
| 284 |
+
for module in self.children():
|
| 285 |
+
fn_recursive_retrieve_slicable_dims(module)
|
| 286 |
+
|
| 287 |
+
num_slicable_layers = len(sliceable_head_dims)
|
| 288 |
+
|
| 289 |
+
if slice_size == "auto":
|
| 290 |
+
# half the attention head size is usually a good trade-off between
|
| 291 |
+
# speed and memory
|
| 292 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
| 293 |
+
elif slice_size == "max":
|
| 294 |
+
# make smallest slice possible
|
| 295 |
+
slice_size = num_slicable_layers * [1]
|
| 296 |
+
|
| 297 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
| 298 |
+
|
| 299 |
+
if len(slice_size) != len(sliceable_head_dims):
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
| 302 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
for i in range(len(slice_size)):
|
| 306 |
+
size = slice_size[i]
|
| 307 |
+
dim = sliceable_head_dims[i]
|
| 308 |
+
if size is not None and size > dim:
|
| 309 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
| 310 |
+
|
| 311 |
+
# Recursively walk through all the children.
|
| 312 |
+
# Any children which exposes the set_attention_slice method
|
| 313 |
+
# gets the message
|
| 314 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
| 315 |
+
if hasattr(module, "set_attention_slice"):
|
| 316 |
+
module.set_attention_slice(slice_size.pop())
|
| 317 |
+
|
| 318 |
+
for child in module.children():
|
| 319 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
| 320 |
+
|
| 321 |
+
reversed_slice_size = list(reversed(slice_size))
|
| 322 |
+
for module in self.children():
|
| 323 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
| 324 |
+
|
| 325 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 326 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
| 327 |
+
module.gradient_checkpointing = value
|
| 328 |
+
|
| 329 |
+
def forward(
|
| 330 |
+
self,
|
| 331 |
+
sample: torch.FloatTensor,
|
| 332 |
+
timestep: Union[torch.Tensor, float, int],
|
| 333 |
+
encoder_hidden_states: torch.Tensor,
|
| 334 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 335 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 336 |
+
# support controlnet
|
| 337 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 338 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
| 339 |
+
return_dict: bool = True,
|
| 340 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
| 341 |
+
r"""
|
| 342 |
+
Args:
|
| 343 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
| 344 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
| 345 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
| 346 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 347 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
| 351 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
| 352 |
+
returning a tuple, the first element is the sample tensor.
|
| 353 |
+
"""
|
| 354 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 355 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
| 356 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 357 |
+
# on the fly if necessary.
|
| 358 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 359 |
+
|
| 360 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 361 |
+
forward_upsample_size = False
|
| 362 |
+
upsample_size = None
|
| 363 |
+
|
| 364 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 365 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 366 |
+
forward_upsample_size = True
|
| 367 |
+
|
| 368 |
+
# prepare attention_mask
|
| 369 |
+
if attention_mask is not None:
|
| 370 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 371 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 372 |
+
|
| 373 |
+
# center input if necessary
|
| 374 |
+
if self.config.center_input_sample:
|
| 375 |
+
sample = 2 * sample - 1.0
|
| 376 |
+
|
| 377 |
+
# time
|
| 378 |
+
timesteps = timestep
|
| 379 |
+
if not torch.is_tensor(timesteps):
|
| 380 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 381 |
+
is_mps = sample.device.type == "mps"
|
| 382 |
+
if isinstance(timestep, float):
|
| 383 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 384 |
+
else:
|
| 385 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 386 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 387 |
+
elif len(timesteps.shape) == 0:
|
| 388 |
+
timesteps = timesteps[None].to(sample.device)
|
| 389 |
+
|
| 390 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 391 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 392 |
+
|
| 393 |
+
t_emb = self.time_proj(timesteps)
|
| 394 |
+
|
| 395 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 396 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 397 |
+
# there might be better ways to encapsulate this.
|
| 398 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
| 399 |
+
emb = self.time_embedding(t_emb)
|
| 400 |
+
|
| 401 |
+
if self.class_embedding is not None:
|
| 402 |
+
if class_labels is None:
|
| 403 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 404 |
+
|
| 405 |
+
if self.config.class_embed_type == "timestep":
|
| 406 |
+
class_labels = self.time_proj(class_labels)
|
| 407 |
+
|
| 408 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
| 409 |
+
emb = emb + class_emb
|
| 410 |
+
|
| 411 |
+
# pre-process
|
| 412 |
+
sample = self.conv_in(sample)
|
| 413 |
+
|
| 414 |
+
# down
|
| 415 |
+
down_block_res_samples = (sample,)
|
| 416 |
+
for downsample_block in self.down_blocks:
|
| 417 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 418 |
+
sample, res_samples = downsample_block(
|
| 419 |
+
hidden_states=sample,
|
| 420 |
+
temb=emb,
|
| 421 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 422 |
+
attention_mask=attention_mask,
|
| 423 |
+
)
|
| 424 |
+
else:
|
| 425 |
+
sample, res_samples = downsample_block(
|
| 426 |
+
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
down_block_res_samples += res_samples
|
| 430 |
+
|
| 431 |
+
# support controlnet
|
| 432 |
+
down_block_res_samples = list(down_block_res_samples)
|
| 433 |
+
if down_block_additional_residuals is not None:
|
| 434 |
+
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
|
| 435 |
+
if down_block_additional_residual.dim() == 4: # boardcast
|
| 436 |
+
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
|
| 437 |
+
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
|
| 438 |
+
|
| 439 |
+
# mid
|
| 440 |
+
sample = self.mid_block(
|
| 441 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# support controlnet
|
| 445 |
+
if mid_block_additional_residual is not None:
|
| 446 |
+
if mid_block_additional_residual.dim() == 4: # boardcast
|
| 447 |
+
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
|
| 448 |
+
sample = sample + mid_block_additional_residual
|
| 449 |
+
|
| 450 |
+
# up
|
| 451 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 452 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 453 |
+
|
| 454 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 455 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 456 |
+
|
| 457 |
+
# if we have not reached the final block and need to forward the
|
| 458 |
+
# upsample size, we do it here
|
| 459 |
+
if not is_final_block and forward_upsample_size:
|
| 460 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 461 |
+
|
| 462 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 463 |
+
sample = upsample_block(
|
| 464 |
+
hidden_states=sample,
|
| 465 |
+
temb=emb,
|
| 466 |
+
res_hidden_states_tuple=res_samples,
|
| 467 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 468 |
+
upsample_size=upsample_size,
|
| 469 |
+
attention_mask=attention_mask,
|
| 470 |
+
)
|
| 471 |
+
else:
|
| 472 |
+
sample = upsample_block(
|
| 473 |
+
hidden_states=sample,
|
| 474 |
+
temb=emb,
|
| 475 |
+
res_hidden_states_tuple=res_samples,
|
| 476 |
+
upsample_size=upsample_size,
|
| 477 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# post-process
|
| 481 |
+
sample = self.conv_norm_out(sample)
|
| 482 |
+
sample = self.conv_act(sample)
|
| 483 |
+
sample = self.conv_out(sample)
|
| 484 |
+
|
| 485 |
+
if not return_dict:
|
| 486 |
+
return (sample,)
|
| 487 |
+
|
| 488 |
+
return UNet3DConditionOutput(sample=sample)
|
| 489 |
+
|
| 490 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 491 |
+
# If the loaded checkpoint's in_channels or out_channels are different from config
|
| 492 |
+
temp_state_dict = copy.deepcopy(state_dict)
|
| 493 |
+
if temp_state_dict["conv_in.weight"].shape[1] != self.config.in_channels:
|
| 494 |
+
del temp_state_dict["conv_in.weight"]
|
| 495 |
+
del temp_state_dict["conv_in.bias"]
|
| 496 |
+
if temp_state_dict["conv_out.weight"].shape[0] != self.config.out_channels:
|
| 497 |
+
del temp_state_dict["conv_out.weight"]
|
| 498 |
+
del temp_state_dict["conv_out.bias"]
|
| 499 |
+
|
| 500 |
+
# If the loaded checkpoint's cross_attention_dim is different from config
|
| 501 |
+
keys_to_remove = []
|
| 502 |
+
for key in temp_state_dict:
|
| 503 |
+
if "audio_cross_attn.attn.to_k." in key or "audio_cross_attn.attn.to_v." in key:
|
| 504 |
+
if temp_state_dict[key].shape[1] != self.config.cross_attention_dim:
|
| 505 |
+
keys_to_remove.append(key)
|
| 506 |
+
|
| 507 |
+
for key in keys_to_remove:
|
| 508 |
+
del temp_state_dict[key]
|
| 509 |
+
|
| 510 |
+
return super().load_state_dict(state_dict=temp_state_dict, strict=strict)
|
| 511 |
+
|
| 512 |
+
@classmethod
|
| 513 |
+
def from_pretrained(cls, model_config: dict, ckpt_path: str, device="cpu"):
|
| 514 |
+
unet = cls.from_config(model_config).to(device)
|
| 515 |
+
if ckpt_path != "":
|
| 516 |
+
zero_rank_log(logger, f"Load from checkpoint: {ckpt_path}")
|
| 517 |
+
ckpt = torch.load(ckpt_path, map_location=device)
|
| 518 |
+
if "global_step" in ckpt:
|
| 519 |
+
zero_rank_log(logger, f"resume from global_step: {ckpt['global_step']}")
|
| 520 |
+
resume_global_step = ckpt["global_step"]
|
| 521 |
+
else:
|
| 522 |
+
resume_global_step = 0
|
| 523 |
+
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
|
| 524 |
+
unet.load_state_dict(state_dict, strict=False)
|
| 525 |
+
else:
|
| 526 |
+
resume_global_step = 0
|
| 527 |
+
|
| 528 |
+
return unet, resume_global_step
|
LatentSync/latentsync/models/unet_blocks.py
ADDED
|
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from .attention import Transformer3DModel
|
| 7 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
| 8 |
+
from .motion_module import get_motion_module
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_down_block(
|
| 12 |
+
down_block_type,
|
| 13 |
+
num_layers,
|
| 14 |
+
in_channels,
|
| 15 |
+
out_channels,
|
| 16 |
+
temb_channels,
|
| 17 |
+
add_downsample,
|
| 18 |
+
resnet_eps,
|
| 19 |
+
resnet_act_fn,
|
| 20 |
+
attn_num_head_channels,
|
| 21 |
+
resnet_groups=None,
|
| 22 |
+
cross_attention_dim=None,
|
| 23 |
+
downsample_padding=None,
|
| 24 |
+
dual_cross_attention=False,
|
| 25 |
+
use_linear_projection=False,
|
| 26 |
+
only_cross_attention=False,
|
| 27 |
+
upcast_attention=False,
|
| 28 |
+
resnet_time_scale_shift="default",
|
| 29 |
+
unet_use_cross_frame_attention=False,
|
| 30 |
+
unet_use_temporal_attention=False,
|
| 31 |
+
use_inflated_groupnorm=False,
|
| 32 |
+
use_motion_module=None,
|
| 33 |
+
motion_module_type=None,
|
| 34 |
+
motion_module_kwargs=None,
|
| 35 |
+
add_audio_layer=False,
|
| 36 |
+
audio_condition_method="cross_attn",
|
| 37 |
+
custom_audio_layer=False,
|
| 38 |
+
):
|
| 39 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
| 40 |
+
if down_block_type == "DownBlock3D":
|
| 41 |
+
return DownBlock3D(
|
| 42 |
+
num_layers=num_layers,
|
| 43 |
+
in_channels=in_channels,
|
| 44 |
+
out_channels=out_channels,
|
| 45 |
+
temb_channels=temb_channels,
|
| 46 |
+
add_downsample=add_downsample,
|
| 47 |
+
resnet_eps=resnet_eps,
|
| 48 |
+
resnet_act_fn=resnet_act_fn,
|
| 49 |
+
resnet_groups=resnet_groups,
|
| 50 |
+
downsample_padding=downsample_padding,
|
| 51 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 52 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 53 |
+
use_motion_module=use_motion_module,
|
| 54 |
+
motion_module_type=motion_module_type,
|
| 55 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 56 |
+
)
|
| 57 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
| 58 |
+
if cross_attention_dim is None:
|
| 59 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
| 60 |
+
return CrossAttnDownBlock3D(
|
| 61 |
+
num_layers=num_layers,
|
| 62 |
+
in_channels=in_channels,
|
| 63 |
+
out_channels=out_channels,
|
| 64 |
+
temb_channels=temb_channels,
|
| 65 |
+
add_downsample=add_downsample,
|
| 66 |
+
resnet_eps=resnet_eps,
|
| 67 |
+
resnet_act_fn=resnet_act_fn,
|
| 68 |
+
resnet_groups=resnet_groups,
|
| 69 |
+
downsample_padding=downsample_padding,
|
| 70 |
+
cross_attention_dim=cross_attention_dim,
|
| 71 |
+
attn_num_head_channels=attn_num_head_channels,
|
| 72 |
+
dual_cross_attention=dual_cross_attention,
|
| 73 |
+
use_linear_projection=use_linear_projection,
|
| 74 |
+
only_cross_attention=only_cross_attention,
|
| 75 |
+
upcast_attention=upcast_attention,
|
| 76 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 77 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 78 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 79 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 80 |
+
use_motion_module=use_motion_module,
|
| 81 |
+
motion_module_type=motion_module_type,
|
| 82 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 83 |
+
add_audio_layer=add_audio_layer,
|
| 84 |
+
audio_condition_method=audio_condition_method,
|
| 85 |
+
custom_audio_layer=custom_audio_layer,
|
| 86 |
+
)
|
| 87 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_up_block(
|
| 91 |
+
up_block_type,
|
| 92 |
+
num_layers,
|
| 93 |
+
in_channels,
|
| 94 |
+
out_channels,
|
| 95 |
+
prev_output_channel,
|
| 96 |
+
temb_channels,
|
| 97 |
+
add_upsample,
|
| 98 |
+
resnet_eps,
|
| 99 |
+
resnet_act_fn,
|
| 100 |
+
attn_num_head_channels,
|
| 101 |
+
resnet_groups=None,
|
| 102 |
+
cross_attention_dim=None,
|
| 103 |
+
dual_cross_attention=False,
|
| 104 |
+
use_linear_projection=False,
|
| 105 |
+
only_cross_attention=False,
|
| 106 |
+
upcast_attention=False,
|
| 107 |
+
resnet_time_scale_shift="default",
|
| 108 |
+
unet_use_cross_frame_attention=False,
|
| 109 |
+
unet_use_temporal_attention=False,
|
| 110 |
+
use_inflated_groupnorm=False,
|
| 111 |
+
use_motion_module=None,
|
| 112 |
+
motion_module_type=None,
|
| 113 |
+
motion_module_kwargs=None,
|
| 114 |
+
add_audio_layer=False,
|
| 115 |
+
audio_condition_method="cross_attn",
|
| 116 |
+
custom_audio_layer=False,
|
| 117 |
+
):
|
| 118 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
| 119 |
+
if up_block_type == "UpBlock3D":
|
| 120 |
+
return UpBlock3D(
|
| 121 |
+
num_layers=num_layers,
|
| 122 |
+
in_channels=in_channels,
|
| 123 |
+
out_channels=out_channels,
|
| 124 |
+
prev_output_channel=prev_output_channel,
|
| 125 |
+
temb_channels=temb_channels,
|
| 126 |
+
add_upsample=add_upsample,
|
| 127 |
+
resnet_eps=resnet_eps,
|
| 128 |
+
resnet_act_fn=resnet_act_fn,
|
| 129 |
+
resnet_groups=resnet_groups,
|
| 130 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 131 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 132 |
+
use_motion_module=use_motion_module,
|
| 133 |
+
motion_module_type=motion_module_type,
|
| 134 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 135 |
+
)
|
| 136 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
| 137 |
+
if cross_attention_dim is None:
|
| 138 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
| 139 |
+
return CrossAttnUpBlock3D(
|
| 140 |
+
num_layers=num_layers,
|
| 141 |
+
in_channels=in_channels,
|
| 142 |
+
out_channels=out_channels,
|
| 143 |
+
prev_output_channel=prev_output_channel,
|
| 144 |
+
temb_channels=temb_channels,
|
| 145 |
+
add_upsample=add_upsample,
|
| 146 |
+
resnet_eps=resnet_eps,
|
| 147 |
+
resnet_act_fn=resnet_act_fn,
|
| 148 |
+
resnet_groups=resnet_groups,
|
| 149 |
+
cross_attention_dim=cross_attention_dim,
|
| 150 |
+
attn_num_head_channels=attn_num_head_channels,
|
| 151 |
+
dual_cross_attention=dual_cross_attention,
|
| 152 |
+
use_linear_projection=use_linear_projection,
|
| 153 |
+
only_cross_attention=only_cross_attention,
|
| 154 |
+
upcast_attention=upcast_attention,
|
| 155 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 156 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 157 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 158 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 159 |
+
use_motion_module=use_motion_module,
|
| 160 |
+
motion_module_type=motion_module_type,
|
| 161 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 162 |
+
add_audio_layer=add_audio_layer,
|
| 163 |
+
audio_condition_method=audio_condition_method,
|
| 164 |
+
custom_audio_layer=custom_audio_layer,
|
| 165 |
+
)
|
| 166 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
in_channels: int,
|
| 173 |
+
temb_channels: int,
|
| 174 |
+
dropout: float = 0.0,
|
| 175 |
+
num_layers: int = 1,
|
| 176 |
+
resnet_eps: float = 1e-6,
|
| 177 |
+
resnet_time_scale_shift: str = "default",
|
| 178 |
+
resnet_act_fn: str = "swish",
|
| 179 |
+
resnet_groups: int = 32,
|
| 180 |
+
resnet_pre_norm: bool = True,
|
| 181 |
+
attn_num_head_channels=1,
|
| 182 |
+
output_scale_factor=1.0,
|
| 183 |
+
cross_attention_dim=1280,
|
| 184 |
+
dual_cross_attention=False,
|
| 185 |
+
use_linear_projection=False,
|
| 186 |
+
upcast_attention=False,
|
| 187 |
+
unet_use_cross_frame_attention=False,
|
| 188 |
+
unet_use_temporal_attention=False,
|
| 189 |
+
use_inflated_groupnorm=False,
|
| 190 |
+
use_motion_module=None,
|
| 191 |
+
motion_module_type=None,
|
| 192 |
+
motion_module_kwargs=None,
|
| 193 |
+
add_audio_layer=False,
|
| 194 |
+
audio_condition_method="cross_attn",
|
| 195 |
+
custom_audio_layer: bool = False,
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
|
| 199 |
+
self.has_cross_attention = True
|
| 200 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 201 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 202 |
+
|
| 203 |
+
# there is always at least one resnet
|
| 204 |
+
resnets = [
|
| 205 |
+
ResnetBlock3D(
|
| 206 |
+
in_channels=in_channels,
|
| 207 |
+
out_channels=in_channels,
|
| 208 |
+
temb_channels=temb_channels,
|
| 209 |
+
eps=resnet_eps,
|
| 210 |
+
groups=resnet_groups,
|
| 211 |
+
dropout=dropout,
|
| 212 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 213 |
+
non_linearity=resnet_act_fn,
|
| 214 |
+
output_scale_factor=output_scale_factor,
|
| 215 |
+
pre_norm=resnet_pre_norm,
|
| 216 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 217 |
+
)
|
| 218 |
+
]
|
| 219 |
+
attentions = []
|
| 220 |
+
audio_attentions = []
|
| 221 |
+
motion_modules = []
|
| 222 |
+
|
| 223 |
+
for _ in range(num_layers):
|
| 224 |
+
if dual_cross_attention:
|
| 225 |
+
raise NotImplementedError
|
| 226 |
+
attentions.append(
|
| 227 |
+
Transformer3DModel(
|
| 228 |
+
attn_num_head_channels,
|
| 229 |
+
in_channels // attn_num_head_channels,
|
| 230 |
+
in_channels=in_channels,
|
| 231 |
+
num_layers=1,
|
| 232 |
+
cross_attention_dim=cross_attention_dim,
|
| 233 |
+
norm_num_groups=resnet_groups,
|
| 234 |
+
use_linear_projection=use_linear_projection,
|
| 235 |
+
upcast_attention=upcast_attention,
|
| 236 |
+
use_motion_module=use_motion_module,
|
| 237 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 238 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 239 |
+
add_audio_layer=add_audio_layer,
|
| 240 |
+
audio_condition_method=audio_condition_method,
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
audio_attentions.append(
|
| 244 |
+
Transformer3DModel(
|
| 245 |
+
attn_num_head_channels,
|
| 246 |
+
in_channels // attn_num_head_channels,
|
| 247 |
+
in_channels=in_channels,
|
| 248 |
+
num_layers=1,
|
| 249 |
+
cross_attention_dim=cross_attention_dim,
|
| 250 |
+
norm_num_groups=resnet_groups,
|
| 251 |
+
use_linear_projection=use_linear_projection,
|
| 252 |
+
upcast_attention=upcast_attention,
|
| 253 |
+
use_motion_module=use_motion_module,
|
| 254 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 255 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 256 |
+
add_audio_layer=add_audio_layer,
|
| 257 |
+
audio_condition_method=audio_condition_method,
|
| 258 |
+
custom_audio_layer=True,
|
| 259 |
+
)
|
| 260 |
+
if custom_audio_layer
|
| 261 |
+
else None
|
| 262 |
+
)
|
| 263 |
+
motion_modules.append(
|
| 264 |
+
get_motion_module(
|
| 265 |
+
in_channels=in_channels,
|
| 266 |
+
motion_module_type=motion_module_type,
|
| 267 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 268 |
+
)
|
| 269 |
+
if use_motion_module
|
| 270 |
+
else None
|
| 271 |
+
)
|
| 272 |
+
resnets.append(
|
| 273 |
+
ResnetBlock3D(
|
| 274 |
+
in_channels=in_channels,
|
| 275 |
+
out_channels=in_channels,
|
| 276 |
+
temb_channels=temb_channels,
|
| 277 |
+
eps=resnet_eps,
|
| 278 |
+
groups=resnet_groups,
|
| 279 |
+
dropout=dropout,
|
| 280 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 281 |
+
non_linearity=resnet_act_fn,
|
| 282 |
+
output_scale_factor=output_scale_factor,
|
| 283 |
+
pre_norm=resnet_pre_norm,
|
| 284 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 285 |
+
)
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
self.attentions = nn.ModuleList(attentions)
|
| 289 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
| 290 |
+
self.resnets = nn.ModuleList(resnets)
|
| 291 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 292 |
+
|
| 293 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
| 294 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
| 295 |
+
for attn, audio_attn, resnet, motion_module in zip(
|
| 296 |
+
self.attentions, self.audio_attentions, self.resnets[1:], self.motion_modules
|
| 297 |
+
):
|
| 298 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 299 |
+
hidden_states = (
|
| 300 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 301 |
+
if audio_attn is not None
|
| 302 |
+
else hidden_states
|
| 303 |
+
)
|
| 304 |
+
hidden_states = (
|
| 305 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
| 306 |
+
if motion_module is not None
|
| 307 |
+
else hidden_states
|
| 308 |
+
)
|
| 309 |
+
hidden_states = resnet(hidden_states, temb)
|
| 310 |
+
|
| 311 |
+
return hidden_states
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class CrossAttnDownBlock3D(nn.Module):
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
in_channels: int,
|
| 318 |
+
out_channels: int,
|
| 319 |
+
temb_channels: int,
|
| 320 |
+
dropout: float = 0.0,
|
| 321 |
+
num_layers: int = 1,
|
| 322 |
+
resnet_eps: float = 1e-6,
|
| 323 |
+
resnet_time_scale_shift: str = "default",
|
| 324 |
+
resnet_act_fn: str = "swish",
|
| 325 |
+
resnet_groups: int = 32,
|
| 326 |
+
resnet_pre_norm: bool = True,
|
| 327 |
+
attn_num_head_channels=1,
|
| 328 |
+
cross_attention_dim=1280,
|
| 329 |
+
output_scale_factor=1.0,
|
| 330 |
+
downsample_padding=1,
|
| 331 |
+
add_downsample=True,
|
| 332 |
+
dual_cross_attention=False,
|
| 333 |
+
use_linear_projection=False,
|
| 334 |
+
only_cross_attention=False,
|
| 335 |
+
upcast_attention=False,
|
| 336 |
+
unet_use_cross_frame_attention=False,
|
| 337 |
+
unet_use_temporal_attention=False,
|
| 338 |
+
use_inflated_groupnorm=False,
|
| 339 |
+
use_motion_module=None,
|
| 340 |
+
motion_module_type=None,
|
| 341 |
+
motion_module_kwargs=None,
|
| 342 |
+
add_audio_layer=False,
|
| 343 |
+
audio_condition_method="cross_attn",
|
| 344 |
+
custom_audio_layer: bool = False,
|
| 345 |
+
):
|
| 346 |
+
super().__init__()
|
| 347 |
+
resnets = []
|
| 348 |
+
attentions = []
|
| 349 |
+
audio_attentions = []
|
| 350 |
+
motion_modules = []
|
| 351 |
+
|
| 352 |
+
self.has_cross_attention = True
|
| 353 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 354 |
+
|
| 355 |
+
for i in range(num_layers):
|
| 356 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 357 |
+
resnets.append(
|
| 358 |
+
ResnetBlock3D(
|
| 359 |
+
in_channels=in_channels,
|
| 360 |
+
out_channels=out_channels,
|
| 361 |
+
temb_channels=temb_channels,
|
| 362 |
+
eps=resnet_eps,
|
| 363 |
+
groups=resnet_groups,
|
| 364 |
+
dropout=dropout,
|
| 365 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 366 |
+
non_linearity=resnet_act_fn,
|
| 367 |
+
output_scale_factor=output_scale_factor,
|
| 368 |
+
pre_norm=resnet_pre_norm,
|
| 369 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 370 |
+
)
|
| 371 |
+
)
|
| 372 |
+
if dual_cross_attention:
|
| 373 |
+
raise NotImplementedError
|
| 374 |
+
attentions.append(
|
| 375 |
+
Transformer3DModel(
|
| 376 |
+
attn_num_head_channels,
|
| 377 |
+
out_channels // attn_num_head_channels,
|
| 378 |
+
in_channels=out_channels,
|
| 379 |
+
num_layers=1,
|
| 380 |
+
cross_attention_dim=cross_attention_dim,
|
| 381 |
+
norm_num_groups=resnet_groups,
|
| 382 |
+
use_linear_projection=use_linear_projection,
|
| 383 |
+
only_cross_attention=only_cross_attention,
|
| 384 |
+
upcast_attention=upcast_attention,
|
| 385 |
+
use_motion_module=use_motion_module,
|
| 386 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 387 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 388 |
+
add_audio_layer=add_audio_layer,
|
| 389 |
+
audio_condition_method=audio_condition_method,
|
| 390 |
+
)
|
| 391 |
+
)
|
| 392 |
+
audio_attentions.append(
|
| 393 |
+
Transformer3DModel(
|
| 394 |
+
attn_num_head_channels,
|
| 395 |
+
out_channels // attn_num_head_channels,
|
| 396 |
+
in_channels=out_channels,
|
| 397 |
+
num_layers=1,
|
| 398 |
+
cross_attention_dim=cross_attention_dim,
|
| 399 |
+
norm_num_groups=resnet_groups,
|
| 400 |
+
use_linear_projection=use_linear_projection,
|
| 401 |
+
only_cross_attention=only_cross_attention,
|
| 402 |
+
upcast_attention=upcast_attention,
|
| 403 |
+
use_motion_module=use_motion_module,
|
| 404 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 405 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 406 |
+
add_audio_layer=add_audio_layer,
|
| 407 |
+
audio_condition_method=audio_condition_method,
|
| 408 |
+
custom_audio_layer=True,
|
| 409 |
+
)
|
| 410 |
+
if custom_audio_layer
|
| 411 |
+
else None
|
| 412 |
+
)
|
| 413 |
+
motion_modules.append(
|
| 414 |
+
get_motion_module(
|
| 415 |
+
in_channels=out_channels,
|
| 416 |
+
motion_module_type=motion_module_type,
|
| 417 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 418 |
+
)
|
| 419 |
+
if use_motion_module
|
| 420 |
+
else None
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
self.attentions = nn.ModuleList(attentions)
|
| 424 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
| 425 |
+
self.resnets = nn.ModuleList(resnets)
|
| 426 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 427 |
+
|
| 428 |
+
if add_downsample:
|
| 429 |
+
self.downsamplers = nn.ModuleList(
|
| 430 |
+
[
|
| 431 |
+
Downsample3D(
|
| 432 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
| 433 |
+
)
|
| 434 |
+
]
|
| 435 |
+
)
|
| 436 |
+
else:
|
| 437 |
+
self.downsamplers = None
|
| 438 |
+
|
| 439 |
+
self.gradient_checkpointing = False
|
| 440 |
+
|
| 441 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
| 442 |
+
output_states = ()
|
| 443 |
+
|
| 444 |
+
for resnet, attn, audio_attn, motion_module in zip(
|
| 445 |
+
self.resnets, self.attentions, self.audio_attentions, self.motion_modules
|
| 446 |
+
):
|
| 447 |
+
if self.training and self.gradient_checkpointing:
|
| 448 |
+
|
| 449 |
+
def create_custom_forward(module, return_dict=None):
|
| 450 |
+
def custom_forward(*inputs):
|
| 451 |
+
if return_dict is not None:
|
| 452 |
+
return module(*inputs, return_dict=return_dict)
|
| 453 |
+
else:
|
| 454 |
+
return module(*inputs)
|
| 455 |
+
|
| 456 |
+
return custom_forward
|
| 457 |
+
|
| 458 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 459 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 460 |
+
create_custom_forward(attn, return_dict=False),
|
| 461 |
+
hidden_states,
|
| 462 |
+
encoder_hidden_states,
|
| 463 |
+
)[0]
|
| 464 |
+
if motion_module is not None:
|
| 465 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 466 |
+
create_custom_forward(motion_module),
|
| 467 |
+
hidden_states.requires_grad_(),
|
| 468 |
+
temb,
|
| 469 |
+
encoder_hidden_states,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
else:
|
| 473 |
+
hidden_states = resnet(hidden_states, temb)
|
| 474 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 475 |
+
|
| 476 |
+
hidden_states = (
|
| 477 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 478 |
+
if audio_attn is not None
|
| 479 |
+
else hidden_states
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# add motion module
|
| 483 |
+
hidden_states = (
|
| 484 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
| 485 |
+
if motion_module is not None
|
| 486 |
+
else hidden_states
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
output_states += (hidden_states,)
|
| 490 |
+
|
| 491 |
+
if self.downsamplers is not None:
|
| 492 |
+
for downsampler in self.downsamplers:
|
| 493 |
+
hidden_states = downsampler(hidden_states)
|
| 494 |
+
|
| 495 |
+
output_states += (hidden_states,)
|
| 496 |
+
|
| 497 |
+
return hidden_states, output_states
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class DownBlock3D(nn.Module):
|
| 501 |
+
def __init__(
|
| 502 |
+
self,
|
| 503 |
+
in_channels: int,
|
| 504 |
+
out_channels: int,
|
| 505 |
+
temb_channels: int,
|
| 506 |
+
dropout: float = 0.0,
|
| 507 |
+
num_layers: int = 1,
|
| 508 |
+
resnet_eps: float = 1e-6,
|
| 509 |
+
resnet_time_scale_shift: str = "default",
|
| 510 |
+
resnet_act_fn: str = "swish",
|
| 511 |
+
resnet_groups: int = 32,
|
| 512 |
+
resnet_pre_norm: bool = True,
|
| 513 |
+
output_scale_factor=1.0,
|
| 514 |
+
add_downsample=True,
|
| 515 |
+
downsample_padding=1,
|
| 516 |
+
use_inflated_groupnorm=False,
|
| 517 |
+
use_motion_module=None,
|
| 518 |
+
motion_module_type=None,
|
| 519 |
+
motion_module_kwargs=None,
|
| 520 |
+
):
|
| 521 |
+
super().__init__()
|
| 522 |
+
resnets = []
|
| 523 |
+
motion_modules = []
|
| 524 |
+
|
| 525 |
+
for i in range(num_layers):
|
| 526 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 527 |
+
resnets.append(
|
| 528 |
+
ResnetBlock3D(
|
| 529 |
+
in_channels=in_channels,
|
| 530 |
+
out_channels=out_channels,
|
| 531 |
+
temb_channels=temb_channels,
|
| 532 |
+
eps=resnet_eps,
|
| 533 |
+
groups=resnet_groups,
|
| 534 |
+
dropout=dropout,
|
| 535 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 536 |
+
non_linearity=resnet_act_fn,
|
| 537 |
+
output_scale_factor=output_scale_factor,
|
| 538 |
+
pre_norm=resnet_pre_norm,
|
| 539 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 540 |
+
)
|
| 541 |
+
)
|
| 542 |
+
motion_modules.append(
|
| 543 |
+
get_motion_module(
|
| 544 |
+
in_channels=out_channels,
|
| 545 |
+
motion_module_type=motion_module_type,
|
| 546 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 547 |
+
)
|
| 548 |
+
if use_motion_module
|
| 549 |
+
else None
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
self.resnets = nn.ModuleList(resnets)
|
| 553 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 554 |
+
|
| 555 |
+
if add_downsample:
|
| 556 |
+
self.downsamplers = nn.ModuleList(
|
| 557 |
+
[
|
| 558 |
+
Downsample3D(
|
| 559 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
| 560 |
+
)
|
| 561 |
+
]
|
| 562 |
+
)
|
| 563 |
+
else:
|
| 564 |
+
self.downsamplers = None
|
| 565 |
+
|
| 566 |
+
self.gradient_checkpointing = False
|
| 567 |
+
|
| 568 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
| 569 |
+
output_states = ()
|
| 570 |
+
|
| 571 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
| 572 |
+
if self.training and self.gradient_checkpointing:
|
| 573 |
+
|
| 574 |
+
def create_custom_forward(module):
|
| 575 |
+
def custom_forward(*inputs):
|
| 576 |
+
return module(*inputs)
|
| 577 |
+
|
| 578 |
+
return custom_forward
|
| 579 |
+
|
| 580 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 581 |
+
if motion_module is not None:
|
| 582 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 583 |
+
create_custom_forward(motion_module),
|
| 584 |
+
hidden_states.requires_grad_(),
|
| 585 |
+
temb,
|
| 586 |
+
encoder_hidden_states,
|
| 587 |
+
)
|
| 588 |
+
else:
|
| 589 |
+
hidden_states = resnet(hidden_states, temb)
|
| 590 |
+
|
| 591 |
+
# add motion module
|
| 592 |
+
hidden_states = (
|
| 593 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
| 594 |
+
if motion_module is not None
|
| 595 |
+
else hidden_states
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
output_states += (hidden_states,)
|
| 599 |
+
|
| 600 |
+
if self.downsamplers is not None:
|
| 601 |
+
for downsampler in self.downsamplers:
|
| 602 |
+
hidden_states = downsampler(hidden_states)
|
| 603 |
+
|
| 604 |
+
output_states += (hidden_states,)
|
| 605 |
+
|
| 606 |
+
return hidden_states, output_states
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
class CrossAttnUpBlock3D(nn.Module):
|
| 610 |
+
def __init__(
|
| 611 |
+
self,
|
| 612 |
+
in_channels: int,
|
| 613 |
+
out_channels: int,
|
| 614 |
+
prev_output_channel: int,
|
| 615 |
+
temb_channels: int,
|
| 616 |
+
dropout: float = 0.0,
|
| 617 |
+
num_layers: int = 1,
|
| 618 |
+
resnet_eps: float = 1e-6,
|
| 619 |
+
resnet_time_scale_shift: str = "default",
|
| 620 |
+
resnet_act_fn: str = "swish",
|
| 621 |
+
resnet_groups: int = 32,
|
| 622 |
+
resnet_pre_norm: bool = True,
|
| 623 |
+
attn_num_head_channels=1,
|
| 624 |
+
cross_attention_dim=1280,
|
| 625 |
+
output_scale_factor=1.0,
|
| 626 |
+
add_upsample=True,
|
| 627 |
+
dual_cross_attention=False,
|
| 628 |
+
use_linear_projection=False,
|
| 629 |
+
only_cross_attention=False,
|
| 630 |
+
upcast_attention=False,
|
| 631 |
+
unet_use_cross_frame_attention=False,
|
| 632 |
+
unet_use_temporal_attention=False,
|
| 633 |
+
use_inflated_groupnorm=False,
|
| 634 |
+
use_motion_module=None,
|
| 635 |
+
motion_module_type=None,
|
| 636 |
+
motion_module_kwargs=None,
|
| 637 |
+
add_audio_layer=False,
|
| 638 |
+
audio_condition_method="cross_attn",
|
| 639 |
+
custom_audio_layer=False,
|
| 640 |
+
):
|
| 641 |
+
super().__init__()
|
| 642 |
+
resnets = []
|
| 643 |
+
attentions = []
|
| 644 |
+
audio_attentions = []
|
| 645 |
+
motion_modules = []
|
| 646 |
+
|
| 647 |
+
self.has_cross_attention = True
|
| 648 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 649 |
+
|
| 650 |
+
for i in range(num_layers):
|
| 651 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
| 652 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 653 |
+
|
| 654 |
+
resnets.append(
|
| 655 |
+
ResnetBlock3D(
|
| 656 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 657 |
+
out_channels=out_channels,
|
| 658 |
+
temb_channels=temb_channels,
|
| 659 |
+
eps=resnet_eps,
|
| 660 |
+
groups=resnet_groups,
|
| 661 |
+
dropout=dropout,
|
| 662 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 663 |
+
non_linearity=resnet_act_fn,
|
| 664 |
+
output_scale_factor=output_scale_factor,
|
| 665 |
+
pre_norm=resnet_pre_norm,
|
| 666 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 667 |
+
)
|
| 668 |
+
)
|
| 669 |
+
if dual_cross_attention:
|
| 670 |
+
raise NotImplementedError
|
| 671 |
+
attentions.append(
|
| 672 |
+
Transformer3DModel(
|
| 673 |
+
attn_num_head_channels,
|
| 674 |
+
out_channels // attn_num_head_channels,
|
| 675 |
+
in_channels=out_channels,
|
| 676 |
+
num_layers=1,
|
| 677 |
+
cross_attention_dim=cross_attention_dim,
|
| 678 |
+
norm_num_groups=resnet_groups,
|
| 679 |
+
use_linear_projection=use_linear_projection,
|
| 680 |
+
only_cross_attention=only_cross_attention,
|
| 681 |
+
upcast_attention=upcast_attention,
|
| 682 |
+
use_motion_module=use_motion_module,
|
| 683 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 684 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 685 |
+
add_audio_layer=add_audio_layer,
|
| 686 |
+
audio_condition_method=audio_condition_method,
|
| 687 |
+
)
|
| 688 |
+
)
|
| 689 |
+
audio_attentions.append(
|
| 690 |
+
Transformer3DModel(
|
| 691 |
+
attn_num_head_channels,
|
| 692 |
+
out_channels // attn_num_head_channels,
|
| 693 |
+
in_channels=out_channels,
|
| 694 |
+
num_layers=1,
|
| 695 |
+
cross_attention_dim=cross_attention_dim,
|
| 696 |
+
norm_num_groups=resnet_groups,
|
| 697 |
+
use_linear_projection=use_linear_projection,
|
| 698 |
+
only_cross_attention=only_cross_attention,
|
| 699 |
+
upcast_attention=upcast_attention,
|
| 700 |
+
use_motion_module=use_motion_module,
|
| 701 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 702 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 703 |
+
add_audio_layer=add_audio_layer,
|
| 704 |
+
audio_condition_method=audio_condition_method,
|
| 705 |
+
custom_audio_layer=True,
|
| 706 |
+
)
|
| 707 |
+
if custom_audio_layer
|
| 708 |
+
else None
|
| 709 |
+
)
|
| 710 |
+
motion_modules.append(
|
| 711 |
+
get_motion_module(
|
| 712 |
+
in_channels=out_channels,
|
| 713 |
+
motion_module_type=motion_module_type,
|
| 714 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 715 |
+
)
|
| 716 |
+
if use_motion_module
|
| 717 |
+
else None
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
self.attentions = nn.ModuleList(attentions)
|
| 721 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
| 722 |
+
self.resnets = nn.ModuleList(resnets)
|
| 723 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 724 |
+
|
| 725 |
+
if add_upsample:
|
| 726 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
| 727 |
+
else:
|
| 728 |
+
self.upsamplers = None
|
| 729 |
+
|
| 730 |
+
self.gradient_checkpointing = False
|
| 731 |
+
|
| 732 |
+
def forward(
|
| 733 |
+
self,
|
| 734 |
+
hidden_states,
|
| 735 |
+
res_hidden_states_tuple,
|
| 736 |
+
temb=None,
|
| 737 |
+
encoder_hidden_states=None,
|
| 738 |
+
upsample_size=None,
|
| 739 |
+
attention_mask=None,
|
| 740 |
+
):
|
| 741 |
+
for resnet, attn, audio_attn, motion_module in zip(
|
| 742 |
+
self.resnets, self.attentions, self.audio_attentions, self.motion_modules
|
| 743 |
+
):
|
| 744 |
+
# pop res hidden states
|
| 745 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 746 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 747 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 748 |
+
|
| 749 |
+
if self.training and self.gradient_checkpointing:
|
| 750 |
+
|
| 751 |
+
def create_custom_forward(module, return_dict=None):
|
| 752 |
+
def custom_forward(*inputs):
|
| 753 |
+
if return_dict is not None:
|
| 754 |
+
return module(*inputs, return_dict=return_dict)
|
| 755 |
+
else:
|
| 756 |
+
return module(*inputs)
|
| 757 |
+
|
| 758 |
+
return custom_forward
|
| 759 |
+
|
| 760 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 761 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 762 |
+
create_custom_forward(attn, return_dict=False),
|
| 763 |
+
hidden_states,
|
| 764 |
+
encoder_hidden_states,
|
| 765 |
+
)[0]
|
| 766 |
+
if motion_module is not None:
|
| 767 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 768 |
+
create_custom_forward(motion_module),
|
| 769 |
+
hidden_states.requires_grad_(),
|
| 770 |
+
temb,
|
| 771 |
+
encoder_hidden_states,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
else:
|
| 775 |
+
hidden_states = resnet(hidden_states, temb)
|
| 776 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 777 |
+
hidden_states = (
|
| 778 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 779 |
+
if audio_attn is not None
|
| 780 |
+
else hidden_states
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# add motion module
|
| 784 |
+
hidden_states = (
|
| 785 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
| 786 |
+
if motion_module is not None
|
| 787 |
+
else hidden_states
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
if self.upsamplers is not None:
|
| 791 |
+
for upsampler in self.upsamplers:
|
| 792 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 793 |
+
|
| 794 |
+
return hidden_states
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
class UpBlock3D(nn.Module):
|
| 798 |
+
def __init__(
|
| 799 |
+
self,
|
| 800 |
+
in_channels: int,
|
| 801 |
+
prev_output_channel: int,
|
| 802 |
+
out_channels: int,
|
| 803 |
+
temb_channels: int,
|
| 804 |
+
dropout: float = 0.0,
|
| 805 |
+
num_layers: int = 1,
|
| 806 |
+
resnet_eps: float = 1e-6,
|
| 807 |
+
resnet_time_scale_shift: str = "default",
|
| 808 |
+
resnet_act_fn: str = "swish",
|
| 809 |
+
resnet_groups: int = 32,
|
| 810 |
+
resnet_pre_norm: bool = True,
|
| 811 |
+
output_scale_factor=1.0,
|
| 812 |
+
add_upsample=True,
|
| 813 |
+
use_inflated_groupnorm=False,
|
| 814 |
+
use_motion_module=None,
|
| 815 |
+
motion_module_type=None,
|
| 816 |
+
motion_module_kwargs=None,
|
| 817 |
+
):
|
| 818 |
+
super().__init__()
|
| 819 |
+
resnets = []
|
| 820 |
+
motion_modules = []
|
| 821 |
+
|
| 822 |
+
for i in range(num_layers):
|
| 823 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
| 824 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 825 |
+
|
| 826 |
+
resnets.append(
|
| 827 |
+
ResnetBlock3D(
|
| 828 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 829 |
+
out_channels=out_channels,
|
| 830 |
+
temb_channels=temb_channels,
|
| 831 |
+
eps=resnet_eps,
|
| 832 |
+
groups=resnet_groups,
|
| 833 |
+
dropout=dropout,
|
| 834 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 835 |
+
non_linearity=resnet_act_fn,
|
| 836 |
+
output_scale_factor=output_scale_factor,
|
| 837 |
+
pre_norm=resnet_pre_norm,
|
| 838 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 839 |
+
)
|
| 840 |
+
)
|
| 841 |
+
motion_modules.append(
|
| 842 |
+
get_motion_module(
|
| 843 |
+
in_channels=out_channels,
|
| 844 |
+
motion_module_type=motion_module_type,
|
| 845 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 846 |
+
)
|
| 847 |
+
if use_motion_module
|
| 848 |
+
else None
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
self.resnets = nn.ModuleList(resnets)
|
| 852 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 853 |
+
|
| 854 |
+
if add_upsample:
|
| 855 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
| 856 |
+
else:
|
| 857 |
+
self.upsamplers = None
|
| 858 |
+
|
| 859 |
+
self.gradient_checkpointing = False
|
| 860 |
+
|
| 861 |
+
def forward(
|
| 862 |
+
self,
|
| 863 |
+
hidden_states,
|
| 864 |
+
res_hidden_states_tuple,
|
| 865 |
+
temb=None,
|
| 866 |
+
upsample_size=None,
|
| 867 |
+
encoder_hidden_states=None,
|
| 868 |
+
):
|
| 869 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
| 870 |
+
# pop res hidden states
|
| 871 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 872 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 873 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 874 |
+
|
| 875 |
+
if self.training and self.gradient_checkpointing:
|
| 876 |
+
|
| 877 |
+
def create_custom_forward(module):
|
| 878 |
+
def custom_forward(*inputs):
|
| 879 |
+
return module(*inputs)
|
| 880 |
+
|
| 881 |
+
return custom_forward
|
| 882 |
+
|
| 883 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 884 |
+
if motion_module is not None:
|
| 885 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 886 |
+
create_custom_forward(motion_module),
|
| 887 |
+
hidden_states.requires_grad_(),
|
| 888 |
+
temb,
|
| 889 |
+
encoder_hidden_states,
|
| 890 |
+
)
|
| 891 |
+
else:
|
| 892 |
+
hidden_states = resnet(hidden_states, temb)
|
| 893 |
+
hidden_states = (
|
| 894 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
| 895 |
+
if motion_module is not None
|
| 896 |
+
else hidden_states
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
if self.upsamplers is not None:
|
| 900 |
+
for upsampler in self.upsamplers:
|
| 901 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 902 |
+
|
| 903 |
+
return hidden_states
|