diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index ac5b09a..55c9270 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -280,6 +280,9 @@ preset_models_on_modelscope = { "ExVideo-SVD-128f-v1": [ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"), ], + "ExVideo-CogVideoX-LoRA-129f-v1": [ + ("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"), + ], # Stable Diffusion "StableDiffusion_v15": [ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"), @@ -380,7 +383,6 @@ preset_models_on_modelscope = { ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ], - # Translator "opus-mt-zh-en": [ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), @@ -453,6 +455,7 @@ Preset_model_id: TypeAlias = Literal[ "HunyuanDiT", "stable-video-diffusion-img2vid-xt", "ExVideo-SVD-128f-v1", + "ExVideo-CogVideoX-LoRA-129f-v1", "StableDiffusion_v15", "DreamShaper_8", "AingDiffusion_v12", diff --git a/diffsynth/models/cog_dit.py b/diffsynth/models/cog_dit.py index c6d0cc6..e93c4c3 100644 --- a/diffsynth/models/cog_dit.py +++ b/diffsynth/models/cog_dit.py @@ -283,7 +283,7 @@ class CogDiT(torch.nn.Module): return value - def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30): + def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False): if tiled: return TileWorker2Dto3D().tiled_forward( forward_fn=lambda x: self.forward(x, timestep, prompt_emb), @@ -298,8 +298,21 @@ class CogDiT(torch.nn.Module): hidden_states = self.patchify(hidden_states) time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype) prompt_emb = self.context_embedder(prompt_emb) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + for block in self.blocks: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb) + if self.training and use_gradient_checkpointing: + hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, prompt_emb, time_emb, image_rotary_emb, + use_reentrant=False, + ) + else: + hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) hidden_states = self.norm_final(hidden_states) diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index aa21034..e948945 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -6,6 +6,7 @@ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2 from .sd3_dit import SD3DiT from .flux_dit import FluxDiT from .hunyuan_dit import HunyuanDiT +from .cog_dit import CogDiT @@ -189,7 +190,7 @@ class FluxLoRAFromCivitai(LoRAFromCivitai): class GeneralLoRAFromPeft: def __init__(self): - self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT] + self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT] def fetch_device_dtype_from_state_dict(self, state_dict): @@ -301,4 +302,4 @@ class FluxLoRAConverter: def get_lora_loaders(): - return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()] + return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()]