mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
26 Commits
wan-lora-f
...
flux-ref
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a10635818a | ||
|
|
9ea29a769d | ||
|
|
2a5355b7cb | ||
|
|
7a06a58f49 | ||
|
|
a3b4f235a0 | ||
|
|
a572254a1d | ||
|
|
9e78bf5e89 | ||
|
|
d21676b4dc | ||
|
|
53f01e72e6 | ||
|
|
55e5e373dd | ||
|
|
4a0921ada1 | ||
|
|
5129d3dc52 | ||
|
|
ee9bab80f2 | ||
|
|
cd8884c9ef | ||
|
|
46744362de | ||
|
|
0f0cdc3afc | ||
|
|
a33c63af87 | ||
|
|
3cc9764bc9 | ||
|
|
f6c6e3c640 | ||
|
|
60a9db706e | ||
|
|
a98700feb2 | ||
|
|
5418ca781e | ||
|
|
71eee780fb | ||
|
|
4864453e0a | ||
|
|
c5a32f76c2 | ||
|
|
c4ed3d3e4b |
2
.github/workflows/publish.yaml
vendored
2
.github/workflows/publish.yaml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install wheel
|
||||
run: pip install wheel && pip install -r requirements.txt
|
||||
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||
- name: Build DiffSynth
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish package to PyPI
|
||||
|
||||
@@ -59,6 +59,7 @@ from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
@@ -120,11 +121,16 @@ model_loader_configs = [
|
||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
125
diffsynth/data/image_pulse.py
Normal file
125
diffsynth/data/image_pulse.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import torch, os, json, torchvision
|
||||
from PIL import Image
|
||||
from torchvision.transforms import v2
|
||||
|
||||
|
||||
|
||||
class SingleTaskDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")), height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None):
|
||||
self.base_path = base_path
|
||||
self.keys = keys
|
||||
self.metadata = []
|
||||
self.bad_data = []
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.random = random
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.image_process = v2.Compose([
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
if metadata_path is None:
|
||||
self.search_for_data("", report_data_log=True)
|
||||
self.report_data_log()
|
||||
else:
|
||||
with open(metadata_path, "r", encoding="utf-8-sig") as f:
|
||||
self.metadata = json.load(f)
|
||||
|
||||
|
||||
def report_data_log(self):
|
||||
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
|
||||
|
||||
|
||||
def dump_metadata(self, path):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def parse_json_file(self, absolute_path, relative_path):
|
||||
data_list = []
|
||||
with open(absolute_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
for image_1, image_2, instruction in self.keys:
|
||||
image_1 = os.path.join(relative_path, metadata[image_1])
|
||||
image_2 = os.path.join(relative_path, metadata[image_2])
|
||||
instruction = metadata[instruction]
|
||||
data_list.append((image_1, image_2, instruction))
|
||||
return data_list
|
||||
|
||||
|
||||
def search_for_data(self, path, report_data_log=False):
|
||||
now_path = os.path.join(self.base_path, path)
|
||||
if os.path.isfile(now_path) and path.endswith(".json"):
|
||||
try:
|
||||
data_list = self.parse_json_file(now_path, os.path.dirname(path))
|
||||
self.metadata.extend(data_list)
|
||||
except:
|
||||
self.bad_data.append(now_path)
|
||||
elif os.path.isdir(now_path):
|
||||
for sub_path in os.listdir(now_path):
|
||||
self.search_for_data(os.path.join(path, sub_path))
|
||||
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
|
||||
self.report_data_log()
|
||||
|
||||
|
||||
def load_image(self, image_path):
|
||||
image_path = os.path.join(self.base_path, image_path)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
width, height = image.size
|
||||
scale = max(self.width / width, self.height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
image = self.image_process(image)
|
||||
return image
|
||||
|
||||
|
||||
def load_data(self, data_id):
|
||||
image_1, image_2, instruction = self.metadata[data_id]
|
||||
image_1 = self.load_image(image_1)
|
||||
image_2 = self.load_image(image_2)
|
||||
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
if self.random:
|
||||
while True:
|
||||
try:
|
||||
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
|
||||
data = self.load_data(data_id)
|
||||
return data
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
return self.load_data(data_id)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch if self.random else len(self.metadata)
|
||||
|
||||
|
||||
|
||||
class MultiTaskDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
|
||||
self.dataset_list = dataset_list
|
||||
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
while True:
|
||||
try:
|
||||
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
|
||||
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
|
||||
data = self.dataset_list[dataset_id][data_id]
|
||||
return data
|
||||
except:
|
||||
continue
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
@@ -20,10 +20,11 @@ class RoPEEmbedding(torch.nn.Module):
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
|
||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
def rope(self, pos: torch.Tensor, dim: int, theta: int, device="cpu") -> torch.Tensor:
|
||||
assert dim % 2 == 0, "The dimension must be even."
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
scale = scale.to(device)
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
batch_size, seq_length = pos.shape
|
||||
@@ -36,9 +37,9 @@ class RoPEEmbedding(torch.nn.Module):
|
||||
return out.float()
|
||||
|
||||
|
||||
def forward(self, ids):
|
||||
def forward(self, ids, device="cpu"):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta, device) for i in range(n_axes)], dim=-3)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
31
diffsynth/models/flux_reference_embedder.py
Normal file
31
diffsynth/models/flux_reference_embedder.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from .sd3_dit import TimestepEmbeddings
|
||||
from .flux_dit import RoPEEmbedding
|
||||
import torch
|
||||
from einops import repeat
|
||||
|
||||
|
||||
class FluxReferenceEmbedder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.idx_embedder = TimestepEmbeddings(256, 256)
|
||||
self.proj = torch.nn.Linear(3072, 3072)
|
||||
|
||||
def forward(self, image_ids, idx, dtype, device):
|
||||
pos_emb = self.pos_embedder(image_ids, device=device)
|
||||
idx_emb = self.idx_embedder(idx, dtype=dtype).to(device)
|
||||
length = pos_emb.shape[2]
|
||||
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
|
||||
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
|
||||
image_rotary_emb = pos_emb + idx_emb
|
||||
return image_rotary_emb
|
||||
|
||||
def init(self):
|
||||
self.idx_embedder.timestep_embedder[-1].load_state_dict({
|
||||
"weight": torch.zeros((256, 256)),
|
||||
"bias": torch.zeros((256,))
|
||||
}),
|
||||
self.proj.load_state_dict({
|
||||
"weight": torch.eye(3072),
|
||||
"bias": torch.zeros((3072,))
|
||||
})
|
||||
@@ -365,7 +365,22 @@ class FluxLoRAConverter:
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
class WanLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, **kwargs):
|
||||
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
|
||||
@@ -493,6 +493,62 @@ class WanModelStateDictConverter:
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
44
diffsynth/models/wan_video_motion_controller.py
Normal file
44
diffsynth/models/wan_video_motion_controller.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .wan_video_dit import sinusoidal_embedding_1d
|
||||
|
||||
|
||||
|
||||
class WanMotionControllerModel(torch.nn.Module):
|
||||
def __init__(self, freq_dim=256, dim=1536):
|
||||
super().__init__()
|
||||
self.freq_dim = freq_dim
|
||||
self.linear = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6),
|
||||
)
|
||||
|
||||
def forward(self, motion_bucket_id):
|
||||
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
|
||||
emb = self.linear(emb)
|
||||
return emb
|
||||
|
||||
def init(self):
|
||||
state_dict = self.linear[-1].state_dict()
|
||||
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
||||
self.linear[-1].load_state_dict(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanMotionControllerModelDictConverter()
|
||||
|
||||
|
||||
|
||||
class WanMotionControllerModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
||||
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..models.flux_reference_embedder import FluxReferenceEmbedder
|
||||
from ..prompters import FluxPrompter
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from typing import List
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -32,6 +34,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.ipadapter: FluxIpAdapter = None
|
||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||
self.infinityou_processor: InfinitYou = None
|
||||
self.reference_embedder: FluxReferenceEmbedder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||
|
||||
|
||||
@@ -360,6 +363,20 @@ class FluxImagePipeline(BasePipeline):
|
||||
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
else:
|
||||
return {}, controlnet_image
|
||||
|
||||
|
||||
def prepare_reference_images(self, reference_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
if reference_images is not None:
|
||||
hidden_states_ref = []
|
||||
for reference_image in reference_images:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
hidden_states_ref.append(latents)
|
||||
hidden_states_ref = torch.concat(hidden_states_ref, dim=0)
|
||||
return {"hidden_states_ref": hidden_states_ref}
|
||||
else:
|
||||
return {"hidden_states_ref": None}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -398,6 +415,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
# InfiniteYou
|
||||
infinityou_id_image=None,
|
||||
infinityou_guidance=1.0,
|
||||
# Reference images
|
||||
reference_images=None,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
@@ -436,6 +455,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
# ControlNets
|
||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||
|
||||
# Reference images
|
||||
reference_kwargs = self.prepare_reference_images(reference_images, **tiler_kwargs)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||
@@ -447,9 +469,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
# Positive side
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **reference_kwargs,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
@@ -464,9 +486,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
if cfg_scale != 1.0:
|
||||
# Negative side
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **reference_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -586,6 +608,7 @@ class TeaCache:
|
||||
def lets_dance_flux(
|
||||
dit: FluxDiT,
|
||||
controlnet: FluxMultiControlNetManager = None,
|
||||
reference_embedder: FluxReferenceEmbedder = None,
|
||||
hidden_states=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
@@ -594,6 +617,7 @@ def lets_dance_flux(
|
||||
text_ids=None,
|
||||
image_ids=None,
|
||||
controlnet_frames=None,
|
||||
hidden_states_ref=None,
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
@@ -603,6 +627,7 @@ def lets_dance_flux(
|
||||
id_emb=None,
|
||||
infinityou_guidance=None,
|
||||
tea_cache: TeaCache = None,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
@@ -669,28 +694,55 @@ def lets_dance_flux(
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1), device=hidden_states.device)
|
||||
attention_mask = None
|
||||
|
||||
# Reference images
|
||||
if hidden_states_ref is not None:
|
||||
# RoPE
|
||||
image_ids_ref = dit.prepare_image_ids(hidden_states_ref)
|
||||
idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100
|
||||
image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2)
|
||||
# hidden_states
|
||||
original_hidden_states_length = hidden_states.shape[1]
|
||||
hidden_states_ref = dit.patchify(hidden_states_ref)
|
||||
hidden_states_ref = dit.x_embedder(hidden_states_ref)
|
||||
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
|
||||
hidden_states_ref = reference_embedder.proj(hidden_states_ref)
|
||||
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
if tea_cache_update:
|
||||
hidden_states = tea_cache.update(hidden_states)
|
||||
else:
|
||||
# Joint Blocks
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
)
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None),
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
@@ -699,14 +751,21 @@ def lets_dance_flux(
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
num_joint_blocks = len(dit.blocks)
|
||||
for block_id, block in enumerate(dit.single_blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
)
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
@@ -715,6 +774,8 @@ def lets_dance_flux(
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(hidden_states)
|
||||
|
||||
if hidden_states_ref is not None:
|
||||
hidden_states = hidden_states[:, :original_hidden_states_length]
|
||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = dit.final_proj_out(hidden_states)
|
||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||
|
||||
@@ -18,6 +18,7 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
|
||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
|
||||
|
||||
|
||||
@@ -31,7 +32,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.image_encoder: WanImageEncoder = None
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
self.use_unified_sequence_parallel = False
|
||||
@@ -122,6 +124,22 @@ class WanVideoPipeline(BasePipeline):
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
if self.motion_controller is not None:
|
||||
dtype = next(iter(self.motion_controller.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.motion_controller,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
|
||||
@@ -134,6 +152,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -163,22 +182,47 @@ class WanVideoPipeline(BasePipeline):
|
||||
return {"context": prompt_emb}
|
||||
|
||||
|
||||
def encode_image(self, image, num_frames, height, width):
|
||||
def encode_image(self, image, end_image, num_frames, height, width):
|
||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||
clip_context = self.image_encoder.encode_image([image])
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
|
||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||
msk[:, -1:] = 1
|
||||
else:
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
|
||||
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"clip_feature": clip_context, "y": y}
|
||||
|
||||
|
||||
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
control_video = self.preprocess_images(control_video)
|
||||
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
return latents
|
||||
|
||||
|
||||
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if control_video is not None:
|
||||
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if clip_feature is None or y is None:
|
||||
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
|
||||
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
y = y[:, -16:]
|
||||
y = torch.concat([control_latents, y], dim=1)
|
||||
return {"clip_feature": clip_feature, "y": y}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
@@ -204,6 +248,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
def prepare_unified_sequence_parallel(self):
|
||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||
|
||||
|
||||
def prepare_motion_bucket_id(self, motion_bucket_id):
|
||||
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"motion_bucket_id": motion_bucket_id}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -212,7 +261,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_image=None,
|
||||
end_image=None,
|
||||
input_video=None,
|
||||
control_video=None,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="cpu",
|
||||
@@ -222,6 +273,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
cfg_scale=5.0,
|
||||
num_inference_steps=50,
|
||||
sigma_shift=5.0,
|
||||
motion_bucket_id=None,
|
||||
tiled=True,
|
||||
tile_size=(30, 52),
|
||||
tile_stride=(15, 26),
|
||||
@@ -263,10 +315,21 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Encode image
|
||||
if input_image is not None and self.image_encoder is not None:
|
||||
self.load_models_to_device(["image_encoder", "vae"])
|
||||
image_emb = self.encode_image(input_image, num_frames, height, width)
|
||||
image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
|
||||
else:
|
||||
image_emb = {}
|
||||
|
||||
# ControlNet
|
||||
if control_video is not None:
|
||||
self.load_models_to_device(["image_encoder", "vae"])
|
||||
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
|
||||
|
||||
# Motion Controller
|
||||
if self.motion_controller is not None and motion_bucket_id is not None:
|
||||
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
|
||||
else:
|
||||
motion_kwargs = {}
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
@@ -278,14 +341,24 @@ class WanVideoPipeline(BasePipeline):
|
||||
usp_kwargs = self.prepare_unified_sequence_parallel()
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit"])
|
||||
self.load_models_to_device(["dit", "motion_controller"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
|
||||
noise_pred_posi = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **image_emb, **extra_input,
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs)
|
||||
noise_pred_nega = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **image_emb, **extra_input,
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
@@ -358,13 +431,15 @@ class TeaCache:
|
||||
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
x: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if use_unified_sequence_parallel:
|
||||
@@ -375,6 +450,8 @@ def model_fn_wan_video(
|
||||
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||
if motion_bucket_id is not None and motion_controller is not None:
|
||||
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
if dit.has_image_input:
|
||||
|
||||
@@ -10,34 +10,52 @@ cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
|
||||
## Model Zoo
|
||||
|
||||
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|
||||
|Developer|Name|Link|Scripts|
|
||||
|-|-|-|-|
|
||||
|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)|
|
||||
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)|
|
||||
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|
||||
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|
||||
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
|
||||
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
|
||||
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
|
||||
|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)|
|
||||
|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
||||
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
||||
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
||||
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
||||
|
||||
## Inference
|
||||
Base model features
|
||||
|
||||
### Wan-Video-1.3B-T2V
|
||||
||Text-to-video|Image-to-video|End frame|Control|
|
||||
|-|-|-|-|-|
|
||||
|1.3B text-to-video|✅||||
|
||||
|14B text-to-video|✅||||
|
||||
|14B image-to-video 480P||✅|||
|
||||
|14B image-to-video 720P||✅|||
|
||||
|1.3B InP||✅|✅||
|
||||
|14B InP||✅|✅||
|
||||
|1.3B Control||||✅|
|
||||
|14B Control||||✅|
|
||||
|
||||
Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
|
||||
Adapter model compatibility
|
||||
|
||||
Required VRAM: 6G
|
||||
||1.3B text-to-video|1.3B InP|
|
||||
|-|-|-|
|
||||
|1.3B aesthetics LoRA|✅||
|
||||
|1.3B Highres-fix LoRA|✅||
|
||||
|1.3B ExVideo LoRA|✅||
|
||||
|1.3B Speed Control adapter|✅|✅|
|
||||
|
||||
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
||||
## VRAM Usage
|
||||
|
||||
Put sunglasses on the dog.
|
||||
* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||
|
||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||
* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!).
|
||||
|
||||
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
|
||||
|
||||
### Wan-Video-14B-T2V
|
||||
|
||||
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||
|
||||
We present a detailed table here. The model is tested on a single A100.
|
||||
We present a detailed table here. The model (14B text-to-video) is tested on a single A100.
|
||||
|
||||
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|
||||
|-|-|-|-|-|
|
||||
@@ -47,31 +65,46 @@ We present a detailed table here. The model is tested on a single A100.
|
||||
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|
||||
|torch.float8_e4m3fn|0|24.0s/it|10G||
|
||||
|
||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||
**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
|
||||
|
||||
### Parallel Inference
|
||||
## Efficient Attention Implementation
|
||||
|
||||
1. Unified Sequence Parallel (USP)
|
||||
DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA.
|
||||
|
||||
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|
||||
|
||||
## Acceleration
|
||||
|
||||
We support multiple acceleration solutions:
|
||||
* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py).
|
||||
|
||||
* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py)
|
||||
|
||||
```bash
|
||||
pip install xfuser>=0.4.3
|
||||
```
|
||||
|
||||
```bash
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
|
||||
```
|
||||
|
||||
2. Tensor Parallel
|
||||
* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py).
|
||||
|
||||
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
|
||||
## Gallery
|
||||
|
||||
### Wan-Video-14B-I2V
|
||||
1.3B text-to-video.
|
||||
|
||||
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
|
||||
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
||||
|
||||
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
|
||||
Put sunglasses on the dog.
|
||||
|
||||

