support ExVideo-CogVideoX-LoRA-129f-v1

This commit is contained in:
Artiprocher
2024-09-30 17:33:15 +08:00
parent d91c603875
commit 2e487a2c55
3 changed files with 22 additions and 5 deletions

View File

@@ -280,6 +280,9 @@ preset_models_on_modelscope = {
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
"ExVideo-CogVideoX-LoRA-129f-v1": [
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
],
# Stable Diffusion
"StableDiffusion_v15": [
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
@@ -380,7 +383,6 @@ preset_models_on_modelscope = {
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
],
# Translator
"opus-mt-zh-en": [
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
@@ -453,6 +455,7 @@ Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
"stable-video-diffusion-img2vid-xt",
"ExVideo-SVD-128f-v1",
"ExVideo-CogVideoX-LoRA-129f-v1",
"StableDiffusion_v15",
"DreamShaper_8",
"AingDiffusion_v12",

View File

@@ -283,7 +283,7 @@ class CogDiT(torch.nn.Module):
return value
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30):
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
if tiled:
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
@@ -298,8 +298,21 @@ class CogDiT(torch.nn.Module):
hidden_states = self.patchify(hidden_states)
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, time_emb, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)

View File

@@ -6,6 +6,7 @@ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sd3_dit import SD3DiT
from .flux_dit import FluxDiT
from .hunyuan_dit import HunyuanDiT
from .cog_dit import CogDiT
@@ -189,7 +190,7 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
class GeneralLoRAFromPeft:
def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT]
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT]
def fetch_device_dtype_from_state_dict(self, state_dict):
@@ -301,4 +302,4 @@ class FluxLoRAConverter:
def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()]
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()]