Compare commits

...

22 Commits

Author SHA1 Message Date
xuyixuan.xyx
a10635818a multi-node 2025-04-22 10:40:40 +08:00
Artiprocher
9ea29a769d fix base model 2025-04-14 13:18:31 +08:00
Artiprocher
2a5355b7cb fix base model 2025-04-14 13:11:45 +08:00
Artiprocher
7a06a58f49 support reference image 2025-04-11 16:23:14 +08:00
Artiprocher
a3b4f235a0 support reference image 2025-04-11 16:17:39 +08:00
Artiprocher
a572254a1d support reference image 2025-04-11 16:14:24 +08:00
Artiprocher
9e78bf5e89 support reference image 2025-04-11 15:36:51 +08:00
Artiprocher
d21676b4dc support reference image 2025-04-11 15:31:30 +08:00
Zhongjie Duan
53f01e72e6 Update setup.py 2025-04-09 15:38:17 +08:00
Zhongjie Duan
55e5e373dd Update publish.yaml 2025-04-09 15:37:46 +08:00
Zhongjie Duan
4a0921ada1 Update requirements.txt 2025-04-09 15:37:16 +08:00
Zhongjie Duan
5129d3dc52 Update setup.py 2025-04-09 15:34:02 +08:00
Zhongjie Duan
ee9bab80f2 Update requirements.txt 2025-04-09 15:33:21 +08:00
Zhongjie Duan
cd8884c9ef Update setup.py 2025-04-09 15:27:36 +08:00
Zhongjie Duan
46744362de Update requirements.txt 2025-04-09 15:26:13 +08:00
Zhongjie Duan
0f0cdc3afc Update setup.py 2025-04-09 15:15:18 +08:00
Zhongjie Duan
a33c63af87 Merge pull request #518 from modelscope/wan-fun
Wan fun
2025-04-08 19:25:12 +08:00
Artiprocher
3cc9764bc9 support more wan models 2025-04-08 19:22:53 +08:00
Artiprocher
f6c6e3c640 support more wan models 2025-04-08 17:19:54 +08:00
Artiprocher
60a9db706e support more wan models 2025-04-08 17:07:10 +08:00
lzw478614@alibaba-inc.com
a98700feb2 support wan-fun-inp generating 2025-04-06 22:55:42 +08:00
lzw478614@alibaba-inc.com
5418ca781e support load wan2.1-fun-inp-1.3B and 14B model 2025-04-03 16:37:59 +08:00
16 changed files with 1113 additions and 67 deletions

View File

@@ -20,7 +20,7 @@ jobs:
with:
python-version: '3.10'
- name: Install wheel
run: pip install wheel && pip install -r requirements.txt
run: pip install wheel==0.44.0 && pip install -r requirements.txt
- name: Build DiffSynth
run: python setup.py sdist bdist_wheel
- name: Publish package to PyPI

View File