|
||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||
|
||||
14B text-to-video.
|
||||
|
||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||
|
||||
14B image-to-video.
|
||||
|
||||
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
|
||||
|
||||
|
||||
41
examples/wanvideo/wan_1.3b_motion_controller.py
Normal file
41
examples/wanvideo/wan_1.3b_motion_controller.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
||||
snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||
"models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors",
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=1, tiled=True,
|
||||
motion_bucket_id=0
|
||||
)
|
||||
save_video(video, "video_slow.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=1, tiled=True,
|
||||
motion_bucket_id=100
|
||||
)
|
||||
save_video(video, "video_fast.mp4", fps=15, quality=5)
|
||||
42
examples/wanvideo/wan_fun_InP.py
Normal file
42
examples/wanvideo/wan_fun_InP.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download, dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Download example image
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# Image-to-video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
input_image=image,
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
40
examples/wanvideo/wan_fun_control.py
Normal file
40
examples/wanvideo/wan_fun_control.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download, dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Download example video
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/control_video.mp4"
|
||||
)
|
||||
|
||||
# Control-to-video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
control_video=control_video, height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
2
setup.py
2
setup.py
@@ -14,7 +14,7 @@ else:
|
||||
|
||||
setup(
|
||||
name="diffsynth",
|
||||
version="1.1.2",
|
||||
version="1.1.7",
|
||||
description="Enjoy the magic of Diffusion models!",
|
||||
author="Artiprocher",
|
||||
packages=find_packages(),
|
||||
|
||||
241
train_flux_reference.py
Normal file
241
train_flux_reference.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline
|
||||
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
import torch, os, argparse
|
||||
import lightning as pl
|
||||
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||
|
||||
|
||||
class LightningModel(LightningModelForT2ILoRA):
|
||||
def __init__(
|
||||
self,
|
||||
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
|
||||
learning_rate=1e-4, use_gradient_checkpointing=True,
|
||||
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
|
||||
state_dict_converter=None, quantize = None
|
||||
):
|
||||
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||
if quantize is None:
|
||||
model_manager.load_models(pretrained_weights)
|
||||
else:
|
||||
model_manager.load_models(pretrained_weights[1:])
|
||||
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
||||
if preset_lora_path is not None:
|
||||
preset_lora_path = preset_lora_path.split(",")
|
||||
for path in preset_lora_path:
|
||||
model_manager.load_lora(path)
|
||||
|
||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
self.pipe.reference_embedder = FluxReferenceEmbedder()
|
||||
self.pipe.reference_embedder.init()
|
||||
|
||||
if quantize is not None:
|
||||
self.pipe.dit.quantize()
|
||||
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
self.freeze_parameters()
|
||||
self.pipe.reference_embedder.requires_grad_(True)
|
||||
self.pipe.reference_embedder.train()
|
||||
self.pipe.dit.requires_grad_(True)
|
||||
self.pipe.dit.train()
|
||||
# self.add_lora_to_model(
|
||||
# self.pipe.denoising_model(),
|
||||
# lora_rank=lora_rank,
|
||||
# lora_alpha=lora_alpha,
|
||||
# lora_target_modules=lora_target_modules,
|
||||
# init_lora_weights=init_lora_weights,
|
||||
# pretrained_lora_path=pretrained_lora_path,
|
||||
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
||||
# )
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
text, image = batch["instruction"], batch["image_2"]
|
||||
image_ref = batch["image_1"]
|
||||
|
||||
# Prepare input parameters
|
||||
self.pipe.device = self.device
|
||||
prompt_emb = self.pipe.encode_prompt(text, positive=True)
|
||||
if "latents" in batch:
|
||||
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
|
||||
else:
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Reference image
|
||||
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
|
||||
# Compute loss
|
||||
noise_pred = lets_dance_flux(
|
||||
self.pipe.denoising_model(),
|
||||
reference_embedder=self.pipe.reference_embedder,
|
||||
hidden_states_ref=hidden_states_ref,
|
||||
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint.clear()
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
state_dict = self.pipe.state_dict()
|
||||
lora_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in trainable_param_names:
|
||||
lora_state_dict[name] = param
|
||||
if self.state_dict_converter is not None:
|
||||
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
|
||||
checkpoint.update(lora_state_dict)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_text_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_text_encoder_2_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_dit_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_vae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
|
||||
help="Layers with LoRA modules.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--align_to_opensource_format",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to export lora files aligned with other opensource format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["float8_e4m3fn"],
|
||||
help="Whether to use quantization when training the model, and in which format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preset_lora_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Preset LoRA path.",
|
||||
)
|
||||
parser = add_general_parsers(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
model = LightningModel(
|
||||
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
|
||||
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
|
||||
preset_lora_path=args.preset_lora_path,
|
||||
learning_rate=args.learning_rate,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
init_lora_weights=args.init_lora_weights,
|
||||
pretrained_lora_path=args.pretrained_lora_path,
|
||||
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
|
||||
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
|
||||
)
|
||||
# dataset and data loader
|
||||
dataset = MultiTaskDataset(
|
||||
dataset_list=[
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
],
|
||||
dataset_weight=(4, 1, 4, 1),
|
||||
steps_per_epoch=args.steps_per_epoch,
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
# train
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
precision=args.precision,
|
||||
strategy=args.training_strategy,
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||
logger=None,
|
||||
)
|
||||
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||
248
train_flux_reference_multi_node.py
Normal file
248
train_flux_reference_multi_node.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline
|
||||
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
import torch, os, argparse
|
||||
import lightning as pl
|
||||
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||
|
||||
|
||||
class LightningModel(LightningModelForT2ILoRA):
|
||||
def __init__(
|
||||
self,
|
||||
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
|
||||
learning_rate=1e-4, use_gradient_checkpointing=True,
|
||||
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
|
||||
state_dict_converter=None, quantize = None
|
||||
):
|
||||
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||
if quantize is None:
|
||||
model_manager.load_models(pretrained_weights)
|
||||
else:
|
||||
model_manager.load_models(pretrained_weights[1:])
|
||||
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
||||
if preset_lora_path is not None:
|
||||
preset_lora_path = preset_lora_path.split(",")
|
||||
for path in preset_lora_path:
|
||||
model_manager.load_lora(path)
|
||||
|
||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
self.pipe.reference_embedder = FluxReferenceEmbedder()
|
||||
self.pipe.reference_embedder.init()
|
||||
|
||||
if quantize is not None:
|
||||
self.pipe.dit.quantize()
|
||||
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
self.freeze_parameters()
|
||||
self.pipe.reference_embedder.requires_grad_(True)
|
||||
self.pipe.reference_embedder.train()
|
||||
self.pipe.dit.requires_grad_(True)
|
||||
self.pipe.dit.train()
|
||||
# self.add_lora_to_model(
|
||||
# self.pipe.denoising_model(),
|
||||
# lora_rank=lora_rank,
|
||||
# lora_alpha=lora_alpha,
|
||||
# lora_target_modules=lora_target_modules,
|
||||
# init_lora_weights=init_lora_weights,
|
||||
# pretrained_lora_path=pretrained_lora_path,
|
||||
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
||||
# )
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
text, image = batch["instruction"], batch["image_2"]
|
||||
image_ref = batch["image_1"]
|
||||
|
||||
# Prepare input parameters
|
||||
self.pipe.device = self.device
|
||||
prompt_emb = self.pipe.encode_prompt(text, positive=True)
|
||||
if "latents" in batch:
|
||||
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
|
||||
else:
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Reference image
|
||||
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
|
||||
# Compute loss
|
||||
noise_pred = lets_dance_flux(
|
||||
self.pipe.denoising_model(),
|
||||
reference_embedder=self.pipe.reference_embedder,
|
||||
hidden_states_ref=hidden_states_ref,
|
||||
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint.clear()
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
state_dict = self.pipe.state_dict()
|
||||
lora_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in trainable_param_names:
|
||||
lora_state_dict[name] = param
|
||||
if self.state_dict_converter is not None:
|
||||
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
|
||||
checkpoint.update(lora_state_dict)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_text_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_text_encoder_2_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_dit_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_vae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
|
||||
help="Layers with LoRA modules.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--align_to_opensource_format",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to export lora files aligned with other opensource format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["float8_e4m3fn"],
|
||||
help="Whether to use quantization when training the model, and in which format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preset_lora_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Preset LoRA path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Num nodes.",
|
||||
)
|
||||
parser = add_general_parsers(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
model = LightningModel(
|
||||
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
|
||||
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
|
||||
preset_lora_path=args.preset_lora_path,
|
||||
learning_rate=args.learning_rate,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
init_lora_weights=args.init_lora_weights,
|
||||
pretrained_lora_path=args.pretrained_lora_path,
|
||||
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
|
||||
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
|
||||
)
|
||||
# dataset and data loader
|
||||
dataset = MultiTaskDataset(
|
||||
dataset_list=[
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
],
|
||||
dataset_weight=(4, 1, 4, 1),
|
||||
steps_per_epoch=args.steps_per_epoch,
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
# train
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
num_nodes=args.num_nodes,
|
||||
precision=args.precision,
|
||||
strategy="ddp",
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||
logger=None,
|
||||
)
|
||||
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||
Reference in New Issue
Block a user