mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
support ExVideo-CogVideoX-LoRA-129f-v1
This commit is contained in:
@@ -283,7 +283,7 @@ class CogDiT(torch.nn.Module):
|
||||
return value
|
||||
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30):
|
||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
|
||||
if tiled:
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
||||
@@ -298,8 +298,21 @@ class CogDiT(torch.nn.Module):
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, time_emb, image_rotary_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user