LiuhanChen commited on
Commit
e6c7152
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
3
+ *.png filter=lfs diff=lfs merge=lfs -text
4
+ *.gif filter=lfs diff=lfs merge=lfs -text
5
+ *.bmp filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sci-Fi
2
+ Official PyTorch implementation of "Sci-Fi: Symmetric Constraint for Frame Inbetweening"
3
+
4
+ ## Quick Start
5
+ ### 1. Setup repository and environment
6
+ ```
7
+ git clone https://github.com/LiuhanChen-github/Sci-Fi.git
8
+ cd Sci-Fi
9
+ conda create -n Sci-Fi python==3.12
10
+ pip install -r requirements.txt
11
+ ```
12
+ ### 2. Download checkpoint
13
+ Download the CogVideoX-I2V-5B model (due to fine-tuning, the weights of the transformer denoiser are different from the original) and EF-Net. [checkpoint](https://drive.google.com/drive/folders/1H7vgiNVbxSeeleyJOqhoyRbJ97kGWGOK?usp=sharing)
14
+
15
+ ### 3. Launch the inference script!
16
+ The example input keyframe pairs are in `examples/` folder, and
17
+ the corresponding generated videos (720x480, 49 frames) are placed in `outputs/` folder.
18
+ </br>
19
+ To interpolate, run:
20
+ ```
21
+ bash Sci_Fi_frame_inbetweening.sh
22
+ ```
cogvideo_EF_Net.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from einops import rearrange
6
+ import torch.nn.functional as F
7
+ from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
8
+ from diffusers.utils import is_torch_version
9
+ from diffusers.loaders import PeftAdapterMixin
10
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
11
+ from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.models.attention import Attention, FeedForward
14
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor2_0
15
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero, AdaLayerNormZeroSingle
16
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
17
+
18
+
19
+ class CogVideoX_EF_Net(ModelMixin, ConfigMixin, PeftAdapterMixin):
20
+ _supports_gradient_checkpointing = True
21
+
22
+ @register_to_config
23
+ def __init__(
24
+ self,
25
+ num_attention_heads: int = 30,
26
+ attention_head_dim: int = 64,
27
+ vae_channels: int = 16,
28
+ in_channels: int = 3,
29
+ downscale_coef: int = 8,
30
+ flip_sin_to_cos: bool = True,
31
+ freq_shift: int = 0,
32
+ time_embed_dim: int = 512,
33
+ num_layers: int = 8,
34
+ dropout: float = 0.0,
35
+ attention_bias: bool = True,
36
+ sample_width: int = 90,
37
+ sample_height: int = 60,
38
+ sample_frames: int = 1,
39
+ patch_size: int = 2,
40
+ temporal_compression_ratio: int = 4,
41
+ max_text_seq_length: int = 226,
42
+ activation_fn: str = "gelu-approximate",
43
+ timestep_activation_fn: str = "silu",
44
+ norm_elementwise_affine: bool = True,
45
+ norm_eps: float = 1e-5,
46
+ spatial_interpolation_scale: float = 1.875,
47
+ temporal_interpolation_scale: float = 1.0,
48
+ use_rotary_positional_embeddings: bool = False,
49
+ use_learned_positional_embeddings: bool = False,
50
+ out_proj_dim = None,
51
+ ):
52
+ super().__init__()
53
+ inner_dim = num_attention_heads * attention_head_dim
54
+ out_proj_dim = inner_dim
55
+
56
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
57
+ raise ValueError(
58
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
59
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
60
+ "issue at https://github.com/huggingface/diffusers/issues."
61
+ )
62
+
63
+ # 1. Patch embedding
64
+ self.patch_embed = CogVideoXPatchEmbed(
65
+ patch_size=patch_size,
66
+ in_channels=vae_channels,
67
+ embed_dim=inner_dim,
68
+ bias=True,
69
+ sample_width=sample_width,
70
+ sample_height=sample_height,
71
+ sample_frames=49,
72
+ temporal_compression_ratio=temporal_compression_ratio,
73
+ spatial_interpolation_scale=spatial_interpolation_scale,
74
+ temporal_interpolation_scale=temporal_interpolation_scale,
75
+ use_positional_embeddings=not use_rotary_positional_embeddings,
76
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
77
+ )
78
+
79
+ self.patch_embed_first = CogVideoXPatchEmbed(
80
+ patch_size=patch_size,
81
+ in_channels=vae_channels,
82
+ embed_dim=inner_dim,
83
+ bias=True,
84
+ sample_width=sample_width,
85
+ sample_height=sample_height,
86
+ sample_frames=sample_frames,
87
+ temporal_compression_ratio=temporal_compression_ratio,
88
+ spatial_interpolation_scale=spatial_interpolation_scale,
89
+ temporal_interpolation_scale=temporal_interpolation_scale,
90
+ use_positional_embeddings=not use_rotary_positional_embeddings,
91
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
92
+ )
93
+
94
+ self.embedding_dropout = nn.Dropout(dropout)
95
+ self.weights = nn.ModuleList([nn.Linear(inner_dim, 13) for _ in range(num_layers)])
96
+ self.first_weights = nn.ModuleList([nn.Linear(2*inner_dim, inner_dim) for _ in range(num_layers)])
97
+
98
+ # 2. Time embeddings
99
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
100
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
101
+
102
+ # 3. Define spatio-temporal transformers blocks
103
+ self.transformer_blocks = nn.ModuleList(
104
+ [
105
+ CogVideoXBlock(
106
+ dim=inner_dim,
107
+ num_attention_heads=num_attention_heads,
108
+ attention_head_dim=attention_head_dim,
109
+ time_embed_dim=time_embed_dim,
110
+ dropout=dropout,
111
+ activation_fn=activation_fn,
112
+ attention_bias=attention_bias,
113
+ norm_elementwise_affine=norm_elementwise_affine,
114
+ norm_eps=norm_eps,
115
+ )
116
+ for _ in range(num_layers)
117
+ ]
118
+ )
119
+
120
+ self.out_projectors = None
121
+ self.relu = nn.LeakyReLU(negative_slope=0.01)
122
+
123
+ if out_proj_dim is not None:
124
+ self.out_projectors = nn.ModuleList(
125
+ [nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)]
126
+ )
127
+
128
+ self.gradient_checkpointing = False
129
+
130
+ def _set_gradient_checkpointing(self, enable=False, gradient_checkpointing_func=None):
131
+ self.gradient_checkpointing = enable
132
+
133
+
134
+ def forward(
135
+ self,
136
+ hidden_states: torch.Tensor,
137
+ encoder_hidden_states: torch.Tensor,
138
+ EF_Net_states: torch.Tensor,
139
+ timestep: Union[int, float, torch.LongTensor],
140
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
141
+ timestep_cond: Optional[torch.Tensor] = None,
142
+ return_dict: bool = True,
143
+ ):
144
+ batch_size, num_frames, channels, height, width = EF_Net_states.shape
145
+ o_hidden_states = hidden_states
146
+ hidden_states = EF_Net_states
147
+ encoder_hidden_states_ = encoder_hidden_states
148
+
149
+ # 1. Time embedding
150
+ timesteps = timestep
151
+ t_emb = self.time_proj(timesteps)
152
+
153
+
154
+ # timesteps does not contain any weights and will always return f32 tensors
155
+ # but time_embedding might actually be running in fp16. so we need to cast here.
156
+ # there might be better ways to encapsulate this.
157
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
158
+ emb = self.time_embedding(t_emb, timestep_cond)
159
+
160
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
161
+ hidden_states = self.embedding_dropout(hidden_states)
162
+
163
+ text_seq_length = encoder_hidden_states.shape[1]
164
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
165
+ hidden_states = hidden_states[:, text_seq_length:]
166
+
167
+ o_hidden_states = self.patch_embed_first(encoder_hidden_states_, o_hidden_states)
168
+ o_hidden_states = self.embedding_dropout(o_hidden_states)
169
+
170
+ text_seq_length = encoder_hidden_states_.shape[1]
171
+ o_hidden_states = o_hidden_states[:, text_seq_length:]
172
+
173
+ EF_Net_hidden_states = ()
174
+ # 2. Transformer blocks
175
+ for i, block in enumerate(self.transformer_blocks):
176
+ #if self.training and self.gradient_checkpointing:
177
+ if self.gradient_checkpointing:
178
+
179
+ def create_custom_forward(module):
180
+ def custom_forward(*inputs):
181
+ return module(*inputs)
182
+
183
+ return custom_forward
184
+
185
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
186
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
187
+ create_custom_forward(block),
188
+ hidden_states,
189
+ encoder_hidden_states,
190
+ emb,
191
+ image_rotary_emb,
192
+ **ckpt_kwargs,
193
+ )
194
+ else:
195
+ hidden_states, encoder_hidden_states = block(
196
+ hidden_states=hidden_states,
197
+ encoder_hidden_states=encoder_hidden_states,
198
+ temb=emb,
199
+ image_rotary_emb=image_rotary_emb,
200
+ )
201
+
202
+
203
+ if self.out_projectors is not None:
204
+ coff = self.weights[i](hidden_states)
205
+ temp_list = []
206
+ for j in range(coff.shape[2]):
207
+ temp_list.append(hidden_states*coff[:,:,j:(j+1)])
208
+ out = torch.concat(temp_list, dim=1)
209
+ out = torch.concat([out, o_hidden_states], dim=2)
210
+ out = self.first_weights[i](out)
211
+ out = self.relu(out)
212
+ out = self.out_projectors[i](out)
213
+ EF_Net_hidden_states += (out,)
214
+ else:
215
+ out = torch.concat([weight*hidden_states for weight in self.weights], dim=1)
216
+ EF_Net_hidden_states += (out,)
217
+
218
+ if not return_dict:
219
+ return (EF_Net_hidden_states,)
220
+ return Transformer2DModelOutput(sample=EF_Net_hidden_states)
221
+
222
+
cogvideo_Sci_Fi_inbetweening_pipeline.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import numpy as np
7
+ import PIL
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange, repeat
11
+ from transformers import T5EncoderModel, T5Tokenizer
12
+ from diffusers.video_processor import VideoProcessor
13
+ from diffusers.utils.torch_utils import randn_tensor
14
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
15
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
17
+ from diffusers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler, CogVideoXImageToVideoPipeline
18
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
19
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipelineOutput, CogVideoXLoraLoaderMixin
20
+
21
+ from cogvideo_EF_Net import CogVideoX_EF_Net
22
+ import torch
23
+
24
+ def resize_for_crop(image, crop_h, crop_w):
25
+ img_h, img_w = image.shape[-2:]
26
+ if img_h >= crop_h and img_w >= crop_w:
27
+ coef = max(crop_h / img_h, crop_w / img_w)
28
+ elif img_h <= crop_h and img_w <= crop_w:
29
+ coef = max(crop_h / img_h, crop_w / img_w)
30
+ else:
31
+ coef = crop_h / img_h if crop_h > img_h else crop_w / img_w
32
+ out_h, out_w = int(img_h * coef), int(img_w * coef)
33
+ resized_image = transforms.functional.resize(image, (out_h, out_w), antialias=True)
34
+ return resized_image
35
+
36
+
37
+ def prepare_frames(input_images, video_size, do_resize=True, do_crop=True):
38
+ images_tensor = input_images
39
+ if do_resize:
40
+ images_tensor = [resize_for_crop(x, crop_h=video_size[0], crop_w=video_size[1]) for x in images_tensor]
41
+ if do_crop:
42
+ images_tensor = [transforms.functional.center_crop(x, video_size) for x in images_tensor]
43
+ if isinstance(images_tensor, list):
44
+ images_tensor = torch.stack(images_tensor)
45
+ print(images_tensor.shape)
46
+ return images_tensor.unsqueeze(0)
47
+
48
+
49
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
50
+ tw = tgt_width
51
+ th = tgt_height
52
+ h, w = src
53
+ r = h / w
54
+ if r > (th / tw):
55
+ resize_height = th
56
+ resize_width = int(round(th / h * w))
57
+ else:
58
+ resize_width = tw
59
+ resize_height = int(round(tw / w * h))
60
+
61
+ crop_top = int(round((th - resize_height) / 2.0))
62
+ crop_left = int(round((tw - resize_width) / 2.0))
63
+
64
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
65
+
66
+
67
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
68
+ def retrieve_timesteps(
69
+ scheduler,
70
+ num_inference_steps: Optional[int] = None,
71
+ device: Optional[Union[str, torch.device]] = None,
72
+ timesteps: Optional[List[int]] = None,
73
+ sigmas: Optional[List[float]] = None,
74
+ **kwargs,
75
+ ):
76
+ """
77
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
78
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
79
+
80
+ Args:
81
+ scheduler (`SchedulerMixin`):
82
+ The scheduler to get timesteps from.
83
+ num_inference_steps (`int`):
84
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
85
+ must be `None`.
86
+ device (`str` or `torch.device`, *optional*):
87
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
88
+ timesteps (`List[int]`, *optional*):
89
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
90
+ `num_inference_steps` and `sigmas` must be `None`.
91
+ sigmas (`List[float]`, *optional*):
92
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
93
+ `num_inference_steps` and `timesteps` must be `None`.
94
+
95
+ Returns:
96
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
97
+ second element is the number of inference steps.
98
+ """
99
+ if timesteps is not None and sigmas is not None:
100
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
101
+ if timesteps is not None:
102
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
103
+ if not accepts_timesteps:
104
+ raise ValueError(
105
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
106
+ f" timestep schedules. Please check whether you are using the correct scheduler."
107
+ )
108
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
109
+ timesteps = scheduler.timesteps
110
+ num_inference_steps = len(timesteps)
111
+ elif sigmas is not None:
112
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
113
+ if not accept_sigmas:
114
+ raise ValueError(
115
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
117
+ )
118
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+ else:
122
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
123
+ timesteps = scheduler.timesteps
124
+ return timesteps, num_inference_steps
125
+
126
+
127
+ def retrieve_latents(
128
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
129
+ ):
130
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
131
+ return encoder_output.latent_dist.sample(generator)
132
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
133
+ return encoder_output.latent_dist.mode()
134
+ elif hasattr(encoder_output, "latents"):
135
+ return encoder_output.latents
136
+ else:
137
+ raise AttributeError("Could not access latents of provided encoder_output")
138
+
139
+
140
+ class CogVideoXEFNetInbetweeningPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
141
+ r"""
142
+ Pipeline for frame inbetweening generation using CogVideoX.
143
+
144
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
145
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
146
+
147
+ Args:
148
+ vae ([`AutoencoderKL`]):
149
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
150
+ text_encoder ([`T5EncoderModel`]):
151
+ Frozen text-encoder. CogVideoX uses
152
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
153
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
154
+ tokenizer (`T5Tokenizer`):
155
+ Tokenizer of class
156
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
157
+ transformer ([`CogVideoXTransformer3DModel`]):
158
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
159
+ scheduler ([`SchedulerMixin`]):
160
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
161
+ EF-Net ([CogVideoX_EF_Net]):
162
+ Our proposed EF-Net.
163
+ """
164
+
165
+ _optional_components = []
166
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
167
+
168
+ _callback_tensor_inputs = [
169
+ "latents",
170
+ "prompt_embeds",
171
+ "negative_prompt_embeds",
172
+ ]
173
+
174
+ def __init__(
175
+ self,
176
+ tokenizer: T5Tokenizer,
177
+ text_encoder: T5EncoderModel,
178
+ vae: AutoencoderKLCogVideoX,
179
+ transformer: CogVideoXTransformer3DModel,
180
+ EF_Net: CogVideoX_EF_Net,
181
+ scheduler: CogVideoXDDIMScheduler,
182
+ ):
183
+ super().__init__()
184
+
185
+ self.register_modules(
186
+ tokenizer=tokenizer,
187
+ text_encoder=text_encoder,
188
+ vae=vae,
189
+ transformer=transformer,
190
+ EF_Net=EF_Net,
191
+ scheduler=scheduler,
192
+ )
193
+ self.vae_scale_factor_spatial = (
194
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
195
+ )
196
+ self.vae_scale_factor_temporal = (
197
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
198
+ )
199
+ self.vae_scaling_factor_image = (
200
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
201
+ )
202
+
203
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
204
+
205
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
206
+ def _get_t5_prompt_embeds(
207
+ self,
208
+ prompt: Union[str, List[str]] = None,
209
+ num_videos_per_prompt: int = 1,
210
+ max_sequence_length: int = 226,
211
+ device: Optional[torch.device] = None,
212
+ dtype: Optional[torch.dtype] = None,
213
+ ):
214
+ device = device or self._execution_device
215
+ dtype = dtype or self.text_encoder.dtype
216
+
217
+ prompt = [prompt] if isinstance(prompt, str) else prompt
218
+ batch_size = len(prompt)
219
+
220
+ text_inputs = self.tokenizer(
221
+ prompt,
222
+ padding="max_length",
223
+ max_length=max_sequence_length,
224
+ truncation=True,
225
+ add_special_tokens=True,
226
+ return_tensors="pt",
227
+ )
228
+ text_input_ids = text_inputs.input_ids
229
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
230
+
231
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
232
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
233
+ """
234
+ logger.warning(
235
+ "The following part of your input was truncated because `max_sequence_length` is set to "
236
+ f" {max_sequence_length} tokens: {removed_text}"
237
+ )
238
+ """
239
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
240
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
241
+
242
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
243
+ _, seq_len, _ = prompt_embeds.shape
244
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
245
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
246
+
247
+ return prompt_embeds
248
+
249
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
250
+ def encode_prompt(
251
+ self,
252
+ prompt: Union[str, List[str]],
253
+ negative_prompt: Optional[Union[str, List[str]]] = None,
254
+ do_classifier_free_guidance: bool = True,
255
+ num_videos_per_prompt: int = 1,
256
+ prompt_embeds: Optional[torch.Tensor] = None,
257
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
258
+ max_sequence_length: int = 226,
259
+ device: Optional[torch.device] = None,
260
+ dtype: Optional[torch.dtype] = None,
261
+ ):
262
+ r"""
263
+ Encodes the prompt into text encoder hidden states.
264
+
265
+ Args:
266
+ prompt (`str` or `List[str]`, *optional*):
267
+ prompt to be encoded
268
+ negative_prompt (`str` or `List[str]`, *optional*):
269
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
270
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
271
+ less than `1`).
272
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
273
+ Whether to use classifier free guidance or not.
274
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
275
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
276
+ prompt_embeds (`torch.Tensor`, *optional*):
277
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
278
+ provided, text embeddings will be generated from `prompt` input argument.
279
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
280
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
281
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
282
+ argument.
283
+ device: (`torch.device`, *optional*):
284
+ torch device
285
+ dtype: (`torch.dtype`, *optional*):
286
+ torch dtype
287
+ """
288
+ device = device or self._execution_device
289
+
290
+ prompt = [prompt] if isinstance(prompt, str) else prompt
291
+ if prompt is not None:
292
+ batch_size = len(prompt)
293
+ else:
294
+ batch_size = prompt_embeds.shape[0]
295
+
296
+ if prompt_embeds is None:
297
+ prompt_embeds = self._get_t5_prompt_embeds(
298
+ prompt=prompt,
299
+ num_videos_per_prompt=num_videos_per_prompt,
300
+ max_sequence_length=max_sequence_length,
301
+ device=device,
302
+ dtype=dtype,
303
+ )
304
+
305
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
306
+ negative_prompt = negative_prompt or ""
307
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
308
+
309
+ if prompt is not None and type(prompt) is not type(negative_prompt):
310
+ raise TypeError(
311
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
312
+ f" {type(prompt)}."
313
+ )
314
+ elif batch_size != len(negative_prompt):
315
+ raise ValueError(
316
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
317
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
318
+ " the batch size of `prompt`."
319
+ )
320
+
321
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
322
+ prompt=negative_prompt,
323
+ num_videos_per_prompt=num_videos_per_prompt,
324
+ max_sequence_length=max_sequence_length,
325
+ device=device,
326
+ dtype=dtype,
327
+ )
328
+
329
+ return prompt_embeds, negative_prompt_embeds
330
+
331
+ def prepare_latents(
332
+ self,
333
+ first_image: torch.Tensor,
334
+ last_image: torch.Tensor,
335
+ batch_size: int = 1,
336
+ num_channels_latents: int = 16,
337
+ num_frames: int = 13,
338
+ height: int = 60,
339
+ width: int = 90,
340
+ dtype: Optional[torch.dtype] = None,
341
+ device: Optional[torch.device] = None,
342
+ generator: Optional[torch.Generator] = None,
343
+ latents: Optional[torch.Tensor] = None,
344
+ ):
345
+ num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
346
+ shape = (
347
+ batch_size,
348
+ num_frames,
349
+ num_channels_latents,
350
+ height // self.vae_scale_factor_spatial,
351
+ width // self.vae_scale_factor_spatial,
352
+ )
353
+
354
+ if isinstance(generator, list) and len(generator) != batch_size:
355
+ raise ValueError(
356
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
357
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
358
+ )
359
+
360
+ first_image = first_image.unsqueeze(2) # [B, C, F, H, W]
361
+ last_image = last_image.unsqueeze(2) # [B, C, F, H, W]
362
+
363
+ if isinstance(generator, list):
364
+ first_image_latents = [
365
+ retrieve_latents(self.vae.encode(first_image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
366
+ ]
367
+ else:
368
+ first_image_latents = [retrieve_latents(self.vae.encode(first_img.unsqueeze(0)), generator) for first_img in first_image]
369
+
370
+ if isinstance(generator, list):
371
+ last_image_latents = [
372
+ retrieve_latents(self.vae.encode(last_image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
373
+ ]
374
+ else:
375
+ last_image_latents = [retrieve_latents(self.vae.encode(last_img.unsqueeze(0)), generator) for last_img in last_image]
376
+
377
+
378
+ first_image_latents = torch.cat(first_image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
379
+ first_image_latents = self.vae.config.scaling_factor * first_image_latents
380
+ last_image_latents = torch.cat(last_image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
381
+ last_image_latents = self.vae.config.scaling_factor * last_image_latents
382
+
383
+
384
+ padding_shape = (
385
+ batch_size,
386
+ num_frames - 2,
387
+ num_channels_latents,
388
+ height // self.vae_scale_factor_spatial,
389
+ width // self.vae_scale_factor_spatial,
390
+ )
391
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
392
+ image_latents = torch.cat([first_image_latents, latent_padding, last_image_latents], dim=1)
393
+
394
+ if latents is None:
395
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
396
+ else:
397
+ latents = latents.to(device)
398
+
399
+ # scale the initial noise by the standard deviation required by the scheduler
400
+ latents = latents * self.scheduler.init_noise_sigma
401
+ return latents, image_latents
402
+
403
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
404
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
405
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
406
+ latents = 1 / self.vae_scaling_factor_image * latents
407
+
408
+ frames = self.vae.decode(latents).sample
409
+ return frames
410
+
411
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
412
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
413
+ # get the original timestep using init_timestep
414
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
415
+
416
+ t_start = max(num_inference_steps - init_timestep, 0)
417
+ timesteps = timesteps[t_start * self.scheduler.order :]
418
+
419
+ return timesteps, num_inference_steps - t_start
420
+
421
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
422
+ def prepare_extra_step_kwargs(self, generator, eta):
423
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
424
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
425
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
426
+ # and should be between [0, 1]
427
+
428
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
429
+ extra_step_kwargs = {}
430
+ if accepts_eta:
431
+ extra_step_kwargs["eta"] = eta
432
+
433
+ # check if the scheduler accepts generator
434
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
435
+ if accepts_generator:
436
+ extra_step_kwargs["generator"] = generator
437
+ return extra_step_kwargs
438
+
439
+ def check_inputs(
440
+ self,
441
+ image,
442
+ prompt,
443
+ height,
444
+ width,
445
+ negative_prompt,
446
+ callback_on_step_end_tensor_inputs,
447
+ latents=None,
448
+ prompt_embeds=None,
449
+ negative_prompt_embeds=None,
450
+ ):
451
+ if (
452
+ not isinstance(image, torch.Tensor)
453
+ and not isinstance(image, PIL.Image.Image)
454
+ and not isinstance(image, list)
455
+ ):
456
+ raise ValueError(
457
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
458
+ f" {type(image)}"
459
+ )
460
+
461
+ if height % 8 != 0 or width % 8 != 0:
462
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
463
+
464
+ if callback_on_step_end_tensor_inputs is not None and not all(
465
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
466
+ ):
467
+ raise ValueError(
468
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
469
+ )
470
+ if prompt is not None and prompt_embeds is not None:
471
+ raise ValueError(
472
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
473
+ " only forward one of the two."
474
+ )
475
+ elif prompt is None and prompt_embeds is None:
476
+ raise ValueError(
477
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
478
+ )
479
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
480
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
481
+
482
+ if prompt is not None and negative_prompt_embeds is not None:
483
+ raise ValueError(
484
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
485
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
486
+ )
487
+
488
+ if negative_prompt is not None and negative_prompt_embeds is not None:
489
+ raise ValueError(
490
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
491
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
492
+ )
493
+
494
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
495
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
496
+ raise ValueError(
497
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
498
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
499
+ f" {negative_prompt_embeds.shape}."
500
+ )
501
+
502
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
503
+ def fuse_qkv_projections(self) -> None:
504
+ r"""Enables fused QKV projections."""
505
+ self.fusing_transformer = True
506
+ self.transformer.fuse_qkv_projections()
507
+
508
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
509
+ def unfuse_qkv_projections(self) -> None:
510
+ r"""Disable QKV projection fusion if enabled."""
511
+ if not self.fusing_transformer:
512
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
513
+ else:
514
+ self.transformer.unfuse_qkv_projections()
515
+ self.fusing_transformer = False
516
+
517
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
518
+ def _prepare_rotary_positional_embeddings(
519
+ self,
520
+ height: int,
521
+ width: int,
522
+ num_frames: int,
523
+ device: torch.device,
524
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
525
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
526
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
527
+
528
+ p = self.transformer.config.patch_size
529
+ p_t = self.transformer.config.patch_size_t
530
+
531
+ base_size_width = self.transformer.config.sample_width // p
532
+ base_size_height = self.transformer.config.sample_height // p
533
+
534
+ if p_t is None:
535
+ # CogVideoX 1.0
536
+ grid_crops_coords = get_resize_crop_region_for_grid(
537
+ (grid_height, grid_width), base_size_width, base_size_height
538
+ )
539
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
540
+ embed_dim=self.transformer.config.attention_head_dim,
541
+ crops_coords=grid_crops_coords,
542
+ grid_size=(grid_height, grid_width),
543
+ temporal_size=num_frames,
544
+ )
545
+ else:
546
+ # CogVideoX 1.5
547
+ base_num_frames = (num_frames + p_t - 1) // p_t
548
+
549
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
550
+ embed_dim=self.transformer.config.attention_head_dim,
551
+ crops_coords=None,
552
+ grid_size=(grid_height, grid_width),
553
+ temporal_size=base_num_frames,
554
+ grid_type="slice",
555
+ max_size=(base_size_height, base_size_width),
556
+ )
557
+
558
+ freqs_cos = freqs_cos.to(device=device)
559
+ freqs_sin = freqs_sin.to(device=device)
560
+ return freqs_cos, freqs_sin
561
+
562
+ def prepare_EF_Net_frames(self, EF_Net_frames, height, width, do_classifier_free_guidance):
563
+ prepared_frames = prepare_frames(EF_Net_frames, (height, width))
564
+ EF_Net_encoded_frames = prepared_frames.to(dtype=self.vae.dtype, device='cuda')
565
+ EF_Net_encoded_frames = EF_Net_encoded_frames.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
566
+ EF_Net_encoded_frames = self.vae.encode(EF_Net_encoded_frames).latent_dist.sample() * self.vae.config.scaling_factor
567
+
568
+ EF_Net_encoded_frames = EF_Net_encoded_frames.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format)
569
+ EF_Net_encoded_frames = torch.cat([EF_Net_encoded_frames] * 2) if do_classifier_free_guidance else EF_Net_encoded_frames
570
+
571
+ return EF_Net_encoded_frames.contiguous()
572
+
573
+ @property
574
+ def guidance_scale(self):
575
+ return self._guidance_scale
576
+
577
+ @property
578
+ def num_timesteps(self):
579
+ return self._num_timesteps
580
+
581
+ @property
582
+ def attention_kwargs(self):
583
+ return self._attention_kwargs
584
+
585
+ @property
586
+ def interrupt(self):
587
+ return self._interrupt
588
+
589
+ @torch.no_grad()
590
+ def __call__(
591
+ self,
592
+ first_image,
593
+ last_image,
594
+ EF_Net_frames = None,
595
+ prompt: Optional[Union[str, List[str]]] = None,
596
+ negative_prompt: Optional[Union[str, List[str]]] = None,
597
+ height: Optional[int] = None,
598
+ width: Optional[int] = None,
599
+ num_frames: int = 49,
600
+ num_inference_steps: int = 50,
601
+ timesteps: Optional[List[int]] = None,
602
+ guidance_scale: float = 6,
603
+ use_dynamic_cfg: bool = False,
604
+ num_videos_per_prompt: int = 1,
605
+ eta: float = 0.0,
606
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
607
+ latents: Optional[torch.FloatTensor] = None,
608
+ prompt_embeds: Optional[torch.FloatTensor] = None,
609
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
610
+ EF_Net_latents: Optional[torch.FloatTensor] = None,
611
+ output_type: str = "pil",
612
+ return_dict: bool = True,
613
+ attention_kwargs: Optional[Dict[str, Any]] = None,
614
+ callback_on_step_end: Optional[
615
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
616
+ ] = None,
617
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
618
+ max_sequence_length: int = 226,
619
+ EF_Net_weights: Optional[Union[float, list, np.ndarray, torch.FloatTensor]] = 1.0,
620
+ EF_Net_guidance_start: float = 0.0,
621
+ EF_Net_guidance_end: float = 1.0,
622
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
623
+
624
+
625
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
626
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
627
+
628
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
629
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
630
+ num_frames = num_frames or self.transformer.config.sample_frames
631
+
632
+ num_videos_per_prompt = 1
633
+
634
+ # 1. Check inputs. Raise error if not correct
635
+ self.check_inputs(
636
+ image=first_image,
637
+ prompt=prompt,
638
+ height=height,
639
+ width=width,
640
+ negative_prompt=negative_prompt,
641
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
642
+ latents=latents,
643
+ prompt_embeds=prompt_embeds,
644
+ negative_prompt_embeds=negative_prompt_embeds,
645
+ )
646
+ self._guidance_scale = guidance_scale
647
+ self._attention_kwargs = attention_kwargs
648
+ self._interrupt = False
649
+
650
+ # 2. Default call parameters
651
+ if prompt is not None and isinstance(prompt, str):
652
+ batch_size = 1
653
+ elif prompt is not None and isinstance(prompt, list):
654
+ batch_size = len(prompt)
655
+ else:
656
+ batch_size = prompt_embeds.shape[0]
657
+
658
+ device = self._execution_device
659
+
660
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
661
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
662
+ # corresponds to doing no classifier free guidance.
663
+ do_classifier_free_guidance = guidance_scale > 1.0
664
+
665
+ # 3. Encode input prompt
666
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
667
+ prompt=prompt,
668
+ negative_prompt=negative_prompt,
669
+ do_classifier_free_guidance=do_classifier_free_guidance,
670
+ num_videos_per_prompt=num_videos_per_prompt,
671
+ prompt_embeds=prompt_embeds,
672
+ negative_prompt_embeds=negative_prompt_embeds,
673
+ max_sequence_length=max_sequence_length,
674
+ device=device,
675
+ )
676
+ if do_classifier_free_guidance:
677
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
678
+
679
+ # 4. Prepare timesteps
680
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
681
+
682
+ self._num_timesteps = len(timesteps)
683
+
684
+ # 5. Prepare latents
685
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
686
+
687
+ first_image = self.video_processor.preprocess(first_image, height=height, width=width).to(
688
+ device, dtype=prompt_embeds.dtype
689
+ )
690
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
691
+ device, dtype=prompt_embeds.dtype
692
+ )
693
+
694
+ latent_channels = self.transformer.config.in_channels // 2
695
+ latents, image_latents = self.prepare_latents(
696
+ first_image,
697
+ last_image,
698
+ batch_size * num_videos_per_prompt,
699
+ latent_channels,
700
+ num_frames,
701
+ height,
702
+ width,
703
+ prompt_embeds.dtype,
704
+ device,
705
+ generator,
706
+ latents,
707
+ )
708
+
709
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
710
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
711
+
712
+ # 7. Create rotary embeds if required
713
+ image_rotary_emb = (
714
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
715
+ if self.transformer.config.use_rotary_positional_embeddings
716
+ else None
717
+ )
718
+
719
+ # 8. Create ofs embeds if required
720
+ ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
721
+
722
+ # 9. Denoising loop
723
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
724
+
725
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
726
+ # for DPM-solver++
727
+ old_pred_original_sample = None
728
+ for i, t in enumerate(timesteps):
729
+ if self.interrupt:
730
+ continue
731
+
732
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
733
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
734
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
735
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
736
+
737
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
738
+ timestep = t.expand(latent_model_input.shape[0])
739
+
740
+ current_sampling_percent = i / len(timesteps)
741
+
742
+ EF_Net_states = []
743
+ if (EF_Net_guidance_start <= current_sampling_percent < EF_Net_guidance_end):
744
+ # extract EF_Net hidden state
745
+ EF_Net_states = self.EF_Net(
746
+ hidden_states=latent_image_input[:,:,0:16,:,:],
747
+ encoder_hidden_states=prompt_embeds,
748
+ image_rotary_emb=None,
749
+ EF_Net_states=latent_image_input[:,12::,:,:,:],
750
+ timestep=timestep,
751
+ return_dict=False,
752
+ )[0]
753
+ if isinstance(EF_Net_states, (tuple, list)):
754
+ EF_Net_states = [x.to(dtype=self.transformer.dtype) for x in EF_Net_states]
755
+ else:
756
+ EF_Net_states = EF_Net_states.to(dtype=self.transformer.dtype)
757
+
758
+ # predict noise model_output
759
+
760
+ noise_pred = self.transformer(
761
+ hidden_states=latent_model_input,
762
+ encoder_hidden_states=prompt_embeds,
763
+ timestep=timestep,
764
+ # ofs=ofs_emb,
765
+ image_rotary_emb=image_rotary_emb,
766
+ # attention_kwargs=attention_kwargs,
767
+ EF_Net_states=EF_Net_states,
768
+ EF_Net_weights=EF_Net_weights,
769
+ return_dict=False,
770
+ )[0]
771
+ noise_pred = noise_pred.float()
772
+
773
+ # perform guidance
774
+ if use_dynamic_cfg:
775
+ self._guidance_scale = 1 + guidance_scale * (
776
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
777
+ )
778
+ if do_classifier_free_guidance:
779
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
780
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
781
+
782
+ # compute the previous noisy sample x_t -> x_t-1
783
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
784
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
785
+
786
+ else:
787
+ latents, old_pred_original_sample = self.scheduler.step(
788
+ noise_pred,
789
+ old_pred_original_sample,
790
+ t,
791
+ timesteps[i - 1] if i > 0 else None,
792
+ latents,
793
+ **extra_step_kwargs,
794
+ return_dict=False,
795
+ )
796
+ latents = latents.to(prompt_embeds.dtype)
797
+
798
+ # call the callback, if provided
799
+ if callback_on_step_end is not None:
800
+ callback_kwargs = {}
801
+ for k in callback_on_step_end_tensor_inputs:
802
+ callback_kwargs[k] = locals()[k]
803
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
804
+
805
+ latents = callback_outputs.pop("latents", latents)
806
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
807
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
808
+
809
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
810
+ progress_bar.update()
811
+
812
+ if not output_type == "latent":
813
+ video = self.decode_latents(latents.to(torch.bfloat16))
814
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
815
+ else:
816
+ video = latents
817
+
818
+ # Offload all models
819
+ self.maybe_free_model_hooks()
820
+
821
+ if not return_dict:
822
+ return (video,)
823
+
824
+ return CogVideoXPipelineOutput(frames=video)
cogvideo_transformer.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import numpy as np
5
+ from diffusers.utils import is_torch_version
6
+ from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel, Transformer2DModelOutput
7
+
8
+
9
+ class CustomCogVideoXTransformer3DModel(CogVideoXTransformer3DModel):
10
+ def forward(
11
+ self,
12
+ hidden_states: torch.Tensor,
13
+ encoder_hidden_states: torch.Tensor,
14
+ timestep: Union[int, float, torch.LongTensor],
15
+ start_frame = None,
16
+ timestep_cond: Optional[torch.Tensor] = None,
17
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
18
+ EF_Net_states: torch.Tensor = None,
19
+ EF_Net_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
20
+ return_dict: bool = True,
21
+ ):
22
+ batch_size, num_frames, channels, height, width = hidden_states.shape
23
+
24
+ if start_frame is not None:
25
+ hidden_states = torch.cat([start_frame, hidden_states], dim=2)
26
+ # 1. Time embedding
27
+ timesteps = timestep
28
+
29
+ t_emb = self.time_proj(timesteps)
30
+
31
+ # timesteps does not contain any weights and will always return f32 tensors
32
+ # but time_embedding might actually be running in fp16. so we need to cast here.
33
+ # there might be better ways to encapsulate this.
34
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
35
+ emb = self.time_embedding(t_emb, timestep_cond)
36
+
37
+
38
+ # 2. Patch embedding
39
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
40
+ hidden_states = self.embedding_dropout(hidden_states)
41
+
42
+ text_seq_length = encoder_hidden_states.shape[1]
43
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
44
+ hidden_states = hidden_states[:, text_seq_length:]
45
+
46
+ # 3. Transformer blocks
47
+ for i, block in enumerate(self.transformer_blocks):
48
+ #if self.training and self.gradient_checkpointing:
49
+ if self.gradient_checkpointing:
50
+
51
+ def create_custom_forward(module):
52
+ def custom_forward(*inputs):
53
+ return module(*inputs)
54
+
55
+ return custom_forward
56
+
57
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
58
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
59
+ create_custom_forward(block),
60
+ hidden_states,
61
+ encoder_hidden_states,
62
+ emb,
63
+ image_rotary_emb,
64
+ **ckpt_kwargs,
65
+ )
66
+ else:
67
+ hidden_states, encoder_hidden_states = block(
68
+ hidden_states=hidden_states,
69
+ encoder_hidden_states=encoder_hidden_states,
70
+ temb=emb,
71
+ image_rotary_emb=image_rotary_emb,
72
+ )
73
+
74
+ if (EF_Net_states is not None) and (i < len(EF_Net_states)):
75
+ EF_Net_states_block = EF_Net_states[i]
76
+ EF_Net_block_weight = 1.0
77
+
78
+ if isinstance(EF_Net_weights, (float, int)):
79
+ EF_Net_block_weight = EF_Net_weights
80
+ else:
81
+ EF_Net_block_weight = EF_Net_weights[i]
82
+
83
+
84
+ hidden_states = hidden_states + EF_Net_states_block * EF_Net_block_weight
85
+
86
+
87
+ if not self.config.use_rotary_positional_embeddings:
88
+ hidden_states = self.norm_final(hidden_states)
89
+ else:
90
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
91
+ hidden_states = self.norm_final(hidden_states)
92
+ hidden_states = hidden_states[:, text_seq_length:]
93
+
94
+ # 4. Final block
95
+ hidden_states = self.norm_out(hidden_states, temb=emb)
96
+ hidden_states = self.proj_out(hidden_states)
97
+
98
+ # 5. Unpatchify
99
+ p = self.config.patch_size
100
+ p_t = self.config.patch_size_t
101
+
102
+ if p_t is None:
103
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
104
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
105
+ else:
106
+ output = hidden_states.reshape(
107
+ batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
108
+ )
109
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
110
+
111
+ if not return_dict:
112
+ return (output,)
113
+ return Transformer2DModelOutput(sample=output)
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces>=0.29.3
2
+ safetensors>=0.4.5
3
+ spandrel>=0.4.0
4
+ tqdm>=4.66.5
5
+ scikit-video>=1.1.11
6
+ diffusers==0.32.0
7
+ transformers>=4.44.0
8
+ accelerate>=0.34.2
9
+ opencv-python>=4.10.0.84
10
+ sentencepiece>=0.2.0
11
+ numpy==1.26.0
12
+ torch>=2.4.0
13
+ torchvision>=0.19.0
14
+ gradio>=4.44.0
15
+ imageio>=2.34.2
16
+ imageio-ffmpeg>=0.5.1
17
+ openai>=1.45.0
18
+ moviepy>=1.0.3
19
+ pillow>=9.5.0
20
+ denku==0.0.51
21
+ decord==0.6.0