@@ -59,6 +59,7 @@ from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_motion_controller import WanMotionControllerModel
model_loader_configs = [
@@ -120,11 +121,16 @@ model_loader_configs = [
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.

View File

@@ -0,0 +1,125 @@
import torch, os, json, torchvision
from PIL import Image
from torchvision.transforms import v2
class SingleTaskDataset(torch.utils.data.Dataset):
def __init__(self, base_path, keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")), height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None):
self.base_path = base_path
self.keys = keys
self.metadata = []
self.bad_data = []
self.height = height
self.width = width
self.random = random
self.steps_per_epoch = steps_per_epoch
self.image_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if metadata_path is None:
self.search_for_data("", report_data_log=True)
self.report_data_log()
else:
with open(metadata_path, "r", encoding="utf-8-sig") as f:
self.metadata = json.load(f)
def report_data_log(self):
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
def dump_metadata(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
def parse_json_file(self, absolute_path, relative_path):
data_list = []
with open(absolute_path, "r") as f:
metadata = json.load(f)
for image_1, image_2, instruction in self.keys:
image_1 = os.path.join(relative_path, metadata[image_1])
image_2 = os.path.join(relative_path, metadata[image_2])
instruction = metadata[instruction]
data_list.append((image_1, image_2, instruction))
return data_list
def search_for_data(self, path, report_data_log=False):
now_path = os.path.join(self.base_path, path)
if os.path.isfile(now_path) and path.endswith(".json"):
try:
data_list = self.parse_json_file(now_path, os.path.dirname(path))
self.metadata.extend(data_list)
except:
self.bad_data.append(now_path)
elif os.path.isdir(now_path):
for sub_path in os.listdir(now_path):
self.search_for_data(os.path.join(path, sub_path))
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
self.report_data_log()
def load_image(self, image_path):
image_path = os.path.join(self.base_path, image_path)
image = Image.open(image_path).convert("RGB")
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = self.image_process(image)
return image
def load_data(self, data_id):
image_1, image_2, instruction = self.metadata[data_id]
image_1 = self.load_image(image_1)
image_2 = self.load_image(image_2)
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
def __getitem__(self, data_id):
if self.random:
while True:
try:
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
data = self.load_data(data_id)
return data
except:
continue
else:
return self.load_data(data_id)
def __len__(self):
return self.steps_per_epoch if self.random else len(self.metadata)
class MultiTaskDataset(torch.utils.data.Dataset):
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
self.dataset_list = dataset_list
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
self.steps_per_epoch = steps_per_epoch
def __getitem__(self, data_id):
while True:
try:
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
data = self.dataset_list[dataset_id][data_id]
return data
except:
continue
def __len__(self):
return self.steps_per_epoch

View File

@@ -20,10 +20,11 @@ class RoPEEmbedding(torch.nn.Module):
self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
def rope(self, pos: torch.Tensor, dim: int, theta: int, device="cpu") -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = scale.to(device)
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
@@ -36,9 +37,9 @@ class RoPEEmbedding(torch.nn.Module):
return out.float()
def forward(self, ids):
def forward(self, ids, device="cpu"):
n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta, device) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)

View File

@@ -0,0 +1,31 @@
from .sd3_dit import TimestepEmbeddings
from .flux_dit import RoPEEmbedding
import torch
from einops import repeat
class FluxReferenceEmbedder(torch.nn.Module):
def __init__(self):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.idx_embedder = TimestepEmbeddings(256, 256)
self.proj = torch.nn.Linear(3072, 3072)
def forward(self, image_ids, idx, dtype, device):
pos_emb = self.pos_embedder(image_ids, device=device)
idx_emb = self.idx_embedder(idx, dtype=dtype).to(device)
length = pos_emb.shape[2]
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
image_rotary_emb = pos_emb + idx_emb
return image_rotary_emb
def init(self):
self.idx_embedder.timestep_embedder[-1].load_state_dict({
"weight": torch.zeros((256, 256)),
"bias": torch.zeros((256,))
}),
self.proj.load_state_dict({
"weight": torch.eye(3072),
"bias": torch.zeros((3072,))
})

View File

@@ -493,6 +493,62 @@ class WanModelStateDictConverter:
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
else:
config = {}
return state_dict, config

View File

@@ -0,0 +1,44 @@
import torch
import torch.nn as nn
from .wan_video_dit import sinusoidal_embedding_1d
class WanMotionControllerModel(torch.nn.Module):
def __init__(self, freq_dim=256, dim=1536):
super().__init__()
self.freq_dim = freq_dim
self.linear = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim * 6),
)
def forward(self, motion_bucket_id):
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
emb = self.linear(emb)
return emb
def init(self):
state_dict = self.linear[-1].state_dict()
state_dict = {i: state_dict[i] * 0 for i in state_dict}
self.linear[-1].load_state_dict(state_dict)
@staticmethod
def state_dict_converter():
return WanMotionControllerModelDictConverter()
class WanMotionControllerModelDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -1,10 +1,12 @@
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..models.flux_reference_embedder import FluxReferenceEmbedder
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
from typing import List
import torch
from einops import rearrange
from tqdm import tqdm
import numpy as np
from PIL import Image
@@ -32,6 +34,7 @@ class FluxImagePipeline(BasePipeline):
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None
self.infinityou_processor: InfinitYou = None
self.reference_embedder: FluxReferenceEmbedder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
@@ -360,6 +363,20 @@ class FluxImagePipeline(BasePipeline):
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
else:
return {}, controlnet_image
def prepare_reference_images(self, reference_images, tiled=False, tile_size=64, tile_stride=32):
if reference_images is not None:
hidden_states_ref = []
for reference_image in reference_images:
self.load_models_to_device(['vae_encoder'])
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
hidden_states_ref.append(latents)
hidden_states_ref = torch.concat(hidden_states_ref, dim=0)
return {"hidden_states_ref": hidden_states_ref}
else:
return {"hidden_states_ref": None}
@torch.no_grad()
@@ -398,6 +415,8 @@ class FluxImagePipeline(BasePipeline):
# InfiniteYou
infinityou_id_image=None,
infinityou_guidance=1.0,
# Reference images
reference_images=None,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
@@ -436,6 +455,9 @@ class FluxImagePipeline(BasePipeline):
# ControlNets
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# Reference images
reference_kwargs = self.prepare_reference_images(reference_images, **tiler_kwargs)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
@@ -447,9 +469,9 @@ class FluxImagePipeline(BasePipeline):
# Positive side
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **reference_kwargs,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -464,9 +486,9 @@ class FluxImagePipeline(BasePipeline):
if cfg_scale != 1.0:
# Negative side
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **reference_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -586,6 +608,7 @@ class TeaCache:
def lets_dance_flux(
dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None,
reference_embedder: FluxReferenceEmbedder = None,
hidden_states=None,
timestep=None,
prompt_emb=None,
@@ -594,6 +617,7 @@ def lets_dance_flux(
text_ids=None,
image_ids=None,
controlnet_frames=None,
hidden_states_ref=None,
tiled=False,
tile_size=128,
tile_stride=64,
@@ -603,6 +627,7 @@ def lets_dance_flux(
id_emb=None,
infinityou_guidance=None,
tea_cache: TeaCache = None,
use_gradient_checkpointing=False,
**kwargs
):
if tiled:
@@ -669,28 +694,55 @@ def lets_dance_flux(
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1), device=hidden_states.device)
attention_mask = None
# Reference images
if hidden_states_ref is not None:
# RoPE
image_ids_ref = dit.prepare_image_ids(hidden_states_ref)
idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100
image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype, device=hidden_states.device)
image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2)
# hidden_states
original_hidden_states_length = hidden_states.shape[1]
hidden_states_ref = dit.patchify(hidden_states_ref)
hidden_states_ref = dit.x_embedder(hidden_states_ref)
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
hidden_states_ref = reference_embedder.proj(hidden_states_ref)
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
else:
tea_cache_update = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tea_cache_update:
hidden_states = tea_cache.update(hidden_states)
else:
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None),
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
@@ -699,14 +751,21 @@ def lets_dance_flux(
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
@@ -715,6 +774,8 @@ def lets_dance_flux(
if tea_cache is not None:
tea_cache.store(hidden_states)
if hidden_states_ref is not None:
hidden_states = hidden_states[:, :original_hidden_states_length]
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
hidden_states = dit.unpatchify(hidden_states, height, width)

View File

@@ -18,6 +18,7 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_motion_controller import WanMotionControllerModel
@@ -31,7 +32,8 @@ class WanVideoPipeline(BasePipeline):
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
self.motion_controller: WanMotionControllerModel = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
@@ -122,6 +124,22 @@ class WanVideoPipeline(BasePipeline):
computation_device=self.device,
),
)
if self.motion_controller is not None:
dtype = next(iter(self.motion_controller.parameters())).dtype
enable_vram_management(
self.motion_controller,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
@@ -134,6 +152,7 @@ class WanVideoPipeline(BasePipeline):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
@staticmethod
@@ -163,22 +182,47 @@ class WanVideoPipeline(BasePipeline):
return {"context": prompt_emb}
def encode_image(self, image, num_frames, height, width):
def encode_image(self, image, end_image, num_frames, height, width):
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
y = y.to(dtype=self.torch_dtype, device=self.device)
return {"clip_feature": clip_context, "y": y}
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
control_video = self.preprocess_images(control_video)
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
return latents
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
if control_video is not None:
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if clip_feature is None or y is None:
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
else:
y = y[:, -16:]
y = torch.concat([control_latents, y], dim=1)
return {"clip_feature": clip_feature, "y": y}
def tensor2video(self, frames):
@@ -204,6 +248,11 @@ class WanVideoPipeline(BasePipeline):
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_motion_bucket_id(self, motion_bucket_id):
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
return {"motion_bucket_id": motion_bucket_id}
@torch.no_grad()
@@ -212,7 +261,9 @@ class WanVideoPipeline(BasePipeline):
prompt,
negative_prompt="",
input_image=None,
end_image=None,
input_video=None,
control_video=None,
denoising_strength=1.0,
seed=None,
rand_device="cpu",
@@ -222,6 +273,7 @@ class WanVideoPipeline(BasePipeline):
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
motion_bucket_id=None,
tiled=True,
tile_size=(30, 52),
tile_stride=(15, 26),
@@ -263,10 +315,21 @@ class WanVideoPipeline(BasePipeline):
# Encode image
if input_image is not None and self.image_encoder is not None:
self.load_models_to_device(["image_encoder", "vae"])
image_emb = self.encode_image(input_image, num_frames, height, width)
image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
else:
image_emb = {}
# ControlNet
if control_video is not None:
self.load_models_to_device(["image_encoder", "vae"])
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
# Motion Controller
if self.motion_controller is not None and motion_bucket_id is not None:
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
else:
motion_kwargs = {}
# Extra input
extra_input = self.prepare_extra_input(latents)
@@ -278,14 +341,24 @@ class WanVideoPipeline(BasePipeline):
usp_kwargs = self.prepare_unified_sequence_parallel()
# Denoise
self.load_models_to_device(["dit"])
self.load_models_to_device(["dit", "motion_controller"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
noise_pred_posi = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller,
x=latents, timestep=timestep,
**prompt_emb_posi, **image_emb, **extra_input,
**tea_cache_posi, **usp_kwargs, **motion_kwargs
)
if cfg_scale != 1.0:
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs)
noise_pred_nega = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller,
x=latents, timestep=timestep,
**prompt_emb_nega, **image_emb, **extra_input,
**tea_cache_nega, **usp_kwargs, **motion_kwargs
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -358,13 +431,15 @@ class TeaCache:
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
motion_controller: WanMotionControllerModel = None,
x: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
motion_bucket_id: Optional[torch.Tensor] = None,
**kwargs,
):
if use_unified_sequence_parallel:
@@ -375,6 +450,8 @@ def model_fn_wan_video(
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)
if dit.has_image_input:

