Compare commits

..

66 Commits

Author SHA1 Message Date
Artiprocher
ba421a9ab9 value control 2025-07-04 14:26:33 +08:00
Artiprocher
6c30a7f080 value control 2025-07-04 14:23:07 +08:00
Artiprocher
1363a0559f support json dataset 2025-07-02 20:07:16 +08:00
Artiprocher
9bb51fe879 refine wan readme 2025-07-02 11:36:41 +08:00
Zhongjie Duan
d9c812818d Merge pull request #653 from mi804/main
fix step1xedit
2025-07-01 17:16:41 +08:00
mi804
c8e9a96196 fix step1xedit 2025-07-01 17:12:53 +08:00
Zhongjie Duan
6143af4654 Merge pull request #651 from mi804/infiniteyou_controlnet_replace
infiniteyou_controlnet outof pipeline
2025-07-01 13:39:47 +08:00
Zhongjie Duan
9458e382b0 Merge pull request #652 from modelscope/flux-refactor
refine readme
2025-07-01 11:34:00 +08:00
Artiprocher
4f2d9226cf refine readme 2025-07-01 11:33:04 +08:00
mi804
f688a469b1 infiniteyou_controlnet outof pipeline 2025-07-01 11:10:46 +08:00
Zhongjie Duan
c8ea3b3356 Merge pull request #649 from modelscope/flux-refactor
refine readme
2025-06-30 11:46:16 +08:00
Artiprocher
6e9472b470 refine readme 2025-06-30 11:45:40 +08:00
Zhongjie Duan
a5c03c5272 Merge pull request #648 from modelscope/flux-refactor
refine readme
2025-06-30 11:44:47 +08:00
Artiprocher
8068ac2592 refine readme 2025-06-30 11:43:59 +08:00
Zhongjie Duan
5f80e7ac5e Merge pull request #647 from modelscope/flux-refactor
kontext training
2025-06-30 11:09:22 +08:00
Artiprocher
157e0be49d kontext training 2025-06-30 11:00:10 +08:00
Zhongjie Duan
3dbe271aab Merge pull request #646 from modelscope/flux-refactor
Flux refactor
2025-06-29 18:04:05 +08:00
Artiprocher
44e2eecdf1 flux-kontext 2025-06-29 15:59:04 +08:00
Artiprocher
8c226e83a6 flux-kontext 2025-06-29 15:51:45 +08:00
Artiprocher
009f26bb40 kontext 2025-06-27 18:38:40 +08:00
Artiprocher
fcf2fbc07f flux-refactor 2025-06-27 10:20:11 +08:00
Artiprocher
b603acd36a refine examples 2025-06-25 13:38:21 +08:00
Artiprocher
6c8bb6438b infiniteyou 2025-06-25 10:33:11 +08:00
Artiprocher
8072d3839d refine examples 2025-06-24 19:17:54 +08:00
Artiprocher
c8ad643374 refine examples 2025-06-24 19:17:43 +08:00
Zhongjie Duan
31f9df5e62 Merge pull request #567 from emmanuel-ferdman/main
Migrate to modern Python Logger API
2025-06-24 15:32:14 +08:00
Zhongjie Duan
e2f415524a Merge pull request #587 from ernestchu/patch-1
Fix typo
2025-06-24 15:23:19 +08:00
Zhongjie Duan
3eb7e7530e Merge pull request #632 from lzws/flux-refactor
step1x, teacache, flex refactor
2025-06-24 15:19:54 +08:00
Zhongjie Duan
916aa54595 Merge branch 'flux-refactor' into flux-refactor 2025-06-24 15:19:42 +08:00
Zhongjie Duan
6ddbd43f7b Merge pull request #634 from modelscope/bugfix
fix videodataset to load images
2025-06-24 11:42:14 +08:00
Artiprocher
a37a83ecc3 fix videodataset to load images 2025-06-24 11:38:43 +08:00
Zhongjie Duan
f2a0d0c85f Merge pull request #633 from modelscope/bugfix
fix i2v resolution
2025-06-24 10:59:31 +08:00
Artiprocher
93194f44e8 fix i2v resolution 2025-06-24 10:56:52 +08:00
Artiprocher
c4e5033532 flux controlnet 2025-06-23 21:01:53 +08:00
lzw478614@alibaba-inc.com
cc6cd26733 step1x, teacache, flex refactor 2025-06-23 17:06:00 +08:00
Zhongjie Duan
1113d305d1 Merge pull request #626 from mi804/flux-refactor
Flux refactor
2025-06-23 10:20:40 +08:00
mi804
6d5f8b7423 flux_eligen_refactor 2025-06-20 16:53:41 +08:00
mi804
1b3c204d20 flux_ipadapter_refactor 2025-06-20 14:49:09 +08:00
Artiprocher
1788d50f0a flux-refactor 2025-06-19 15:04:30 +08:00
Artiprocher
e7a21dbf0b flux-refactor 2025-06-19 14:53:11 +08:00
Zhongjie Duan
3b3e1e4d44 Merge pull request #623 from modelscope/usp
Usp
2025-06-19 10:15:39 +08:00
Artiprocher
24426e3a32 update README_zh 2025-06-19 10:06:55 +08:00
Artiprocher
31369bab15 update import 2025-06-19 10:04:24 +08:00
mi804
551721658b fix bug for usp with refimage 2025-06-16 19:38:45 +08:00
mi804
46f052375f fix vace usp 2025-06-16 18:54:29 +08:00
Zhongjie Duan
c2d35a2157 update wan training (#614)
update wan training
2025-06-16 15:48:35 +08:00
mi804
4c052e42bc fix usp download 2025-06-16 15:43:39 +08:00
Zhongjie Duan
a88613555d Merge pull request #612 from Yunnglin/update/eval_news
update readme for eval
2025-06-16 14:06:52 +08:00
Zhongjie Duan
c164519ef1 vram management support torch<2.6.0 (#613)
support torch<2.6.0
2025-06-16 13:08:29 +08:00
Yunnglin
afff5ffb21 update readme 2025-06-16 11:08:53 +08:00
Yunnglin
a8481fd5e1 update readme 2025-06-16 11:00:53 +08:00
Zhongjie Duan
8584e50309 Merge pull request #611 from modelscope/refactor
fix model id
2025-06-16 10:58:14 +08:00
Artiprocher
9f3e02f167 fix model id 2025-06-16 10:57:33 +08:00
Zhongjie Duan
7ad9b9aecc Merge pull request #609 from modelscope/refactor
refine readme
2025-06-13 14:14:14 +08:00
Artiprocher
b6a111d3a2 refine readme 2025-06-13 14:13:38 +08:00
Zhongjie Duan
bd6f2695a9 Merge pull request #608 from modelscope/refactor
Refactor
2025-06-13 14:02:49 +08:00
Artiprocher
6eecc9d442 refine readme 2025-06-13 14:02:20 +08:00
Artiprocher
35269783d7 refine readme 2025-06-13 14:00:58 +08:00
Zhongjie Duan
9534a78167 Merge pull request #607 from modelscope/refactor
wan-refactor
2025-06-13 13:49:00 +08:00
Artiprocher
830b1b7202 wan-refactor 2025-06-13 13:46:17 +08:00
Zhongjie Duan
436a91e0c9 Merge pull request #602 from modelscope/revert-601-wan-refactor
Revert "Wan refactor"
2025-06-11 17:30:06 +08:00
Zhongjie Duan
40760ab88b Revert "Wan refactor" 2025-06-11 17:29:27 +08:00
CD22104
8badd63a2d Merge pull request #601 from CD22104/wan-refactor
Wan refactor
2025-06-11 17:26:58 +08:00
CD22104
b1afff1728 camera 2025-06-11 17:24:09 +08:00
Ernie Chu
4e00c109e3 Fix typo
Change
Only `num_frames % 4 != 1` is acceptable
to
Only `num_frames % 4 == 1` is acceptable
2025-05-27 21:20:38 -04:00
Emmanuel Ferdman
a3a35acc7e Migrate to modern Python Logger API
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-05-12 14:09:26 -07:00
51 changed files with 3260 additions and 79 deletions

View File

@@ -42,6 +42,8 @@ Until now, DiffSynth-Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News
- **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide.
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.

View File

@@ -64,6 +64,8 @@ from ..models.wan_video_vace import VaceWanModel
from ..models.step1x_connector import Qwen2Connector
from ..models.flux_value_control import SingleValueEncoder
model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -102,6 +104,7 @@ model_loader_configs = [
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
(None, "3ede90c44b2c161240b659f3b8393c9d", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),

View File

@@ -413,7 +413,7 @@ class BertEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warn(
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@@ -0,0 +1,13 @@
import torch
from diffsynth.lora import GeneralLoRALoader
from diffsynth.models.lora import FluxLoRAFromCivitai
class FluxLoRALoader(GeneralLoRALoader):
def __init__(self, device="cpu", torch_dtype=torch.float32):
super().__init__(device=device, torch_dtype=torch_dtype)
self.loader = FluxLoRAFromCivitai()
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
lora_prefix, model_resource = self.loader.match(model, state_dict_lora)
self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource)

View File

@@ -0,0 +1,58 @@
import torch
from diffsynth.models.svd_unet import TemporalTimesteps
class MultiValueEncoder(torch.nn.Module):
def __init__(self, encoders=()):
super().__init__()
self.encoders = torch.nn.ModuleList(encoders)
def __call__(self, values, dtype):
emb = []
for encoder, value in zip(self.encoders, values):
if value is not None:
value = value.unsqueeze(0)
emb.append(encoder(value, dtype))
emb = torch.concat(emb, dim=0)
return emb
class SingleValueEncoder(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
super().__init__()
self.prefer_len = prefer_len
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
self.prefer_value_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
self.positional_embedding = torch.nn.Parameter(
torch.randn(self.prefer_len, dim_in)
)
self._initialize_weights()
def _initialize_weights(self):
last_linear = self.prefer_value_embedder[-1]
torch.nn.init.zeros_(last_linear.weight)
torch.nn.init.zeros_(last_linear.bias)
def forward(self, value, dtype):
emb = self.prefer_proj(value).to(dtype)
emb = emb.expand(self.prefer_len, -1)
emb = emb + self.positional_embedding
emb = self.prefer_value_embedder(emb)
return emb
@staticmethod
def state_dict_converter():
return SingleValueEncoderStateDictConverter()
class SingleValueEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -1373,7 +1373,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "

File diff suppressed because it is too large Load Diff

View File

@@ -379,7 +379,7 @@ class WanVideoPipeline(BasePipeline):
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
print(f"Only `num_frames % 4 == 1` is acceptable. We round it up to {num_frames}.")
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}

View File

@@ -143,10 +143,8 @@ class BasePipeline(torch.nn.Module):
self.vram_management_enabled = True
def get_free_vram(self):
total_memory = torch.cuda.get_device_properties(self.device).total_memory
allocated_memory = torch.cuda.device_memory_used(self.device)
return (total_memory - allocated_memory) / (1024 ** 3)
def get_vram(self):
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
def freeze_except(self, model_names):
@@ -168,20 +166,50 @@ class ModelConfig:
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
def download_if_necessary(self, local_model_path="./models", skip_download=False):
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
if self.path is None:
if self.model_id is None or self.origin_file_pattern is None:
# Check model_id and origin_file_pattern
if self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
# Skip if not in rank 0
if use_usp:
import torch.distributed as dist
skip_download = dist.get_rank() != 0
# Check whether the origin path is a folder
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.origin_file_pattern = ""
allow_file_pattern = None
is_folder = True
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
allow_file_pattern = self.origin_file_pattern + "*"
is_folder = True
else:
allow_file_pattern = self.origin_file_pattern
is_folder = False
# Download
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
snapshot_download(
self.model_id,
local_dir=os.path.join(local_model_path, self.model_id),
allow_file_pattern=self.origin_file_pattern,
allow_file_pattern=allow_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
# Let rank 1, 2, ... wait for rank 0
if use_usp:
import torch.distributed as dist
dist.barrier(device_ids=[dist.get_rank()])
# Return downloaded files
if is_folder:
self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern)
else:
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
@@ -247,7 +275,7 @@ class WanVideoPipeline(BasePipeline):
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_free_vram()
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
dtype = next(iter(self.text_encoder.parameters())).dtype
@@ -427,19 +455,21 @@ class WanVideoPipeline(BasePipeline):
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp: pipe.initialize_usp()
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary(local_model_path, skip_download=skip_download)
model_config.download_if_necessary(local_model_path, skip_download=skip_download, use_usp=use_usp)
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp: pipe.initialize_usp()
# Load models
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
pipe.dit = model_manager.fetch_model("wan_video_dit")
pipe.vae = model_manager.fetch_model("wan_video_vae")
@@ -608,11 +638,17 @@ class PipelineUnitRunner:
elif unit.seperate_cfg:
# Positive side
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_posi.update(processor_outputs)
# Negative side
if inputs_shared["cfg_scale"] != 1:
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_nega.update(processor_outputs)
else:
@@ -1150,17 +1186,20 @@ def model_fn_wan_video(
else:
x = block(x, context, t_mod, freqs)
if vace_context is not None and block_id in vace.vace_layers_mapping:
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
x = x + current_vace_hint * vace_scale
if tea_cache is not None:
tea_cache.store(x)
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]
f -= 1
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
# Remove reference latents
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]
f -= 1
x = dit.unpatchify(x, (f, h, w))
return x

View File

@@ -1,4 +1,4 @@
import imageio, os, torch, warnings, torchvision, argparse
import imageio, os, torch, warnings, torchvision, argparse, json
from peft import LoraConfig, inject_adapter_in_model
from PIL import Image
import pandas as pd
@@ -7,12 +7,139 @@ from accelerate import Accelerator
class ImageDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None, metadata_path=None,
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
data_file_keys=("image",),
image_file_extension=("jpg", "jpeg", "png", "webp"),
repeat=1,
args=None,
):
if args is not None:
base_path = args.dataset_base_path
metadata_path = args.dataset_metadata_path
height = args.height
width = args.width
max_pixels = args.max_pixels
data_file_keys = args.data_file_keys.split(",")
repeat = args.dataset_repeat
self.base_path = base_path
self.max_pixels = max_pixels
self.height = height
self.width = width
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.data_file_keys = data_file_keys
self.image_file_extension = image_file_extension
self.repeat = repeat
if height is not None and width is not None:
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False
elif height is None and width is None:
print("Height and width are none. Setting `dynamic_resolution` to True.")
self.dynamic_resolution = True
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(base_path)
print(f"{len(metadata)} lines in metadata.")
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
elif metadata_path.endswith(".json"):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.data = metadata
else:
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def generate_metadata(self, folder):
image_list, prompt_list = [], []
file_set = set(os.listdir(folder))
for file_name in file_set:
if "." not in file_name:
continue
file_ext_name = file_name.split(".")[-1].lower()
file_base_name = file_name[:-len(file_ext_name)-1]
if file_ext_name not in self.image_file_extension:
continue
prompt_file_name = file_base_name + ".txt"
if prompt_file_name not in file_set:
continue
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
prompt = f.read().strip()
image_list.append(file_name)
prompt_list.append(prompt)
metadata = pd.DataFrame()
metadata["image"] = image_list
metadata["prompt"] = prompt_list
return metadata
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.dynamic_resolution:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def load_image(self, file_path):
image = Image.open(file_path).convert("RGB")
image = self.crop_and_resize(image, *self.get_height_width(image))
return image
def load_data(self, file_path):
return self.load_image(file_path)
def __getitem__(self, data_id):
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
path = os.path.join(self.base_path, data[key])
data[key] = self.load_data(path)
if data[key] is None:
warnings.warn(f"cannot load file {data[key]}.")
return None
return data
def __len__(self):
return len(self.data) * self.repeat
class VideoDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None, metadata_path=None,
frame_interval=1, num_frames=81,
dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None,
num_frames=81,
time_division_factor=4, time_division_remainder=1,
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
data_file_keys=("video",),
image_file_extension=("jpg", "jpeg", "png", "webp"),
@@ -25,17 +152,15 @@ class VideoDataset(torch.utils.data.Dataset):
metadata_path = args.dataset_metadata_path
height = args.height
width = args.width
max_pixels = args.max_pixels
num_frames = args.num_frames
data_file_keys = args.data_file_keys.split(",")
repeat = args.dataset_repeat
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
self.base_path = base_path
self.frame_interval = frame_interval
self.num_frames = num_frames
self.dynamic_resolution = dynamic_resolution
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.max_pixels = max_pixels
self.height = height
self.width = width
@@ -46,9 +171,48 @@ class VideoDataset(torch.utils.data.Dataset):
self.video_file_extension = video_file_extension
self.repeat = repeat
if height is not None and width is not None and dynamic_resolution == True:
if height is not None and width is not None:
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False
elif height is None and width is None:
print("Height and width are none. Setting `dynamic_resolution` to True.")
self.dynamic_resolution = True
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(base_path)
print(f"{len(metadata)} lines in metadata.")
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
elif metadata_path.endswith(".json"):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.data = metadata
else:
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def generate_metadata(self, folder):
video_list, prompt_list = [], []
file_set = set(os.listdir(folder))
for file_name in file_set:
if "." not in file_name:
continue
file_ext_name = file_name.split(".")[-1].lower()
file_base_name = file_name[:-len(file_ext_name)-1]
if file_ext_name not in self.image_file_extension and file_ext_name not in self.video_file_extension:
continue
prompt_file_name = file_base_name + ".txt"
if prompt_file_name not in file_set:
continue
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
prompt = f.read().strip()
video_list.append(file_name)
prompt_list.append(prompt)
metadata = pd.DataFrame()
metadata["video"] = video_list
metadata["prompt"] = prompt_list
return metadata
def crop_and_resize(self, image, target_height, target_width):
@@ -75,15 +239,22 @@ class VideoDataset(torch.utils.data.Dataset):
height, width = self.height, self.width
return height, width
def get_num_frames(self, reader):
num_frames = self.num_frames
if int(reader.count_frames()) < num_frames:
num_frames = int(reader.count_frames())
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames):
def load_video(self, file_path):
reader = imageio.get_reader(file_path)
if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
num_frames = self.get_num_frames(reader)
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = reader.get_data(frame_id)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
frames.append(frame)
@@ -94,11 +265,7 @@ class VideoDataset(torch.utils.data.Dataset):
def load_image(self, file_path):
image = Image.open(file_path).convert("RGB")
image = self.crop_and_resize(image, *self.get_height_width(image))
return image
def load_video(self, file_path):
frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames)
frames = [image]
return frames
@@ -182,34 +349,52 @@ class DiffusionTrainingModule(torch.nn.Module):
def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate=1e-4, num_epochs=1, output_path="./models", remove_prefix_in_ckpt=None, args=None):
if args is not None:
learning_rate = args.learning_rate
num_epochs = args.num_epochs
output_path = args.output_path
remove_prefix_in_ckpt = args.remove_prefix_in_ckpt
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter
accelerator = Accelerator(gradient_accumulation_steps=1)
def on_step_end(self, loss):
pass
def on_epoch_end(self, accelerator, model, epoch_id):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
def launch_training_task(
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
num_epochs: int = 1,
gradient_accumulation_steps: int = 1,
):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch in range(num_epochs):
for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt)
os.makedirs(output_path, exist_ok=True)
path = os.path.join(output_path, f"epoch-{epoch}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
model_logger.on_step_end(loss)
scheduler.step()
model_logger.on_epoch_end(accelerator, model, epoch_id)
@@ -228,8 +413,9 @@ def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_pat
def wan_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Path to the metadata file of the dataset.")
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..")
parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
@@ -247,5 +433,33 @@ def wan_parser():
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser
def flux_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..")
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.")
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser

View File

@@ -1 +1,2 @@
from .layers import *
from .gradient_checkpointing import *

View File

@@ -0,0 +1,34 @@
import torch
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*args,
**kwargs,
):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
else:
model_output = model(*args, **kwargs)
return model_output

