mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
support ExVideo-CogVideoX-LoRA-129f-v1
This commit is contained in:
@@ -280,6 +280,9 @@ preset_models_on_modelscope = {
|
|||||||
"ExVideo-SVD-128f-v1": [
|
"ExVideo-SVD-128f-v1": [
|
||||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||||
],
|
],
|
||||||
|
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
||||||
|
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
||||||
|
],
|
||||||
# Stable Diffusion
|
# Stable Diffusion
|
||||||
"StableDiffusion_v15": [
|
"StableDiffusion_v15": [
|
||||||
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
||||||
@@ -380,7 +383,6 @@ preset_models_on_modelscope = {
|
|||||||
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
],
|
],
|
||||||
|
|
||||||
# Translator
|
# Translator
|
||||||
"opus-mt-zh-en": [
|
"opus-mt-zh-en": [
|
||||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||||
@@ -453,6 +455,7 @@ Preset_model_id: TypeAlias = Literal[
|
|||||||
"HunyuanDiT",
|
"HunyuanDiT",
|
||||||
"stable-video-diffusion-img2vid-xt",
|
"stable-video-diffusion-img2vid-xt",
|
||||||
"ExVideo-SVD-128f-v1",
|
"ExVideo-SVD-128f-v1",
|
||||||
|
"ExVideo-CogVideoX-LoRA-129f-v1",
|
||||||
"StableDiffusion_v15",
|
"StableDiffusion_v15",
|
||||||
"DreamShaper_8",
|
"DreamShaper_8",
|
||||||
"AingDiffusion_v12",
|
"AingDiffusion_v12",
|
||||||
|
|||||||
@@ -283,7 +283,7 @@ class CogDiT(torch.nn.Module):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30):
|
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
|
||||||
if tiled:
|
if tiled:
|
||||||
return TileWorker2Dto3D().tiled_forward(
|
return TileWorker2Dto3D().tiled_forward(
|
||||||
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
||||||
@@ -298,8 +298,21 @@ class CogDiT(torch.nn.Module):
|
|||||||
hidden_states = self.patchify(hidden_states)
|
hidden_states = self.patchify(hidden_states)
|
||||||
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
prompt_emb = self.context_embedder(prompt_emb)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
if self.training and use_gradient_checkpointing:
|
||||||
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states, prompt_emb, time_emb, image_rotary_emb,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
||||||
|
|
||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||||
hidden_states = self.norm_final(hidden_states)
|
hidden_states = self.norm_final(hidden_states)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|||||||
from .sd3_dit import SD3DiT
|
from .sd3_dit import SD3DiT
|
||||||
from .flux_dit import FluxDiT
|
from .flux_dit import FluxDiT
|
||||||
from .hunyuan_dit import HunyuanDiT
|
from .hunyuan_dit import HunyuanDiT
|
||||||
|
from .cog_dit import CogDiT
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -189,7 +190,7 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
|
|||||||
|
|
||||||
class GeneralLoRAFromPeft:
|
class GeneralLoRAFromPeft:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT]
|
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT]
|
||||||
|
|
||||||
|
|
||||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
def fetch_device_dtype_from_state_dict(self, state_dict):
|
||||||
@@ -301,4 +302,4 @@ class FluxLoRAConverter:
|
|||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
def get_lora_loaders():
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()]
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||||
|
|||||||
Reference in New Issue
Block a user