View File

@@ -10,34 +10,52 @@ cd DiffSynth-Studio
pip install -e .
```
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
## Model Zoo
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
* [Sage Attention](https://github.com/thu-ml/SageAttention)
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|Developer|Name|Link|Scripts|
|-|-|-|-|
|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)|
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)|
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)|
|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
## Inference
Base model features
### Wan-Video-1.3B-T2V
||Text-to-video|Image-to-video|End frame|Control|
|-|-|-|-|-|
|1.3B text-to-video|✅||||
|14B text-to-video|✅||||
|14B image-to-video 480P||✅|||
|14B image-to-video 720P||✅|||
|1.3B InP||✅|✅||
|14B InP||✅|✅||
|1.3B Control||||✅|
|14B Control||||✅|
Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
Adapter model compatibility
Required VRAM: 6G
||1.3B text-to-video|1.3B InP|
|-|-|-|
|1.3B aesthetics LoRA|✅||
|1.3B Highres-fix LoRA|✅||
|1.3B ExVideo LoRA|✅||
|1.3B Speed Control adapter|✅|✅|
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
## VRAM Usage
Put sunglasses on the dog.
* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!).
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
### Wan-Video-14B-T2V
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
We present a detailed table here. The model is tested on a single A100.
We present a detailed table here. The model (14B text-to-video) is tested on a single A100.
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|-|-|-|-|-|
@@ -47,31 +65,46 @@ We present a detailed table here. The model is tested on a single A100.
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|torch.float8_e4m3fn|0|24.0s/it|10G||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
### Parallel Inference
## Efficient Attention Implementation
1. Unified Sequence Parallel (USP)
DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA.
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
* [Sage Attention](https://github.com/thu-ml/SageAttention)
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
## Acceleration
We support multiple acceleration solutions:
* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py).
* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py)
```bash
pip install xfuser>=0.4.3
```
```bash
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
```
2. Tensor Parallel
* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py).
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
## Gallery
### Wan-Video-14B-I2V
1.3B text-to-video.
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
Put sunglasses on the dog.
![Image](https://github.com/user-attachments/assets/adf8047f-7943-4aaa-a555-2b32dc415f39)
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
14B text-to-video.
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
14B image-to-video.
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75

View File

@@ -0,0 +1,41 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download
# Download models
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
"models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Text-to-video
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=1, tiled=True,
motion_bucket_id=0
)
save_video(video, "video_slow.mp4", fps=15, quality=5)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=1, tiled=True,
motion_bucket_id=100
)
save_video(video, "video_fast.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,42 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors",
"models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth",
"models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth",
"models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example image
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/input_image.jpg"
)
image = Image.open("data/examples/wan/input_image.jpg")
# Image-to-video
video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
input_image=image,
# You can input `end_image=xxx` to control the last frame of the video.
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,40 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors",
"models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth",
"models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth",
"models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example video
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/control_video.mp4"
)
# Control-to-video
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
video = pipe(
prompt="扁平风格动漫一位长发少女优雅起舞。她五官精致大眼睛明亮有神黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
control_video=control_video, height=832, width=576, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -14,7 +14,7 @@ else:
setup(
name="diffsynth",
version="1.1.3",
version="1.1.7",
description="Enjoy the magic of Diffusion models!",
author="Artiprocher",
packages=find_packages(),

241
train_flux_reference.py Normal file
View File

@@ -0,0 +1,241 @@
from diffsynth import ModelManager, FluxImagePipeline
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
from diffsynth.models.lora import FluxLoRAConverter
import torch, os, argparse
import lightning as pl
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
from diffsynth.pipelines.flux_image import lets_dance_flux
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class LightningModel(LightningModelForT2ILoRA):
def __init__(
self,
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
state_dict_converter=None, quantize = None
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
# Load models
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
if quantize is None:
model_manager.load_models(pretrained_weights)
else:
model_manager.load_models(pretrained_weights[1:])
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
if preset_lora_path is not None:
preset_lora_path = preset_lora_path.split(",")
for path in preset_lora_path:
model_manager.load_lora(path)
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.pipe.reference_embedder = FluxReferenceEmbedder()
self.pipe.reference_embedder.init()
if quantize is not None:
self.pipe.dit.quantize()
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.pipe.reference_embedder.requires_grad_(True)
self.pipe.reference_embedder.train()
self.pipe.dit.requires_grad_(True)
self.pipe.dit.train()
# self.add_lora_to_model(
# self.pipe.denoising_model(),
# lora_rank=lora_rank,
# lora_alpha=lora_alpha,
# lora_target_modules=lora_target_modules,
# init_lora_weights=init_lora_weights,
# pretrained_lora_path=pretrained_lora_path,
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
# )
def training_step(self, batch, batch_idx):
# Data
text, image = batch["instruction"], batch["image_2"]
image_ref = batch["image_1"]
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
if "latents" in batch:
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
else:
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Reference image
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
# Compute loss
noise_pred = lets_dance_flux(
self.pipe.denoising_model(),
reference_embedder=self.pipe.reference_embedder,
hidden_states_ref=hidden_states_ref,
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
if self.state_dict_converter is not None:
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
checkpoint.update(lora_state_dict)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_text_encoder_path",
type=str,
default=None,
required=True,
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
)
parser.add_argument(
"--pretrained_text_encoder_2_path",
type=str,
default=None,
required=True,
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
)
parser.add_argument(
"--pretrained_dit_path",
type=str,
default=None,
required=True,
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
)
parser.add_argument(
"--pretrained_vae_path",
type=str,
default=None,
required=True,
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--align_to_opensource_format",
default=False,
action="store_true",
help="Whether to export lora files aligned with other opensource format.",
)
parser.add_argument(
"--quantize",
type=str,
default=None,
choices=["float8_e4m3fn"],
help="Whether to use quantization when training the model, and in which format.",
)
parser.add_argument(
"--preset_lora_path",
type=str,
default=None,
help="Preset LoRA path.",
)
parser = add_general_parsers(parser)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = LightningModel(
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
preset_lora_path=args.preset_lora_path,
learning_rate=args.learning_rate,
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
pretrained_lora_path=args.pretrained_lora_path,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
)
# dataset and data loader
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
height=512, width=512,
),
],
dataset_weight=(4, 1, 4, 1),
steps_per_epoch=args.steps_per_epoch,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers
)
# train
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision=args.precision,
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=None,
)
trainer.fit(model=model, train_dataloaders=train_loader)

View File

@@ -0,0 +1,248 @@
from diffsynth import ModelManager, FluxImagePipeline
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
from diffsynth.models.lora import FluxLoRAConverter
import torch, os, argparse
import lightning as pl
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
from diffsynth.pipelines.flux_image import lets_dance_flux
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class LightningModel(LightningModelForT2ILoRA):
def __init__(
self,
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
state_dict_converter=None, quantize = None
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
# Load models
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
if quantize is None:
model_manager.load_models(pretrained_weights)
else:
model_manager.load_models(pretrained_weights[1:])
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
if preset_lora_path is not None:
preset_lora_path = preset_lora_path.split(",")
for path in preset_lora_path:
model_manager.load_lora(path)
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.pipe.reference_embedder = FluxReferenceEmbedder()
self.pipe.reference_embedder.init()
if quantize is not None:
self.pipe.dit.quantize()
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.pipe.reference_embedder.requires_grad_(True)
self.pipe.reference_embedder.train()
self.pipe.dit.requires_grad_(True)
self.pipe.dit.train()
# self.add_lora_to_model(
# self.pipe.denoising_model(),
# lora_rank=lora_rank,
# lora_alpha=lora_alpha,
# lora_target_modules=lora_target_modules,
# init_lora_weights=init_lora_weights,
# pretrained_lora_path=pretrained_lora_path,
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
# )
def training_step(self, batch, batch_idx):
# Data
text, image = batch["instruction"], batch["image_2"]
image_ref = batch["image_1"]
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
if "latents" in batch:
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
else:
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Reference image
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
# Compute loss
noise_pred = lets_dance_flux(
self.pipe.denoising_model(),
reference_embedder=self.pipe.reference_embedder,
hidden_states_ref=hidden_states_ref,
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
if self.state_dict_converter is not None:
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
checkpoint.update(lora_state_dict)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_text_encoder_path",
type=str,
default=None,
required=True,
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
)
parser.add_argument(
"--pretrained_text_encoder_2_path",
type=str,
default=None,
required=True,
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
)
parser.add_argument(
"--pretrained_dit_path",
type=str,
default=None,
required=True,
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
)
parser.add_argument(
"--pretrained_vae_path",
type=str,
default=None,
required=True,
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--align_to_opensource_format",
default=False,
action="store_true",
help="Whether to export lora files aligned with other opensource format.",
)
parser.add_argument(
"--quantize",
type=str,
default=None,
choices=["float8_e4m3fn"],
help="Whether to use quantization when training the model, and in which format.",
)
parser.add_argument(
"--preset_lora_path",
type=str,
default=None,
help="Preset LoRA path.",
)
parser.add_argument(
"--num_nodes",
type=int,
default=1,
help="Num nodes.",
)
parser = add_general_parsers(parser)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = LightningModel(
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
preset_lora_path=args.preset_lora_path,
learning_rate=args.learning_rate,
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
pretrained_lora_path=args.pretrained_lora_path,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
)
# dataset and data loader
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
height=512, width=512,
),
],
dataset_weight=(4, 1, 4, 1),
steps_per_epoch=args.steps_per_epoch,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers
)
# train
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
num_nodes=args.num_nodes,
precision=args.precision,
strategy="ddp",
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=None,
)
trainer.fit(model=model, train_dataloaders=train_loader)