View File

@@ -13,7 +13,8 @@ class AutoTorchModule(torch.nn.Module):
super().__init__()
def check_free_vram(self):
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3)
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
return used_memory < self.vram_limit
def offload(self):

318
examples/flux/README.md Normal file
View File

@@ -0,0 +1,318 @@
# FLUX
[切换到中文](./README_zh.md)
FLUX is a series of image generation models open-sourced by Black-Forest-Labs.
**DiffSynth-Studio has introduced a new inference and training framework. If you need to use the old version, please click [here](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c).**
## Installation
Before using these models, please install DiffSynth-Studio from source code:
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
## Quick Start
You can quickly load the FLUX.1-dev model and perform inference by running the following code:
```python
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
image = pipe(prompt="a cat", seed=0)
image.save("image.jpg")
```
## Model Overview
**Support for the new framework of the FLUX series models is under active development. Stay tuned!**
| Model ID | Additional Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
## Model Inference
The following sections will help you understand our features and write inference code.
<details>
<summary>Loading Models</summary>
Models are loaded using `from_pretrained`:
```python
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
```
Here, `torch_dtype` and `device` refer to the computation precision and device, respectively. The `model_configs` can be configured in various ways to specify model paths:
* Download the model from [ModelScope Community](https://modelscope.cn/) and load it. In this case, provide `model_id` and `origin_file_pattern`, for example:
```python
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
```
* Load the model from a local file path. In this case, provide the `path`, for example:
```python
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")
```
For models that consist of multiple files, use a list as follows:
```python
ModelConfig(path=[
"models/xxx/diffusion_pytorch_model-00001-of-00003.safetensors",
"models/xxx/diffusion_pytorch_model-00002-of-00003.safetensors",
"models/xxx/diffusion_pytorch_model-00003-of-00003.safetensors",
])
```
The `from_pretrained` method also provides additional parameters to control model loading behavior:
* `local_model_path`: Path for saving downloaded models. The default is `"./models"`.
* `skip_download`: Whether to skip downloading models. The default is `False`. If your network cannot access [ModelScope Community](https://modelscope.cn/), manually download the required files and set this to `True`.
</details>
<details>
<summary>VRAM Management</summary>
DiffSynth-Studio provides fine-grained VRAM management for FLUX models, enabling inference on devices with limited VRAM. You can enable offloading functionality via the following code, which moves certain modules to system memory on devices with limited GPU memory.
```python
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
],
)
pipe.enable_vram_management()
```
The `enable_vram_management` function provides the following parameters to control VRAM usage:
* `vram_limit`: VRAM usage limit in GB. By default, it uses the remaining VRAM available on the device. Note that this is not an absolute limit; if the set VRAM is insufficient but more VRAM is actually available, the model will run with minimal VRAM consumption. Setting it to 0 achieves the theoretical minimum VRAM usage.
* `vram_buffer`: VRAM buffer size in GB. The default is 0.5GB. Since some large neural network layers may consume extra VRAM during onload phases, a VRAM buffer is necessary. Ideally, the optimal value should match the VRAM occupied by the largest layer in the model.
* `num_persistent_param_in_dit`: Number of persistent parameters in the DiT model (default: no limit). We plan to remove this parameter in the future, so please avoid relying on it.
</details>
<details>
<summary>Inference Acceleration</summary>
* TeaCache: Acceleration technique [TeaCache](https://github.com/ali-vilab/TeaCache), please refer to the [sample code](./acceleration/teacache.py).
</details>
<details>
<summary>Input Parameters</summary>
The pipeline accepts the following input parameters during inference:
* `prompt`: Prompt describing what should appear in the image.
* `negative_prompt`: Negative prompt describing what should **not** appear in the image. Default is `""`.
* `cfg_scale`: Classifier-free guidance scale. Default is 1. It becomes effective when set to a value greater than 1.
* `embedded_guidance`: Embedded guidance parameter for FLUX-dev. Default is 3.5.
* `t5_sequence_length`: Sequence length of T5 text embeddings. Default is 512.
* `input_image`: Input image used for image-to-image generation. This works together with `denoising_strength`.
* `denoising_strength`: Denoising strength, ranging from 0 to 1. Default is 1. When close to 0, the generated image will be similar to the input image; when close to 1, the generated image will differ significantly from the input. Do not set this to a non-1 value if no `input_image` is provided.
* `height`: Height of the generated image. Must be a multiple of 16.
* `width`: Width of the generated image. Must be a multiple of 16.
* `seed`: Random seed. Default is `None`, meaning completely random.
* `rand_device`: Device for generating random Gaussian noise. Default is `"cpu"`. Setting it to `"cuda"` may lead to different results across GPUs.
* `sigma_shift`: Parameter from Rectified Flow theory. Default is 3. A larger value increases the number of steps spent at the beginning of denoising and can improve image quality. However, it may cause inconsistencies between the generation process and training data.
* `num_inference_steps`: Number of inference steps. Default is 30.
* `kontext_images`: Input images for the Kontext model.
* `controlnet_inputs`: Inputs for the ControlNet model.
* `ipadapter_images`: Input images for the IP-Adapter model.
* `ipadapter_scale`: Control strength of the IP-Adapter model.
</details>
## Model Training
FLUX series models are trained using a unified script [`./model_training/train.py`](./model_training/train.py).
<details>
<summary>Script Parameters</summary>
The script supports the following parameters:
* Dataset
* `--dataset_base_path`: Root path to the dataset.
* `--dataset_metadata_path`: Path to the metadata file of the dataset.
* `--max_pixels`: Maximum pixel area, default is 1024*1024. When dynamic resolution is enabled, any image with a resolution larger than this value will be scaled down.。
* `--height`: Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
* `--width`: Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
* `--data_file_keys`: Keys in metadata for data files. Comma-separated.
* `--dataset_repeat`: Number of times the dataset repeats per epoch.
* Models
* `--model_paths`: Paths to load models. JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Comma-separated.
* Training
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of training epochs.
* `--output_path`: Output path for saving checkpoints.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint filenames.
* Trainable Modules
* `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which base model to apply LoRA on.
* `--lora_target_modules`: Which layers to apply LoRA on.
* `--lora_rank`: Rank of LoRA.
* Extra Inputs
* `--extra_inputs`: Additional model inputs. Comma-separated.
* VRAM Management
* `--use_gradient_checkpointing`: Whether to use gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of steps for gradient accumulation.
* Miscellaneous
* `--align_to_opensource_format`: Whether to align the FLUX DiT LoRA format with the open-source version. Only applicable to LoRA training for FLUX.1-dev and FLUX.1-Kontext-dev.
</details>
<details>
<summary>Step 1: Prepare Dataset</summary>
The dataset contains a series of files. We recommend organizing your dataset files as follows:
```
data/example_image_dataset/
├── metadata.csv
├── image1.jpg
└── image2.jpg
```
Here, `image1.jpg`, `image2.jpg` are training image data, and `metadata.csv` is the metadata list, for example:
```
image,prompt
image1.jpg,"a cat is sleeping"
image2.jpg,"a dog is running"
```
We have built a sample image dataset to help you test more conveniently. You can download this dataset using the following command:
```shell
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
The dataset supports multiple image formats: `"jpg", "jpeg", "png", "webp"`.
The image resolution can be controlled via script parameters `--height` and `--width`. When both `--height` and `--width` are left empty, dynamic resolution will be enabled, allowing training with the actual width and height of each image in the dataset.
**We strongly recommend using fixed-resolution training, as there may be load-balancing issues in multi-GPU training with dynamic resolution.**
When the model requires additional inputs—for instance, `kontext_images` required by the controllable model [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)—please add corresponding columns in the dataset, for example:
```
image,prompt,kontext_images
image1.jpg,"a cat is sleeping",image1_reference.jpg
```
If additional inputs include image files, you need to specify the column names to parse using the `--data_file_keys` parameter. You can add more column names accordingly, e.g., `--data_file_keys "image,kontext_images"`.
</details>
<details>
<summary>Step 2: Load Model</summary>
Similar to the model loading logic during inference, you can directly configure the model to be loaded using its model ID. For example, during inference we load the model with the following configuration:
```python
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
]
```
Then during training, simply provide the following parameter to load the corresponding model:
```shell
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"
```
If you prefer to load the model from local files, as in the inference example:
```python
model_configs=[
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"),
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors"),
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder_2/"),
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/ae.safetensors"),
]
```
Then during training, set it up as follows:
```shell
--model_paths '[
"models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors",
"models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors",
"models/black-forest-labs/FLUX.1-dev/text_encoder_2/",
"models/black-forest-labs/FLUX.1-dev/ae.safetensors"
]' \
```
</details>
<details>
<summary>Step 3: Configure Trainable Modules</summary>
The training framework supports both full-model training and LoRA-based fine-tuning. Below are some examples:
* Full training of the DiT module: `--trainable_models dit`
* Training a LoRA model on the DiT module: `--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
Additionally, since the training script loads multiple modules (text encoder, DiT, VAE), you need to remove prefixes when saving the model files. For example, when performing full DiT training or LoRA training on the DiT module, please set `--remove_prefix_in_ckpt pipe.dit.`
</details>
<details>
<summary>Step 4: Launch the Training Script</summary>
We have written specific training commands for each model. Please refer to the table at the beginning of this document for details.
</details>

327
examples/flux/README_zh.md Normal file
View File

@@ -0,0 +1,327 @@
# FLUX
[Switch to English](./README.md)
FLUX 是由 Black-Forest-Labs 开源的一系列图像生成模型。
**DiffSynth-Studio 启用了新的推理和训练框架,如需使用旧版本,请点击[这里](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c)。**
## 安装
在使用本系列模型之前,请通过源码安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
## 快速开始
通过运行以下代码可以快速加载 FLUX.1-dev 模型并进行推理。
```python
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
image = pipe(prompt="a cat", seed=0)
image.save("image.jpg")
```
## 模型总览
**FLUX 系列模型的全新框架支持正在开发中,敬请期待!**
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
## 模型推理
以下部分将会帮助您理解我们的功能并编写推理代码。
<details>
<summary>加载模型</summary>
模型通过 `from_pretrained` 加载:
```python
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
```
其中 `torch_dtype``device` 是计算精度和计算设备。`model_configs` 可通过多种方式配置模型路径:
* 从[魔搭社区](https://modelscope.cn/)下载模型并加载。此时需要填写 `model_id``origin_file_pattern`,例如
```python
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
```
* 从本地文件路径加载模型。此时需要填写 `path`,例如
```python
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")
```
对于从多个文件加载的单一模型,使用列表即可,例如
```python
ModelConfig(path=[
"models/xxx/diffusion_pytorch_model-00001-of-00003.safetensors",
"models/xxx/diffusion_pytorch_model-00002-of-00003.safetensors",
"models/xxx/diffusion_pytorch_model-00003-of-00003.safetensors",
])
```
`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为:
* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`
* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`
</details>
<details>
<summary>显存管理</summary>
DiffSynth-Studio 为 FLUX 模型提供了细粒度的显存管理,让模型能够在低显存设备上进行推理,可通过以下代码开启 offload 功能,在显存有限的设备上将部分模块 offload 到内存中。
```python
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
],
)
pipe.enable_vram_management()
```
`enable_vram_management` 函数提供了以下参数,用于控制显存使用情况:
* `vram_limit`: 显存占用量GB默认占用设备上的剩余显存。注意这不是一个绝对限制当设置的显存不足以支持模型进行推理但实际可用显存足够时将会以最小化显存占用的形式进行推理。将其设置为0时将会实现理论最小显存占用。
* `vram_buffer`: 显存缓冲区大小GB默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
</details>
<details>
<summary>推理加速</summary>
* TeaCache加速技术 [TeaCache](https://github.com/ali-vilab/TeaCache),请参考[示例代码](./acceleration/teacache.py)。
</details>
<details>
<summary>输入参数</summary>
Pipeline 在推理阶段能够接收以下输入参数:
* `prompt`: 提示词,描述画面中出现的内容。
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1当设置为大于1的数值时生效。
* `embedded_guidance`: FLUX-dev 的内嵌引导参数,默认值为 3.5。
* `t5_sequence_length`: T5 模型的文本向量序列长度,默认值为 512。
* `input_image`: 输入图像,用于图生图,该参数与 `denoising_strength` 配合使用。
* `denoising_strength`: 去噪强度,范围是 01默认值为 1当数值接近 0 时,生成图像与输入图像相似;当数值接近 1 时,生成图像与输入图像相差更大。在不输入 `input_image` 参数时,请不要将其设置为非 1 的数值。
* `height`: 图像高度,需保证高度为 16 的倍数。
* `width`: 图像宽度,需保证宽度为 16 的倍数。
* `seed`: 随机种子。默认为 `None`,即完全随机。
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
* `sigma_shift`: Rectified Flow 理论中的参数,默认为 3。数值越大模型在去噪的开始阶段停留的步骤数越多可适当调大这个参数来提高画面质量但会因生成过程与训练过程不一致导致生成的图像内容与训练数据存在差异。
* `num_inference_steps`: 推理次数,默认值为 30。
* `kontext_images`: Kontext 模型的输入图像。
* `controlnet_inputs`: ControlNet 模型的输入。
* `ipadapter_images`: IP-Adapter 模型的输入图像。
* `ipadapter_scale`: IP-Adapter 模型的控制强度。
</details>
## 模型训练
FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_training/train.py) 脚本进行。
<details>
<summary>脚本参数</summary>
脚本包含以下参数:
* 数据集
* `--dataset_base_path`: 数据集的根路径。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--max_pixels`: 最大像素面积,默认为 1024*1024当启用动态分辨率时任何分辨率大于这个数值的图片都会被缩小。
* `--height`: 图像或视频的高度。将 `height``width` 留空以启用动态分辨率。
* `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
* 训练
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch数量。
* `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪一层上。
* `--lora_rank`: LoRA 的秩Rank
* 额外模型输入
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
* 显存管理
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 其他
* `--align_to_opensource_format`: 是否将 FLUX DiT LoRA 的格式与开源版本对齐,仅对 FLUX.1-dev 和 FLUX.1-Kontext-dev 的 LoRA 训练生效。
此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。
</details>
<details>
<summary>Step 1: 准备数据集</summary>
数据集包含一系列文件,我们建议您这样组织数据集文件:
```
data/example_image_dataset/
├── metadata.csv
├── image1.jpg
└── image2.jpg
```
其中 `image1.jpg``image2.jpg` 为训练用图像数据,`metadata.csv` 为元数据列表,例如
```
image,prompt
image1.jpg,"a cat is sleeping"
image2.jpg,"a dog is running"
```
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
```shell
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
数据集支持多种图片格式,`"jpg", "jpeg", "png", "webp"`
图片的尺寸可通过脚本参数 `--height``--width` 控制。当 `--height``--width` 为空时将会开启动态分辨率,按照数据集中每个图像的实际宽高训练。
**我们强烈建议使用固定分辨率训练,因为在多卡训练中存在负载均衡问题。**
当模型需要额外输入时,例如具备控制能力的模型 [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) 所需的 `kontext_images`,请在数据集中补充相应的列,例如:
```
image,prompt,kontext_images
image1.jpg,"a cat is sleeping",image1_reference.jpg
```
额外输入若包含图像文件,则需要在 `--data_file_keys` 参数中指定要解析的列名。可根据额外输入增加相应的列名,例如 `--data_file_keys "image,kontext_images"`,同时启用 `--extra_inputs "kontext_images"`
</details>
<details>
<summary>Step 2: 加载模型</summary>
类似于推理时的模型加载逻辑,可直接通过模型 ID 配置要加载的模型。例如,推理时我们通过以下设置加载模型
```python
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
]
```
那么在训练时,填入以下参数即可加载对应的模型。
```shell
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"
```
如果您希望从本地文件加载模型,例如推理时
```python
model_configs=[
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"),
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors"),
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder_2/"),
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/ae.safetensors"),
]
```
那么训练时需设置为
```shell
--model_paths '[
"models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors",
"models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors",
"models/black-forest-labs/FLUX.1-dev/text_encoder_2/",
"models/black-forest-labs/FLUX.1-dev/ae.safetensors"
]' \
```
</details>
<details>
<summary>Step 3: 设置可训练模块</summary>
训练框架支持训练基础模型,或 LoRA 模型。以下是几个例子:
* 全量训练 DiT 部分:`--trainable_models dit`
* 训练 DiT 部分的 LoRA 模型:`--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
此外由于训练脚本中加载了多个模块text encoder、dit、vae保存模型文件时需要移除前缀例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时,请设置 `--remove_prefix_in_ckpt pipe.dit.`
</details>
<details>
<summary>Step 4: 启动训练程序</summary>
我们为每一个模型编写了训练命令,请参考本文档开头的表格。
</details>

View File

@@ -0,0 +1,24 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
for tea_cache_l1_thresh in [None, 0.2, 0.4, 0.6, 0.8]:
image = pipe(
prompt=prompt, embedded_guidance=3.5, seed=0,
num_inference_steps=50, tea_cache_l1_thresh=tea_cache_l1_thresh
)
image.save(f"image_{tea_cache_l1_thresh}.png")

View File

@@ -0,0 +1,147 @@
import random
import torch
from PIL import Image, ImageDraw, ImageFont
from diffsynth import download_customized_models
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
# Create a blank image for overlays
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
colors = [
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
]
# Generate random colors for each mask
if use_random_colors:
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
# Font settings
try:
font = ImageFont.truetype("arial", font_size) # Adjust as needed
except IOError:
font = ImageFont.load_default(font_size)
# Overlay each mask onto the overlay image
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
# Convert mask to RGBA mode
mask_rgba = mask.convert('RGBA')
mask_data = mask_rgba.getdata()
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
mask_rgba.putdata(new_data)
# Draw the mask prompt text on the mask
draw = ImageDraw.Draw(mask_rgba)
mask_bbox = mask.getbbox() # Get the bounding box of the mask
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
# Alpha composite the overlay with this mask
overlay = Image.alpha_composite(overlay, mask_rgba)
# Composite the overlay onto the original image
result = Image.alpha_composite(image.convert('RGBA'), overlay)
# Save or display the resulting image
result.save(output_path)
return result
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
for seed in seeds:
# generate image
image = pipe(
prompt=global_prompt,
cfg_scale=3.0,
negative_prompt=negative_prompt,
num_inference_steps=50,
embedded_guidance=3.5,
seed=seed,
height=1024,
width=1024,
eligen_entity_prompts=entity_prompts,
eligen_entity_masks=masks,
)
image.save(f"eligen_example_{example_id}_{seed}.png")
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
download_from_modelscope = True
if download_from_modelscope:
model_id = "DiffSynth-Studio/Eligen"
downloading_priority = ["ModelScope"]
else:
model_id = "modelscope/EliGen"
downloading_priority = ["HuggingFace"]
EliGen_path = download_customized_models(
model_id=model_id,
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control",
downloading_priority=downloading_priority)[0]
pipe.load_lora(pipe.dit, EliGen_path, alpha=1)
# example 1
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
example(pipe, [0], 1, global_prompt, entity_prompts)
# example 2
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
example(pipe, [0], 2, global_prompt, entity_prompts)
# example 3
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
example(pipe, [27], 3, global_prompt, entity_prompts)
# example 4
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
example(pipe, [21], 4, global_prompt, entity_prompts)
# example 5
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
example(pipe, [0], 5, global_prompt, entity_prompts)
# example 6
global_prompt = "Snow White and the 6 Dwarfs."
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
example(pipe, [8], 6, global_prompt, entity_prompts)
# example 7, same prompt with different seeds
seeds = range(5, 9)
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
example(pipe, seeds, 7, global_prompt, entity_prompts)

View File

@@ -0,0 +1,50 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from diffsynth.controlnets.processors import Annotator
import numpy as np
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
image = pipe(
prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach",
num_inference_steps=50, embedded_guidance=3.5,
seed=0
)
image.save(f"image_1.jpg")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[200:400, 400:700] = 255
mask = Image.fromarray(mask)
mask.save(f"image_mask.jpg")
inpaint_image = image
image = pipe(
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
num_inference_steps=50, embedded_guidance=3.5,
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
seed=4
)
image.save(f"image_2_new.jpg")
control_image = Annotator("canny")(image)
control_image.save("image_control.jpg")
image = pipe(
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
num_inference_steps=50, embedded_guidance=3.5,
flex_control_image=control_image,
seed=4
)
image.save(f"image_3_new.jpg")

View File

@@ -0,0 +1,54 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
image_1 = pipe(
prompt="a beautiful Asian long-haired female college student.",
embedded_guidance=2.5,
seed=1,
)
image_1.save("image_1.jpg")
image_2 = pipe(
prompt="transform the style to anime style.",
kontext_images=image_1,
embedded_guidance=2.5,
seed=2,
)
image_2.save("image_2.jpg")
image_3 = pipe(
prompt="let her smile.",
kontext_images=image_1,
embedded_guidance=2.5,
seed=3,
)
image_3.save("image_3.jpg")
image_4 = pipe(
prompt="let the girl play basketball.",
kontext_images=image_1,
embedded_guidance=2.5,
seed=4,
)
image_4.save("image_4.jpg")
image_5 = pipe(
prompt="move the girl to a park, let her sit on a chair.",
kontext_images=image_1,
embedded_guidance=2.5,
seed=5,
)
image_5.save("image_5.jpg")

View File

@@ -0,0 +1,37 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
import numpy as np
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
image_1 = pipe(
prompt="a cat sitting on a chair",
height=1024, width=1024,
seed=8, rand_device="cuda",
)
image_1.save("image_1.jpg")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[100:350, 350: -300] = 255
mask = Image.fromarray(mask)
mask.save("mask.jpg")
image_2 = pipe(
prompt="a cat sitting on a chair, wearing sunglasses",
controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)],
height=1024, width=1024,
seed=9, rand_device="cuda",
)
image_2.save("image_2.jpg")

View File

@@ -0,0 +1,40 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth.controlnets.processors import Annotator
from diffsynth import download_models
download_models(["Annotators:Depth"])
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
image_1 = pipe(
prompt="a beautiful Asian girl, full body, red dress, summer",
height=1024, width=1024,
seed=6, rand_device="cuda",
)
image_1.save("image_1.jpg")
image_canny = Annotator("canny")(image_1)
image_depth = Annotator("depth")(image_1)
image_2 = pipe(
prompt="a beautiful Asian girl, full body, red dress, winter",
controlnet_inputs=[
ControlNetInput(image=image_canny, scale=0.3, processor_id="canny"),
ControlNetInput(image=image_depth, scale=0.3, processor_id="depth"),
],
height=1024, width=1024,
seed=7, rand_device="cuda",
)
image_2.save("image_2.jpg")

View File

@@ -0,0 +1,33 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
image_1 = pipe(
prompt="a photo of a cat, highly detailed",
height=768, width=768,
seed=0, rand_device="cuda",
)
image_1.save("image_1.jpg")
image_1 = image_1.resize((2048, 2048))
image_2 = pipe(
prompt="a photo of a cat, highly detailed",
controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)],
input_image=image_1,
denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=1, rand_device="cuda",
)
image_2.save("image_2.jpg")

View File

@@ -0,0 +1,24 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
ModelConfig(model_id="google/siglip-so400m-patch14-384"),
],
)
origin_prompt = "a rabbit in a garden, colorful flowers"
image = pipe(prompt=origin_prompt, height=1280, width=960, seed=42)
image.save("style image.jpg")
image = pipe(prompt="A piggy", height=1280, width=960, seed=42,
ipadapter_images=[image], ipadapter_scale=0.7)
image.save("A piggy.jpg")

View File

@@ -0,0 +1,59 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from modelscope import dataset_snapshot_download
from modelscope import snapshot_download
from PIL import Image
import numpy as np
snapshot_download(
"ByteDance/InfiniteYou",
allow_file_pattern="supports/insightface/models/antelopev2/*",
local_dir="models/ByteDance/InfiniteYou",
)
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"),
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"),
],
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/infiniteyou/*",
)
height, width = 1024, 1024
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
controlnet_inputs = [ControlNetInput(image=controlnet_image, scale=1.0, processor_id="None")]
prompt = "A man, portrait, cinematic"
id_image = "data/examples/infiniteyou/man.jpg"
id_image = Image.open(id_image).convert('RGB')
image = pipe(
prompt=prompt, seed=1,
infinityou_id_image=id_image, infinityou_guidance=1.0,
controlnet_inputs=controlnet_inputs,
num_inference_steps=50, embedded_guidance=3.5,
height=height, width=width,
)
image.save("man.jpg")
prompt = "A woman, portrait, cinematic"
id_image = "data/examples/infiniteyou/woman.jpg"
id_image = Image.open(id_image).convert('RGB')
image = pipe(
prompt=prompt, seed=1,
infinityou_id_image=id_image, infinityou_guidance=1.0,
controlnet_inputs=controlnet_inputs,
num_inference_steps=50, embedded_guidance=3.5,
height=height, width=width,
)
image.save("woman.jpg")

View File

@@ -0,0 +1,20 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
from diffsynth.models.flux_value_control import SingleValueEncoder, MultiValueEncoder
pipe.value_controller = MultiValueEncoder(encoders=[SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder()]).to(dtype=torch.bfloat16, device="cuda")
image = pipe(prompt="a cat", seed=0, value_controller_inputs=[0.5, 0.5, 1, 0])
image.save("flux.jpg")

View File

@@ -0,0 +1,26 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
image = pipe(prompt=prompt, seed=0)
image.save("flux.jpg")
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
seed=0, cfg_scale=2, num_inference_steps=50,
)
image.save("flux_cfg.jpg")

View File

@@ -0,0 +1,32 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from PIL import Image
import numpy as np
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"),
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
],
)
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
image = pipe(
prompt="draw red flowers in Chinese ink painting style",
step1x_reference_image=image,
width=832, height=1248, cfg_scale=6,
seed=1, rand_device='cuda'
)
image.save("image_1.jpg")
image = pipe(
prompt="add more flowers in Chinese ink painting style",
step1x_reference_image=image,
width=832, height=1248, cfg_scale=6,
seed=2, rand_device='cuda'
)
image.save("image_2.jpg")

View File

@@ -0,0 +1,14 @@
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
--data_file_keys "image,kontext_images" \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-Kontext-dev_full" \
--trainable_models "dit" \
--extra_inputs "kontext_images" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,14 @@
accelerate launch examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_ipadapter.csv \
--data_file_keys "image,ipadapter_images" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.ipadapter." \
--output_path "./models/train/FLUX.1-dev-IP-Adapter_full" \
--trainable_models "ipadapter" \
--extra_inputs "ipadapter_images" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,12 @@
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-dev_full" \
--trainable_models "dit" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,17 @@
accelerate launch examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
--data_file_keys "image,kontext_images" \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-Kontext-dev_lora" \
--lora_base_model "dit" \
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
--lora_rank 32 \
--align_to_opensource_format \
--extra_inputs "kontext_images" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,15 @@
accelerate launch examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-dev_lora" \
--lora_base_model "dit" \
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
--lora_rank 32 \
--align_to_opensource_format \
--use_gradient_checkpointing

View File

@@ -0,0 +1,117 @@
import torch, os, json
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
from diffsynth.models.lora import FluxLoRAConverter
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class FluxTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
):
super().__init__()
# Load models
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
# Reset training scheduler
self.pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Add LoRA to the base models
if lora_base_model is not None:
model = self.add_lora_to_model(
getattr(self.pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank
)
setattr(self.pipe, lora_base_model, model)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
def forward_preprocess(self, data):
# CFG-sensitive parameters
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {}
# CFG-unsensitive parameters
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"embedded_guidance": 1,
"t5_sequence_length": 512,
"tiled": False,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
# Extra inputs
for extra_input in self.extra_inputs:
inputs_shared[extra_input] = data[extra_input]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None):
if inputs is None: inputs = self.forward_preprocess(data)
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
loss = self.pipe.training_loss(**models, **inputs)
return loss
if __name__ == "__main__":
parser = flux_parser()
args = parser.parse_args()
dataset = ImageDataset(args=args)
model = FluxTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task(
dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
)

View File

@@ -0,0 +1,120 @@
import torch, os, json
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
from diffsynth.models.lora import FluxLoRAConverter
from diffsynth.models.flux_value_control import SingleValueEncoder, MultiValueEncoder
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class FluxTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
):
super().__init__()
# Load models
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
self.pipe.value_controller = MultiValueEncoder(encoders=[SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder()]).to(dtype=torch.bfloat16)
# Reset training scheduler
self.pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Add LoRA to the base models
if lora_base_model is not None:
model = self.add_lora_to_model(
getattr(self.pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank
)
setattr(self.pipe, lora_base_model, model)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
def forward_preprocess(self, data):
# CFG-sensitive parameters
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {}
# CFG-unsensitive parameters
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"embedded_guidance": 1,
"t5_sequence_length": 512,
"tiled": False,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
# Extra inputs
for extra_input in self.extra_inputs:
inputs_shared[extra_input] = data[extra_input]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None):
if inputs is None: inputs = self.forward_preprocess(data)
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
loss = self.pipe.training_loss(**models, **inputs)
return loss
if __name__ == "__main__":
parser = flux_parser()
args = parser.parse_args()
dataset = ImageDataset(args=args)
model = FluxTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task(
dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
)

View File

@@ -0,0 +1,26 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from diffsynth import load_state_dict
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
state_dict = load_state_dict("models/train/FLUX.1-Kontext-dev_full/epoch-0.safetensors")
pipe.dit.load_state_dict(state_dict)
image = pipe(
prompt="Make the dog turn its head around.",
kontext_images=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)),
height=768, width=768,
seed=0
)
image.save("image_FLUX.1-Kontext-dev_full.jpg")

View File

@@ -0,0 +1,28 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from diffsynth import load_state_dict
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
ModelConfig(model_id="google/siglip-so400m-patch14-384"),
],
)
state_dict = load_state_dict("models/train/FLUX.1-dev-IP-Adapter_full/epoch-0.safetensors")
pipe.ipadapter.load_state_dict(state_dict)
image = pipe(
prompt="a dog",
ipadapter_images=Image.open("data/example_image_dataset/1.jpg"),
height=768, width=768,
seed=0
)
image.save("image_FLUX.1-dev-IP-Adapter_full.jpg")

View File

@@ -0,0 +1,20 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from diffsynth import load_state_dict
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
state_dict = load_state_dict("models/train/FLUX.1-dev_full/epoch-0.safetensors")
pipe.dit.load_state_dict(state_dict)
image = pipe(prompt="a dog", seed=0)
image.save("image_FLUX.1-dev_full.jpg")

View File

@@ -0,0 +1,24 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
pipe.load_lora(pipe.dit, "models/train/FLUX.1-Kontext-dev_lora/epoch-4.safetensors", alpha=1)
image = pipe(
prompt="Make the dog turn its head around.",
kontext_images=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)),
height=768, width=768,
seed=0
)
image.save("image_FLUX.1-Kontext-dev_lora.jpg")

View File

@@ -0,0 +1,18 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev_lora/epoch-4.safetensors", alpha=1)
image = pipe(prompt="a dog", seed=0)
image.save("image_FLUX.1-dev_lora.jpg")

View File

@@ -4,6 +4,10 @@
Wan 2.1 is a collection of video synthesis models open-sourced by Alibaba.
**DiffSynth-Studio has adopted a new inference and training framework. To use the previous version, please click [here](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c).**
## Installation
Before using this model, please install DiffSynth-Studio from **source code**.
```shell
@@ -12,6 +16,34 @@ cd DiffSynth-Studio
pip install -e .
```
## Quick Start
```python
import torch
from diffsynth import save_video
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
save_video(video, "video1.mp4", fps=15, quality=5)
```
## Overview
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
@@ -167,7 +199,7 @@ Wan supports multiple acceleration techniques, including:
* **Unified Sequence Parallel**: Sequence parallelism based on [xDiT](https://github.com/xdit-project/xDiT). Please refer to [this example](./acceleration/unified_sequence_parallel.py), and run it using the command:
```shell
pip install xfuser>=0.4.3
pip install "xfuser[flash-attn]>=0.4.3"
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
```
@@ -283,7 +315,7 @@ video2.mp4,"a dog is running"
We have prepared a sample video dataset to help you test. You can download it using the following command:
```shell
modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
```
The dataset supports mixed training of videos and images. Supported video formats include `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`, and supported image formats include `"jpg", "jpeg", "png", "webp"`.
@@ -387,3 +419,25 @@ Note that full fine-tuning of the 14B model requires 8 GPUs, each with at least
The default video resolution in the training script is `480*832*81`. Increasing the resolution may cause out-of-memory errors. To reduce VRAM usage, add the parameter `--use_gradient_checkpointing_offload`.
</details>
## Gallery
1.3B text-to-video:
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
Put sunglasses on the dog (1.3B video-to-video):
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
14B text-to-video:
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
14B image-to-video:
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
LoRA training:
https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9

View File

@@ -4,6 +4,10 @@
Wan 2.1 是由阿里巴巴通义实验室开源的一系列视频生成模型。
**DiffSynth-Studio 启用了新的推理和训练框架,如需使用旧版本,请点击[这里](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c)。**
## 安装
在使用本系列模型之前,请通过源码安装 DiffSynth-Studio。
```shell
@@ -12,6 +16,34 @@ cd DiffSynth-Studio
pip install -e .
```
## 快速开始
```python
import torch
from diffsynth import save_video
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
save_video(video, "video1.mp4", fps=15, quality=5)
```
## 模型总览
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
@@ -169,7 +201,7 @@ Wan 支持多种加速方案,包括
* 统一序列并行:基于 [xDiT](https://github.com/xdit-project/xDiT) 实现的序列并行,请参考[示例代码](./acceleration/unified_sequence_parallel.py),使用以下命令运行:
```shell
pip install xfuser>=0.4.3
pip install "xfuser[flash-attn]>=0.4.3"
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
```
@@ -286,7 +318,7 @@ video2.mp4,"a dog is running"
我们构建了一个样例视频数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
```shell
modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
```
数据集支持视频和图片混合训练,支持的视频文件格式包括 `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`,支持的图片格式包括 `"jpg", "jpeg", "png", "webp"`
@@ -390,3 +422,25 @@ model_configs=[
训练脚本的默认视频尺寸为 `480*832*81`,提升分辨率将可能导致显存不足,请添加参数 `--use_gradient_checkpointing_offload` 降低显存占用。
</details>
## 案例展示
1.3B 文生视频:
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
给狗狗戴上墨镜1.3B 视频生视频):
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
14B 文生视频:
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
14B 图生视频:
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
LoRA 训练:
https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9

View File

@@ -1,8 +1,9 @@
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--height 720 \
--width 1280 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-5 \
@@ -10,4 +11,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.1-I2V-14B-720P_full" \
--trainable_models "dit" \
--extra_inputs "input_image"
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload

View File

@@ -8,7 +8,7 @@ accelerate launch examples/wanvideo/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full" \
--output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \

View File

@@ -8,7 +8,7 @@ accelerate launch examples/wanvideo/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full" \
--output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \

View File

@@ -1,8 +1,9 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--height 720 \
--width 1280 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-4 \
@@ -12,4 +13,5 @@ accelerate launch examples/wanvideo/model_training/train.py \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image"
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload

View File

@@ -1,6 +1,6 @@
import torch, os, json
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task, wan_parser
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -107,4 +107,14 @@ if __name__ == "__main__":
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
)
launch_training_task(model, dataset, args=args)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt
)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task(
dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
)

View File

@@ -19,12 +19,13 @@ state_dict = load_state_dict("models/train/Wan2.1-I2V-14B-720P_full/epoch-1.safe
pipe.dit.load_state_dict(state_dict)
pipe.enable_vram_management()
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
input_image = VideoData("data/example_video_dataset/video1.mp4", height=720, width=1280)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=input_image,
height=720, width=1280, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5)

View File

@@ -13,7 +13,7 @@ pipe = WanVideoPipeline.from_pretrained(
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
state_dict = load_state_dict("models/train/VACE-Wan2.1-1.3B-Preview_full/epoch-1.safetensors")
state_dict = load_state_dict("models/train/Wan2.1-VACE-1.3B-Preview_full/epoch-1.safetensors")
pipe.vace.load_state_dict(state_dict)
pipe.enable_vram_management()

View File

@@ -18,12 +18,13 @@ pipe = WanVideoPipeline.from_pretrained(
pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-720P_lora/epoch-4.safetensors", alpha=1)
pipe.enable_vram_management()
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
input_image = VideoData("data/example_video_dataset/video1.mp4", height=720, width=1280)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=input_image,
height=720, width=1280, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5)

View File

@@ -12,3 +12,5 @@ protobuf
modelscope
ftfy
pynvml
pandas
accelerate