mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
57 Commits
refactor
...
value-cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba421a9ab9 | ||
|
|
6c30a7f080 | ||
|
|
1363a0559f | ||
|
|
9bb51fe879 | ||
|
|
d9c812818d | ||
|
|
c8e9a96196 | ||
|
|
6143af4654 | ||
|
|
9458e382b0 | ||
|
|
4f2d9226cf | ||
|
|
f688a469b1 | ||
|
|
c8ea3b3356 | ||
|
|
6e9472b470 | ||
|
|
a5c03c5272 | ||
|
|
8068ac2592 | ||
|
|
5f80e7ac5e | ||
|
|
157e0be49d | ||
|
|
3dbe271aab | ||
|
|
44e2eecdf1 | ||
|
|
8c226e83a6 | ||
|
|
009f26bb40 | ||
|
|
fcf2fbc07f | ||
|
|
b603acd36a | ||
|
|
6c8bb6438b | ||
|
|
8072d3839d | ||
|
|
c8ad643374 | ||
|
|
31f9df5e62 | ||
|
|
e2f415524a | ||
|
|
3eb7e7530e | ||
|
|
916aa54595 | ||
|
|
6ddbd43f7b | ||
|
|
a37a83ecc3 | ||
|
|
f2a0d0c85f | ||
|
|
93194f44e8 | ||
|
|
c4e5033532 | ||
|
|
cc6cd26733 | ||
|
|
1113d305d1 | ||
|
|
6d5f8b7423 | ||
|
|
1b3c204d20 | ||
|
|
1788d50f0a | ||
|
|
e7a21dbf0b | ||
|
|
3b3e1e4d44 | ||
|
|
24426e3a32 | ||
|
|
31369bab15 | ||
|
|
551721658b | ||
|
|
46f052375f | ||
|
|
c2d35a2157 | ||
|
|
4c052e42bc | ||
|
|
a88613555d | ||
|
|
c164519ef1 | ||
|
|
afff5ffb21 | ||
|
|
a8481fd5e1 | ||
|
|
8584e50309 | ||
|
|
7ad9b9aecc | ||
|
|
bd6f2695a9 | ||
|
|
9534a78167 | ||
|
|
4e00c109e3 | ||
|
|
a3a35acc7e |
@@ -42,6 +42,8 @@ Until now, DiffSynth-Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
- **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide.
|
||||
|
||||
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||
|
||||
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
@@ -64,6 +64,8 @@ from ..models.wan_video_vace import VaceWanModel
|
||||
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
|
||||
from ..models.flux_value_control import SingleValueEncoder
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -102,6 +104,7 @@ model_loader_configs = [
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "3ede90c44b2c161240b659f3b8393c9d", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
|
||||
@@ -413,7 +413,7 @@ class BertEncoder(nn.Module):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
13
diffsynth/lora/flux_lora.py
Normal file
13
diffsynth/lora/flux_lora.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
from diffsynth.lora import GeneralLoRALoader
|
||||
from diffsynth.models.lora import FluxLoRAFromCivitai
|
||||
|
||||
|
||||
class FluxLoRALoader(GeneralLoRALoader):
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.loader = FluxLoRAFromCivitai()
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
lora_prefix, model_resource = self.loader.match(model, state_dict_lora)
|
||||
self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource)
|
||||
58
diffsynth/models/flux_value_control.py
Normal file
58
diffsynth/models/flux_value_control.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
|
||||
|
||||
class MultiValueEncoder(torch.nn.Module):
|
||||
def __init__(self, encoders=()):
|
||||
super().__init__()
|
||||
self.encoders = torch.nn.ModuleList(encoders)
|
||||
|
||||
def __call__(self, values, dtype):
|
||||
emb = []
|
||||
for encoder, value in zip(self.encoders, values):
|
||||
if value is not None:
|
||||
value = value.unsqueeze(0)
|
||||
emb.append(encoder(value, dtype))
|
||||
emb = torch.concat(emb, dim=0)
|
||||
return emb
|
||||
|
||||
|
||||
class SingleValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
||||
super().__init__()
|
||||
self.prefer_len = prefer_len
|
||||
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
self.prefer_value_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
self.positional_embedding = torch.nn.Parameter(
|
||||
torch.randn(self.prefer_len, dim_in)
|
||||
)
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
last_linear = self.prefer_value_embedder[-1]
|
||||
torch.nn.init.zeros_(last_linear.weight)
|
||||
torch.nn.init.zeros_(last_linear.bias)
|
||||
|
||||
def forward(self, value, dtype):
|
||||
emb = self.prefer_proj(value).to(dtype)
|
||||
emb = emb.expand(self.prefer_len, -1)
|
||||
emb = emb + self.positional_embedding
|
||||
emb = self.prefer_value_embedder(emb)
|
||||
return emb
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SingleValueEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SingleValueEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -1373,7 +1373,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
if not has_default_max_length:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
|
||||
1052
diffsynth/pipelines/flux_image_new.py
Normal file
1052
diffsynth/pipelines/flux_image_new.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -379,7 +379,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
if num_frames % 4 != 1:
|
||||
num_frames = (num_frames + 2) // 4 * 4 + 1
|
||||
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
||||
print(f"Only `num_frames % 4 == 1` is acceptable. We round it up to {num_frames}.")
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
@@ -143,10 +143,8 @@ class BasePipeline(torch.nn.Module):
|
||||
self.vram_management_enabled = True
|
||||
|
||||
|
||||
def get_free_vram(self):
|
||||
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||
allocated_memory = torch.cuda.device_memory_used(self.device)
|
||||
return (total_memory - allocated_memory) / (1024 ** 3)
|
||||
def get_vram(self):
|
||||
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
@@ -168,20 +166,50 @@ class ModelConfig:
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
offload_dtype: Optional[torch.dtype] = None
|
||||
|
||||
def download_if_necessary(self, local_model_path="./models", skip_download=False):
|
||||
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
|
||||
if self.path is None:
|
||||
if self.model_id is None or self.origin_file_pattern is None:
|
||||
# Check model_id and origin_file_pattern
|
||||
if self.model_id is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
||||
|
||||
# Skip if not in rank 0
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
skip_download = dist.get_rank() != 0
|
||||
|
||||
# Check whether the origin path is a folder
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.origin_file_pattern = ""
|
||||
allow_file_pattern = None
|
||||
is_folder = True
|
||||
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
|
||||
allow_file_pattern = self.origin_file_pattern + "*"
|
||||
is_folder = True
|
||||
else:
|
||||
allow_file_pattern = self.origin_file_pattern
|
||||
is_folder = False
|
||||
|
||||
# Download
|
||||
if not skip_download:
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(local_model_path, self.model_id),
|
||||
allow_file_pattern=self.origin_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
|
||||
|
||||
# Let rank 1, 2, ... wait for rank 0
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
dist.barrier(device_ids=[dist.get_rank()])
|
||||
|
||||
# Return downloaded files
|
||||
if is_folder:
|
||||
self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern)
|
||||
else:
|
||||
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
|
||||
if isinstance(self.path, list) and len(self.path) == 1:
|
||||
self.path = self.path[0]
|
||||
|
||||
@@ -247,7 +275,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_free_vram()
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
if self.text_encoder is not None:
|
||||
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||
@@ -427,19 +455,21 @@ class WanVideoPipeline(BasePipeline):
|
||||
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
|
||||
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
|
||||
|
||||
# Initialize pipeline
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
if use_usp: pipe.initialize_usp()
|
||||
|
||||
# Download and load models
|
||||
model_manager = ModelManager()
|
||||
for model_config in model_configs:
|
||||
model_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
||||
model_config.download_if_necessary(local_model_path, skip_download=skip_download, use_usp=use_usp)
|
||||
model_manager.load_model(
|
||||
model_config.path,
|
||||
device=model_config.offload_device or device,
|
||||
torch_dtype=model_config.offload_dtype or torch_dtype
|
||||
)
|
||||
|
||||
# Initialize pipeline
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
if use_usp: pipe.initialize_usp()
|
||||
# Load models
|
||||
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||
@@ -608,11 +638,17 @@ class PipelineUnitRunner:
|
||||
elif unit.seperate_cfg:
|
||||
# Positive side
|
||||
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_posi.update(processor_outputs)
|
||||
# Negative side
|
||||
if inputs_shared["cfg_scale"] != 1:
|
||||
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
@@ -1150,17 +1186,20 @@ def model_fn_wan_video(
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
x = x + current_vace_hint * vace_scale
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
if reference_latents is not None:
|
||||
x = x[:, reference_latents.shape[1]:]
|
||||
f -= 1
|
||||
|
||||
x = dit.head(x, t)
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
# Remove reference latents
|
||||
if reference_latents is not None:
|
||||
x = x[:, reference_latents.shape[1]:]
|
||||
f -= 1
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import imageio, os, torch, warnings, torchvision, argparse
|
||||
import imageio, os, torch, warnings, torchvision, argparse, json
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
@@ -7,12 +7,139 @@ from accelerate import Accelerator
|
||||
|
||||
|
||||
|
||||
class ImageDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path=None, metadata_path=None,
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
data_file_keys=("image",),
|
||||
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||
repeat=1,
|
||||
args=None,
|
||||
):
|
||||
if args is not None:
|
||||
base_path = args.dataset_base_path
|
||||
metadata_path = args.dataset_metadata_path
|
||||
height = args.height
|
||||
width = args.width
|
||||
max_pixels = args.max_pixels
|
||||
data_file_keys = args.data_file_keys.split(",")
|
||||
repeat = args.dataset_repeat
|
||||
|
||||
self.base_path = base_path
|
||||
self.max_pixels = max_pixels
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.data_file_keys = data_file_keys
|
||||
self.image_file_extension = image_file_extension
|
||||
self.repeat = repeat
|
||||
|
||||
if height is not None and width is not None:
|
||||
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
||||
self.dynamic_resolution = False
|
||||
elif height is None and width is None:
|
||||
print("Height and width are none. Setting `dynamic_resolution` to True.")
|
||||
self.dynamic_resolution = True
|
||||
|
||||
if metadata_path is None:
|
||||
print("No metadata. Trying to generate it.")
|
||||
metadata = self.generate_metadata(base_path)
|
||||
print(f"{len(metadata)} lines in metadata.")
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
elif metadata_path.endswith(".json"):
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.data = metadata
|
||||
else:
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
|
||||
|
||||
def generate_metadata(self, folder):
|
||||
image_list, prompt_list = [], []
|
||||
file_set = set(os.listdir(folder))
|
||||
for file_name in file_set:
|
||||
if "." not in file_name:
|
||||
continue
|
||||
file_ext_name = file_name.split(".")[-1].lower()
|
||||
file_base_name = file_name[:-len(file_ext_name)-1]
|
||||
if file_ext_name not in self.image_file_extension:
|
||||
continue
|
||||
prompt_file_name = file_base_name + ".txt"
|
||||
if prompt_file_name not in file_set:
|
||||
continue
|
||||
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
|
||||
prompt = f.read().strip()
|
||||
image_list.append(file_name)
|
||||
prompt_list.append(prompt)
|
||||
metadata = pd.DataFrame()
|
||||
metadata["image"] = image_list
|
||||
metadata["prompt"] = prompt_list
|
||||
return metadata
|
||||
|
||||
|
||||
def crop_and_resize(self, image, target_height, target_width):
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||
return image
|
||||
|
||||
|
||||
def get_height_width(self, image):
|
||||
if self.dynamic_resolution:
|
||||
width, height = image.size
|
||||
if width * height > self.max_pixels:
|
||||
scale = (width * height / self.max_pixels) ** 0.5
|
||||
height, width = int(height / scale), int(width / scale)
|
||||
height = height // self.height_division_factor * self.height_division_factor
|
||||
width = width // self.width_division_factor * self.width_division_factor
|
||||
else:
|
||||
height, width = self.height, self.width
|
||||
return height, width
|
||||
|
||||
|
||||
def load_image(self, file_path):
|
||||
image = Image.open(file_path).convert("RGB")
|
||||
image = self.crop_and_resize(image, *self.get_height_width(image))
|
||||
return image
|
||||
|
||||
|
||||
def load_data(self, file_path):
|
||||
return self.load_image(file_path)
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
data = self.data[data_id % len(self.data)].copy()
|
||||
for key in self.data_file_keys:
|
||||
if key in data:
|
||||
path = os.path.join(self.base_path, data[key])
|
||||
data[key] = self.load_data(path)
|
||||
if data[key] is None:
|
||||
warnings.warn(f"cannot load file {data[key]}.")
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data) * self.repeat
|
||||
|
||||
|
||||
|
||||
class VideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path=None, metadata_path=None,
|
||||
frame_interval=1, num_frames=81,
|
||||
dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None,
|
||||
num_frames=81,
|
||||
time_division_factor=4, time_division_remainder=1,
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
data_file_keys=("video",),
|
||||
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||
@@ -25,17 +152,15 @@ class VideoDataset(torch.utils.data.Dataset):
|
||||
metadata_path = args.dataset_metadata_path
|
||||
height = args.height
|
||||
width = args.width
|
||||
max_pixels = args.max_pixels
|
||||
num_frames = args.num_frames
|
||||
data_file_keys = args.data_file_keys.split(",")
|
||||
repeat = args.dataset_repeat
|
||||
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
|
||||
self.base_path = base_path
|
||||
self.frame_interval = frame_interval
|
||||
self.num_frames = num_frames
|
||||
self.dynamic_resolution = dynamic_resolution
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
self.max_pixels = max_pixels
|
||||
self.height = height
|
||||
self.width = width
|
||||
@@ -46,9 +171,48 @@ class VideoDataset(torch.utils.data.Dataset):
|
||||
self.video_file_extension = video_file_extension
|
||||
self.repeat = repeat
|
||||
|
||||
if height is not None and width is not None and dynamic_resolution == True:
|
||||
if height is not None and width is not None:
|
||||
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
||||
self.dynamic_resolution = False
|
||||
elif height is None and width is None:
|
||||
print("Height and width are none. Setting `dynamic_resolution` to True.")
|
||||
self.dynamic_resolution = True
|
||||
|
||||
if metadata_path is None:
|
||||
print("No metadata. Trying to generate it.")
|
||||
metadata = self.generate_metadata(base_path)
|
||||
print(f"{len(metadata)} lines in metadata.")
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
elif metadata_path.endswith(".json"):
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.data = metadata
|
||||
else:
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
|
||||
|
||||
def generate_metadata(self, folder):
|
||||
video_list, prompt_list = [], []
|
||||
file_set = set(os.listdir(folder))
|
||||
for file_name in file_set:
|
||||
if "." not in file_name:
|
||||
continue
|
||||
file_ext_name = file_name.split(".")[-1].lower()
|
||||
file_base_name = file_name[:-len(file_ext_name)-1]
|
||||
if file_ext_name not in self.image_file_extension and file_ext_name not in self.video_file_extension:
|
||||
continue
|
||||
prompt_file_name = file_base_name + ".txt"
|
||||
if prompt_file_name not in file_set:
|
||||
continue
|
||||
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
|
||||
prompt = f.read().strip()
|
||||
video_list.append(file_name)
|
||||
prompt_list.append(prompt)
|
||||
metadata = pd.DataFrame()
|
||||
metadata["video"] = video_list
|
||||
metadata["prompt"] = prompt_list
|
||||
return metadata
|
||||
|
||||
|
||||
def crop_and_resize(self, image, target_height, target_width):
|
||||
@@ -75,15 +239,22 @@ class VideoDataset(torch.utils.data.Dataset):
|
||||
height, width = self.height, self.width
|
||||
return height, width
|
||||
|
||||
|
||||
def get_num_frames(self, reader):
|
||||
num_frames = self.num_frames
|
||||
if int(reader.count_frames()) < num_frames:
|
||||
num_frames = int(reader.count_frames())
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
|
||||
def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames):
|
||||
def load_video(self, file_path):
|
||||
reader = imageio.get_reader(file_path)
|
||||
if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||
reader.close()
|
||||
return None
|
||||
num_frames = self.get_num_frames(reader)
|
||||
frames = []
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||
frame = reader.get_data(frame_id)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
|
||||
frames.append(frame)
|
||||
@@ -94,11 +265,7 @@ class VideoDataset(torch.utils.data.Dataset):
|
||||
def load_image(self, file_path):
|
||||
image = Image.open(file_path).convert("RGB")
|
||||
image = self.crop_and_resize(image, *self.get_height_width(image))
|
||||
return image
|
||||
|
||||
|
||||
def load_video(self, file_path):
|
||||
frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames)
|
||||
frames = [image]
|
||||
return frames
|
||||
|
||||
|
||||
@@ -182,34 +349,52 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
|
||||
|
||||
|
||||
def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate=1e-4, num_epochs=1, output_path="./models", remove_prefix_in_ckpt=None, args=None):
|
||||
if args is not None:
|
||||
learning_rate = args.learning_rate
|
||||
num_epochs = args.num_epochs
|
||||
output_path = args.output_path
|
||||
remove_prefix_in_ckpt = args.remove_prefix_in_ckpt
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
class ModelLogger:
|
||||
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
self.state_dict_converter = state_dict_converter
|
||||
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=1)
|
||||
def on_step_end(self, loss):
|
||||
pass
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
|
||||
def launch_training_task(
|
||||
dataset: torch.utils.data.Dataset,
|
||||
model: DiffusionTrainingModule,
|
||||
model_logger: ModelLogger,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
||||
num_epochs: int = 1,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
):
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
|
||||
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for epoch_id in range(num_epochs):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
path = os.path.join(output_path, f"epoch-{epoch}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
model_logger.on_step_end(loss)
|
||||
scheduler.step()
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
|
||||
|
||||
|
||||
@@ -228,8 +413,9 @@ def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_pat
|
||||
|
||||
def wan_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Path to the metadata file of the dataset.")
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||
parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..")
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||
@@ -247,5 +433,33 @@ def wan_parser():
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
def flux_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..")
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
return parser
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .layers import *
|
||||
from .gradient_checkpointing import *
|
||||
|
||||
34
diffsynth/vram_management/gradient_checkpointing.py
Normal file
34
diffsynth/vram_management/gradient_checkpointing.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
|
||||
def gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
model_output = model(*args, **kwargs)
|
||||
return model_output
|
||||
@@ -13,7 +13,8 @@ class AutoTorchModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
def check_free_vram(self):
|
||||
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3)
|
||||
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
def offload(self):
|
||||
|
||||
318
examples/flux/README.md
Normal file
318
examples/flux/README.md
Normal file
@@ -0,0 +1,318 @@
|
||||
# FLUX
|
||||
|
||||
[切换到中文](./README_zh.md)
|
||||
|
||||
FLUX is a series of image generation models open-sourced by Black-Forest-Labs.
|
||||
|
||||
**DiffSynth-Studio has introduced a new inference and training framework. If you need to use the old version, please click [here](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c).**
|
||||
|
||||
## Installation
|
||||
|
||||
Before using these models, please install DiffSynth-Studio from source code:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
You can quickly load the FLUX.1-dev model and perform inference by running the following code:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
**Support for the new framework of the FLUX series models is under active development. Stay tuned!**
|
||||
|
||||
| Model ID | Additional Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|
||||
## Model Inference
|
||||
|
||||
The following sections will help you understand our features and write inference code.
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Loading Models</summary>
|
||||
|
||||
Models are loaded using `from_pretrained`:
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
Here, `torch_dtype` and `device` refer to the computation precision and device, respectively. The `model_configs` can be configured in various ways to specify model paths:
|
||||
|
||||
* Download the model from [ModelScope Community](https://modelscope.cn/) and load it. In this case, provide `model_id` and `origin_file_pattern`, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||
```
|
||||
|
||||
* Load the model from a local file path. In this case, provide the `path`, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")
|
||||
```
|
||||
|
||||
For models that consist of multiple files, use a list as follows:
|
||||
|
||||
```python
|
||||
ModelConfig(path=[
|
||||
"models/xxx/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"models/xxx/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"models/xxx/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
])
|
||||
```
|
||||
|
||||
The `from_pretrained` method also provides additional parameters to control model loading behavior:
|
||||
|
||||
* `local_model_path`: Path for saving downloaded models. The default is `"./models"`.
|
||||
* `skip_download`: Whether to skip downloading models. The default is `False`. If your network cannot access [ModelScope Community](https://modelscope.cn/), manually download the required files and set this to `True`.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>VRAM Management</summary>
|
||||
|
||||
DiffSynth-Studio provides fine-grained VRAM management for FLUX models, enabling inference on devices with limited VRAM. You can enable offloading functionality via the following code, which moves certain modules to system memory on devices with limited GPU memory.
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
The `enable_vram_management` function provides the following parameters to control VRAM usage:
|
||||
|
||||
* `vram_limit`: VRAM usage limit in GB. By default, it uses the remaining VRAM available on the device. Note that this is not an absolute limit; if the set VRAM is insufficient but more VRAM is actually available, the model will run with minimal VRAM consumption. Setting it to 0 achieves the theoretical minimum VRAM usage.
|
||||
* `vram_buffer`: VRAM buffer size in GB. The default is 0.5GB. Since some large neural network layers may consume extra VRAM during onload phases, a VRAM buffer is necessary. Ideally, the optimal value should match the VRAM occupied by the largest layer in the model.
|
||||
* `num_persistent_param_in_dit`: Number of persistent parameters in the DiT model (default: no limit). We plan to remove this parameter in the future, so please avoid relying on it.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Inference Acceleration</summary>
|
||||
|
||||
* TeaCache: Acceleration technique [TeaCache](https://github.com/ali-vilab/TeaCache), please refer to the [sample code](./acceleration/teacache.py).
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Input Parameters</summary>
|
||||
|
||||
The pipeline accepts the following input parameters during inference:
|
||||
|
||||
* `prompt`: Prompt describing what should appear in the image.
|
||||
* `negative_prompt`: Negative prompt describing what should **not** appear in the image. Default is `""`.
|
||||
* `cfg_scale`: Classifier-free guidance scale. Default is 1. It becomes effective when set to a value greater than 1.
|
||||
* `embedded_guidance`: Embedded guidance parameter for FLUX-dev. Default is 3.5.
|
||||
* `t5_sequence_length`: Sequence length of T5 text embeddings. Default is 512.
|
||||
* `input_image`: Input image used for image-to-image generation. This works together with `denoising_strength`.
|
||||
* `denoising_strength`: Denoising strength, ranging from 0 to 1. Default is 1. When close to 0, the generated image will be similar to the input image; when close to 1, the generated image will differ significantly from the input. Do not set this to a non-1 value if no `input_image` is provided.
|
||||
* `height`: Height of the generated image. Must be a multiple of 16.
|
||||
* `width`: Width of the generated image. Must be a multiple of 16.
|
||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||
* `rand_device`: Device for generating random Gaussian noise. Default is `"cpu"`. Setting it to `"cuda"` may lead to different results across GPUs.
|
||||
* `sigma_shift`: Parameter from Rectified Flow theory. Default is 3. A larger value increases the number of steps spent at the beginning of denoising and can improve image quality. However, it may cause inconsistencies between the generation process and training data.
|
||||
* `num_inference_steps`: Number of inference steps. Default is 30.
|
||||
* `kontext_images`: Input images for the Kontext model.
|
||||
* `controlnet_inputs`: Inputs for the ControlNet model.
|
||||
* `ipadapter_images`: Input images for the IP-Adapter model.
|
||||
* `ipadapter_scale`: Control strength of the IP-Adapter model.
|
||||
|
||||
</details>
|
||||
|
||||
## Model Training
|
||||
|
||||
FLUX series models are trained using a unified script [`./model_training/train.py`](./model_training/train.py).
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Script Parameters</summary>
|
||||
|
||||
The script supports the following parameters:
|
||||
|
||||
* Dataset
|
||||
* `--dataset_base_path`: Root path to the dataset.
|
||||
* `--dataset_metadata_path`: Path to the metadata file of the dataset.
|
||||
* `--max_pixels`: Maximum pixel area, default is 1024*1024. When dynamic resolution is enabled, any image with a resolution larger than this value will be scaled down.。
|
||||
* `--height`: Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--width`: Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--data_file_keys`: Keys in metadata for data files. Comma-separated.
|
||||
* `--dataset_repeat`: Number of times the dataset repeats per epoch.
|
||||
* Models
|
||||
* `--model_paths`: Paths to load models. JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Comma-separated.
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of training epochs.
|
||||
* `--output_path`: Output path for saving checkpoints.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint filenames.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which base model to apply LoRA on.
|
||||
* `--lora_target_modules`: Which layers to apply LoRA on.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* Extra Inputs
|
||||
* `--extra_inputs`: Additional model inputs. Comma-separated.
|
||||
* VRAM Management
|
||||
* `--use_gradient_checkpointing`: Whether to use gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||
* `--gradient_accumulation_steps`: Number of steps for gradient accumulation.
|
||||
* Miscellaneous
|
||||
* `--align_to_opensource_format`: Whether to align the FLUX DiT LoRA format with the open-source version. Only applicable to LoRA training for FLUX.1-dev and FLUX.1-Kontext-dev.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 1: Prepare Dataset</summary>
|
||||
|
||||
The dataset contains a series of files. We recommend organizing your dataset files as follows:
|
||||
|
||||
```
|
||||
data/example_image_dataset/
|
||||
├── metadata.csv
|
||||
├── image1.jpg
|
||||
└── image2.jpg
|
||||
```
|
||||
|
||||
Here, `image1.jpg`, `image2.jpg` are training image data, and `metadata.csv` is the metadata list, for example:
|
||||
|
||||
```
|
||||
image,prompt
|
||||
image1.jpg,"a cat is sleeping"
|
||||
image2.jpg,"a dog is running"
|
||||
```
|
||||
|
||||
We have built a sample image dataset to help you test more conveniently. You can download this dataset using the following command:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
The dataset supports multiple image formats: `"jpg", "jpeg", "png", "webp"`.
|
||||
|
||||
The image resolution can be controlled via script parameters `--height` and `--width`. When both `--height` and `--width` are left empty, dynamic resolution will be enabled, allowing training with the actual width and height of each image in the dataset.
|
||||
|
||||
**We strongly recommend using fixed-resolution training, as there may be load-balancing issues in multi-GPU training with dynamic resolution.**
|
||||
|
||||
When the model requires additional inputs—for instance, `kontext_images` required by the controllable model [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)—please add corresponding columns in the dataset, for example:
|
||||
|
||||
```
|
||||
image,prompt,kontext_images
|
||||
image1.jpg,"a cat is sleeping",image1_reference.jpg
|
||||
```
|
||||
|
||||
If additional inputs include image files, you need to specify the column names to parse using the `--data_file_keys` parameter. You can add more column names accordingly, e.g., `--data_file_keys "image,kontext_images"`.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 2: Load Model</summary>
|
||||
|
||||
Similar to the model loading logic during inference, you can directly configure the model to be loaded using its model ID. For example, during inference we load the model with the following configuration:
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
]
|
||||
```
|
||||
|
||||
Then during training, simply provide the following parameter to load the corresponding model:
|
||||
|
||||
```shell
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"
|
||||
```
|
||||
|
||||
If you prefer to load the model from local files, as in the inference example:
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder_2/"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/ae.safetensors"),
|
||||
]
|
||||
```
|
||||
|
||||
Then during training, set it up as follows:
|
||||
|
||||
```shell
|
||||
--model_paths '[
|
||||
"models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors",
|
||||
"models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/black-forest-labs/FLUX.1-dev/text_encoder_2/",
|
||||
"models/black-forest-labs/FLUX.1-dev/ae.safetensors"
|
||||
]' \
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 3: Configure Trainable Modules</summary>
|
||||
|
||||
The training framework supports both full-model training and LoRA-based fine-tuning. Below are some examples:
|
||||
|
||||
* Full training of the DiT module: `--trainable_models dit`
|
||||
* Training a LoRA model on the DiT module: `--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
|
||||
|
||||
Additionally, since the training script loads multiple modules (text encoder, DiT, VAE), you need to remove prefixes when saving the model files. For example, when performing full DiT training or LoRA training on the DiT module, please set `--remove_prefix_in_ckpt pipe.dit.`
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 4: Launch the Training Script</summary>
|
||||
|
||||
We have written specific training commands for each model. Please refer to the table at the beginning of this document for details.
|
||||
|
||||
</details>
|
||||
327
examples/flux/README_zh.md
Normal file
327
examples/flux/README_zh.md
Normal file
@@ -0,0 +1,327 @@
|
||||
# FLUX
|
||||
|
||||
[Switch to English](./README.md)
|
||||
|
||||
FLUX 是由 Black-Forest-Labs 开源的一系列图像生成模型。
|
||||
|
||||
**DiffSynth-Studio 启用了新的推理和训练框架,如需使用旧版本,请点击[这里](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c)。**
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本系列模型之前,请通过源码安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
通过运行以下代码可以快速加载 FLUX.1-dev 模型并进行推理。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
**FLUX 系列模型的全新框架支持正在开发中,敬请期待!**
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|
||||
## 模型推理
|
||||
|
||||
以下部分将会帮助您理解我们的功能并编写推理代码。
|
||||
|
||||
<details>
|
||||
|
||||
<summary>加载模型</summary>
|
||||
|
||||
模型通过 `from_pretrained` 加载:
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
其中 `torch_dtype` 和 `device` 是计算精度和计算设备。`model_configs` 可通过多种方式配置模型路径:
|
||||
|
||||
* 从[魔搭社区](https://modelscope.cn/)下载模型并加载。此时需要填写 `model_id` 和 `origin_file_pattern`,例如
|
||||
|
||||
```python
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||
```
|
||||
|
||||
* 从本地文件路径加载模型。此时需要填写 `path`,例如
|
||||
|
||||
```python
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")
|
||||
```
|
||||
|
||||
对于从多个文件加载的单一模型,使用列表即可,例如
|
||||
|
||||
```python
|
||||
ModelConfig(path=[
|
||||
"models/xxx/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"models/xxx/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"models/xxx/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
])
|
||||
```
|
||||
|
||||
`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为:
|
||||
|
||||
* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。
|
||||
* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>显存管理</summary>
|
||||
|
||||
DiffSynth-Studio 为 FLUX 模型提供了细粒度的显存管理,让模型能够在低显存设备上进行推理,可通过以下代码开启 offload 功能,在显存有限的设备上将部分模块 offload 到内存中。
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
`enable_vram_management` 函数提供了以下参数,用于控制显存使用情况:
|
||||
|
||||
* `vram_limit`: 显存占用量(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。将其设置为0时,将会实现理论最小显存占用。
|
||||
* `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
|
||||
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>推理加速</summary>
|
||||
|
||||
* TeaCache:加速技术 [TeaCache](https://github.com/ali-vilab/TeaCache),请参考[示例代码](./acceleration/teacache.py)。
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>输入参数</summary>
|
||||
|
||||
Pipeline 在推理阶段能够接收以下输入参数:
|
||||
|
||||
* `prompt`: 提示词,描述画面中出现的内容。
|
||||
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
|
||||
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1,当设置为大于1的数值时生效。
|
||||
* `embedded_guidance`: FLUX-dev 的内嵌引导参数,默认值为 3.5。
|
||||
* `t5_sequence_length`: T5 模型的文本向量序列长度,默认值为 512。
|
||||
* `input_image`: 输入图像,用于图生图,该参数与 `denoising_strength` 配合使用。
|
||||
* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1,当数值接近 0 时,生成图像与输入图像相似;当数值接近 1 时,生成图像与输入图像相差更大。在不输入 `input_image` 参数时,请不要将其设置为非 1 的数值。
|
||||
* `height`: 图像高度,需保证高度为 16 的倍数。
|
||||
* `width`: 图像宽度,需保证宽度为 16 的倍数。
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `sigma_shift`: Rectified Flow 理论中的参数,默认为 3。数值越大,模型在去噪的开始阶段停留的步骤数越多,可适当调大这个参数来提高画面质量,但会因生成过程与训练过程不一致导致生成的图像内容与训练数据存在差异。
|
||||
* `num_inference_steps`: 推理次数,默认值为 30。
|
||||
* `kontext_images`: Kontext 模型的输入图像。
|
||||
* `controlnet_inputs`: ControlNet 模型的输入。
|
||||
* `ipadapter_images`: IP-Adapter 模型的输入图像。
|
||||
* `ipadapter_scale`: IP-Adapter 模型的控制强度。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## 模型训练
|
||||
|
||||
FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_training/train.py) 脚本进行。
|
||||
|
||||
<details>
|
||||
|
||||
<summary>脚本参数</summary>
|
||||
|
||||
脚本包含以下参数:
|
||||
|
||||
* 数据集
|
||||
* `--dataset_base_path`: 数据集的根路径。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--max_pixels`: 最大像素面积,默认为 1024*1024,当启用动态分辨率时,任何分辨率大于这个数值的图片都会被缩小。
|
||||
* `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* 模型
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)数量。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* 可训练模块
|
||||
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪一层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* 额外模型输入
|
||||
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
|
||||
* 显存管理
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 其他
|
||||
* `--align_to_opensource_format`: 是否将 FLUX DiT LoRA 的格式与开源版本对齐,仅对 FLUX.1-dev 和 FLUX.1-Kontext-dev 的 LoRA 训练生效。
|
||||
|
||||
|
||||
此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 1: 准备数据集</summary>
|
||||
|
||||
数据集包含一系列文件,我们建议您这样组织数据集文件:
|
||||
|
||||
```
|
||||
data/example_image_dataset/
|
||||
├── metadata.csv
|
||||
├── image1.jpg
|
||||
└── image2.jpg
|
||||
```
|
||||
|
||||
其中 `image1.jpg`、`image2.jpg` 为训练用图像数据,`metadata.csv` 为元数据列表,例如
|
||||
|
||||
```
|
||||
image,prompt
|
||||
image1.jpg,"a cat is sleeping"
|
||||
image2.jpg,"a dog is running"
|
||||
```
|
||||
|
||||
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
数据集支持多种图片格式,`"jpg", "jpeg", "png", "webp"`。
|
||||
|
||||
图片的尺寸可通过脚本参数 `--height`、`--width` 控制。当 `--height` 和 `--width` 为空时将会开启动态分辨率,按照数据集中每个图像的实际宽高训练。
|
||||
|
||||
**我们强烈建议使用固定分辨率训练,因为在多卡训练中存在负载均衡问题。**
|
||||
|
||||
当模型需要额外输入时,例如具备控制能力的模型 [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) 所需的 `kontext_images`,请在数据集中补充相应的列,例如:
|
||||
|
||||
```
|
||||
image,prompt,kontext_images
|
||||
image1.jpg,"a cat is sleeping",image1_reference.jpg
|
||||
```
|
||||
|
||||
额外输入若包含图像文件,则需要在 `--data_file_keys` 参数中指定要解析的列名。可根据额外输入增加相应的列名,例如 `--data_file_keys "image,kontext_images"`,同时启用 `--extra_inputs "kontext_images"`。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 2: 加载模型</summary>
|
||||
|
||||
类似于推理时的模型加载逻辑,可直接通过模型 ID 配置要加载的模型。例如,推理时我们通过以下设置加载模型
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
]
|
||||
```
|
||||
|
||||
那么在训练时,填入以下参数即可加载对应的模型。
|
||||
|
||||
```shell
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"
|
||||
```
|
||||
|
||||
如果您希望从本地文件加载模型,例如推理时
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder_2/"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/ae.safetensors"),
|
||||
]
|
||||
```
|
||||
|
||||
那么训练时需设置为
|
||||
|
||||
```shell
|
||||
--model_paths '[
|
||||
"models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors",
|
||||
"models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/black-forest-labs/FLUX.1-dev/text_encoder_2/",
|
||||
"models/black-forest-labs/FLUX.1-dev/ae.safetensors"
|
||||
]' \
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 3: 设置可训练模块</summary>
|
||||
|
||||
训练框架支持训练基础模型,或 LoRA 模型。以下是几个例子:
|
||||
|
||||
* 全量训练 DiT 部分:`--trainable_models dit`
|
||||
* 训练 DiT 部分的 LoRA 模型:`--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
|
||||
|
||||
此外,由于训练脚本中加载了多个模块(text encoder、dit、vae),保存模型文件时需要移除前缀,例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时,请设置 `--remove_prefix_in_ckpt pipe.dit.`
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 4: 启动训练程序</summary>
|
||||
|
||||
我们为每一个模型编写了训练命令,请参考本文档开头的表格。
|
||||
|
||||
</details>
|
||||
24
examples/flux/acceleration/teacache.py
Normal file
24
examples/flux/acceleration/teacache.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
|
||||
for tea_cache_l1_thresh in [None, 0.2, 0.4, 0.6, 0.8]:
|
||||
image = pipe(
|
||||
prompt=prompt, embedded_guidance=3.5, seed=0,
|
||||
num_inference_steps=50, tea_cache_l1_thresh=tea_cache_l1_thresh
|
||||
)
|
||||
image.save(f"image_{tea_cache_l1_thresh}.png")
|
||||
147
examples/flux/model_inference/EliGen.py
Normal file
147
examples/flux/model_inference/EliGen.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import random
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from diffsynth import download_customized_models
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=3.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
download_from_modelscope = True
|
||||
if download_from_modelscope:
|
||||
model_id = "DiffSynth-Studio/Eligen"
|
||||
downloading_priority = ["ModelScope"]
|
||||
else:
|
||||
model_id = "modelscope/EliGen"
|
||||
downloading_priority = ["HuggingFace"]
|
||||
EliGen_path = download_customized_models(
|
||||
model_id=model_id,
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control",
|
||||
downloading_priority=downloading_priority)[0]
|
||||
pipe.load_lora(pipe.dit, EliGen_path, alpha=1)
|
||||
|
||||
# example 1
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 6
|
||||
global_prompt = "Snow White and the 6 Dwarfs."
|
||||
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
|
||||
example(pipe, [8], 6, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
50
examples/flux/model_inference/FLEX.2-preview.py
Normal file
50
examples/flux/model_inference/FLEX.2-preview.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
seed=0
|
||||
)
|
||||
image.save(f"image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[200:400, 400:700] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save(f"image_mask.jpg")
|
||||
|
||||
inpaint_image = image
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
||||
seed=4
|
||||
)
|
||||
image.save(f"image_2_new.jpg")
|
||||
|
||||
control_image = Annotator("canny")(image)
|
||||
control_image.save("image_control.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_control_image=control_image,
|
||||
seed=4
|
||||
)
|
||||
image.save(f"image_3_new.jpg")
|
||||
54
examples/flux/model_inference/FLUX.1-Kontext-dev.py
Normal file
54
examples/flux/model_inference/FLUX.1-Kontext-dev.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian long-haired female college student.",
|
||||
embedded_guidance=2.5,
|
||||
seed=1,
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="transform the style to anime style.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=2,
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
|
||||
image_3 = pipe(
|
||||
prompt="let her smile.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=3,
|
||||
)
|
||||
image_3.save("image_3.jpg")
|
||||
|
||||
image_4 = pipe(
|
||||
prompt="let the girl play basketball.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=4,
|
||||
)
|
||||
image_4.save("image_4.jpg")
|
||||
|
||||
image_5 = pipe(
|
||||
prompt="move the girl to a park, let her sit on a chair.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=5,
|
||||
)
|
||||
image_5.save("image_5.jpg")
|
||||
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a cat sitting on a chair",
|
||||
height=1024, width=1024,
|
||||
seed=8, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[100:350, 350: -300] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save("mask.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a cat sitting on a chair, wearing sunglasses",
|
||||
controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)],
|
||||
height=1024, width=1024,
|
||||
seed=9, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
from diffsynth import download_models
|
||||
|
||||
|
||||
|
||||
download_models(["Annotators:Depth"])
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, summer",
|
||||
height=1024, width=1024,
|
||||
seed=6, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_canny = Annotator("canny")(image_1)
|
||||
image_depth = Annotator("depth")(image_1)
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, winter",
|
||||
controlnet_inputs=[
|
||||
ControlNetInput(image=image_canny, scale=0.3, processor_id="canny"),
|
||||
ControlNetInput(image=image_depth, scale=0.3, processor_id="depth"),
|
||||
],
|
||||
height=1024, width=1024,
|
||||
seed=7, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
height=768, width=768,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_1 = image_1.resize((2048, 2048))
|
||||
image_2 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)],
|
||||
input_image=image_1,
|
||||
denoising_strength=0.99,
|
||||
height=2048, width=2048, tiled=True,
|
||||
seed=1, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
24
examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py
Normal file
24
examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
|
||||
ModelConfig(model_id="google/siglip-so400m-patch14-384"),
|
||||
],
|
||||
)
|
||||
|
||||
origin_prompt = "a rabbit in a garden, colorful flowers"
|
||||
image = pipe(prompt=origin_prompt, height=1280, width=960, seed=42)
|
||||
image.save("style image.jpg")
|
||||
|
||||
image = pipe(prompt="A piggy", height=1280, width=960, seed=42,
|
||||
ipadapter_images=[image], ipadapter_scale=0.7)
|
||||
image.save("A piggy.jpg")
|
||||
59
examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py
Normal file
59
examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from modelscope import dataset_snapshot_download
|
||||
from modelscope import snapshot_download
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
snapshot_download(
|
||||
"ByteDance/InfiniteYou",
|
||||
allow_file_pattern="supports/insightface/models/antelopev2/*",
|
||||
local_dir="models/ByteDance/InfiniteYou",
|
||||
)
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/infiniteyou/*",
|
||||
)
|
||||
|
||||
height, width = 1024, 1024
|
||||
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
|
||||
controlnet_inputs = [ControlNetInput(image=controlnet_image, scale=1.0, processor_id="None")]
|
||||
|
||||
prompt = "A man, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/man.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
controlnet_inputs=controlnet_inputs,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=height, width=width,
|
||||
)
|
||||
image.save("man.jpg")
|
||||
|
||||
prompt = "A woman, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/woman.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
controlnet_inputs=controlnet_inputs,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=height, width=width,
|
||||
)
|
||||
image.save("woman.jpg")
|
||||
20
examples/flux/model_inference/FLUX.1-dev-ValueControl.py
Normal file
20
examples/flux/model_inference/FLUX.1-dev-ValueControl.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
from diffsynth.models.flux_value_control import SingleValueEncoder, MultiValueEncoder
|
||||
pipe.value_controller = MultiValueEncoder(encoders=[SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder()]).to(dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
image = pipe(prompt="a cat", seed=0, value_controller_inputs=[0.5, 0.5, 1, 0])
|
||||
image.save("flux.jpg")
|
||||
26
examples/flux/model_inference/FLUX.1-dev.py
Normal file
26
examples/flux/model_inference/FLUX.1-dev.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
image = pipe(prompt=prompt, seed=0)
|
||||
image.save("flux.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
seed=0, cfg_scale=2, num_inference_steps=50,
|
||||
)
|
||||
image.save("flux_cfg.jpg")
|
||||
32
examples/flux/model_inference/Step1X-Edit.py
Normal file
32
examples/flux/model_inference/Step1X-Edit.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
|
||||
image = pipe(
|
||||
prompt="draw red flowers in Chinese ink painting style",
|
||||
step1x_reference_image=image,
|
||||
width=832, height=1248, cfg_scale=6,
|
||||
seed=1, rand_device='cuda'
|
||||
)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="add more flowers in Chinese ink painting style",
|
||||
step1x_reference_image=image,
|
||||
width=832, height=1248, cfg_scale=6,
|
||||
seed=2, rand_device='cuda'
|
||||
)
|
||||
image.save("image_2.jpg")
|
||||
14
examples/flux/model_training/full/FLUX.1-Kontext-dev.sh
Normal file
14
examples/flux/model_training/full/FLUX.1-Kontext-dev.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
|
||||
--data_file_keys "image,kontext_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-Kontext-dev_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "kontext_images" \
|
||||
--use_gradient_checkpointing
|
||||
14
examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh
Normal file
14
examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_ipadapter.csv \
|
||||
--data_file_keys "image,ipadapter_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.ipadapter." \
|
||||
--output_path "./models/train/FLUX.1-dev-IP-Adapter_full" \
|
||||
--trainable_models "ipadapter" \
|
||||
--extra_inputs "ipadapter_images" \
|
||||
--use_gradient_checkpointing
|
||||
12
examples/flux/model_training/full/FLUX.1-dev.sh
Normal file
12
examples/flux/model_training/full/FLUX.1-dev.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
22
examples/flux/model_training/full/accelerate_config.yaml
Normal file
22
examples/flux/model_training/full/accelerate_config.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
17
examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh
Normal file
17
examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
|
||||
--data_file_keys "image,kontext_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-Kontext-dev_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--extra_inputs "kontext_images" \
|
||||
--use_gradient_checkpointing
|
||||
15
examples/flux/model_training/lora/FLUX.1-dev.sh
Normal file
15
examples/flux/model_training/lora/FLUX.1-dev.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
117
examples/flux/model_training/train.py
Normal file
117
examples/flux/model_training/train.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch, os, json
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
|
||||
class FluxTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
||||
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||
|
||||
# Reset training scheduler
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
# Freeze untrainable models
|
||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
model = self.add_lora_to_model(
|
||||
getattr(self.pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
)
|
||||
setattr(self.pipe, lora_base_model, model)
|
||||
|
||||
# Store other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
# CFG-sensitive parameters
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {}
|
||||
|
||||
# CFG-unsensitive parameters
|
||||
inputs_shared = {
|
||||
# Assume you are using this pipeline for inference,
|
||||
# please fill in the input parameters.
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"cfg_scale": 1,
|
||||
"embedded_guidance": 1,
|
||||
"t5_sequence_length": 512,
|
||||
"tiled": False,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
for extra_input in self.extra_inputs:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
|
||||
# Pipeline units will automatically process the input parameters.
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
return {**inputs_shared, **inputs_posi}
|
||||
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
if inputs is None: inputs = self.forward_preprocess(data)
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = flux_parser()
|
||||
args = parser.parse_args()
|
||||
dataset = ImageDataset(args=args)
|
||||
model = FluxTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
)
|
||||
120
examples/flux/model_training/train_value_controller.py
Normal file
120
examples/flux/model_training/train_value_controller.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch, os, json
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.models.flux_value_control import SingleValueEncoder, MultiValueEncoder
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
|
||||
class FluxTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
||||
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||
|
||||
self.pipe.value_controller = MultiValueEncoder(encoders=[SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder()]).to(dtype=torch.bfloat16)
|
||||
|
||||
# Reset training scheduler
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
# Freeze untrainable models
|
||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
model = self.add_lora_to_model(
|
||||
getattr(self.pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
)
|
||||
setattr(self.pipe, lora_base_model, model)
|
||||
|
||||
# Store other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
# CFG-sensitive parameters
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {}
|
||||
|
||||
# CFG-unsensitive parameters
|
||||
inputs_shared = {
|
||||
# Assume you are using this pipeline for inference,
|
||||
# please fill in the input parameters.
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"cfg_scale": 1,
|
||||
"embedded_guidance": 1,
|
||||
"t5_sequence_length": 512,
|
||||
"tiled": False,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
for extra_input in self.extra_inputs:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
|
||||
# Pipeline units will automatically process the input parameters.
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
return {**inputs_shared, **inputs_posi}
|
||||
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
if inputs is None: inputs = self.forward_preprocess(data)
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = flux_parser()
|
||||
args = parser.parse_args()
|
||||
dataset = ImageDataset(args=args)
|
||||
model = FluxTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLUX.1-Kontext-dev_full/epoch-0.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
image = pipe(
|
||||
prompt="Make the dog turn its head around.",
|
||||
kontext_images=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)),
|
||||
height=768, width=768,
|
||||
seed=0
|
||||
)
|
||||
image.save("image_FLUX.1-Kontext-dev_full.jpg")
|
||||
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
|
||||
ModelConfig(model_id="google/siglip-so400m-patch14-384"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLUX.1-dev-IP-Adapter_full/epoch-0.safetensors")
|
||||
pipe.ipadapter.load_state_dict(state_dict)
|
||||
|
||||
image = pipe(
|
||||
prompt="a dog",
|
||||
ipadapter_images=Image.open("data/example_image_dataset/1.jpg"),
|
||||
height=768, width=768,
|
||||
seed=0
|
||||
)
|
||||
image.save("image_FLUX.1-dev-IP-Adapter_full.jpg")
|
||||
20
examples/flux/model_training/validate_full/FLUX.1-dev.py
Normal file
20
examples/flux/model_training/validate_full/FLUX.1-dev.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLUX.1-dev_full/epoch-0.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
image = pipe(prompt="a dog", seed=0)
|
||||
image.save("image_FLUX.1-dev_full.jpg")
|
||||
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-Kontext-dev_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
image = pipe(
|
||||
prompt="Make the dog turn its head around.",
|
||||
kontext_images=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)),
|
||||
height=768, width=768,
|
||||
seed=0
|
||||
)
|
||||
image.save("image_FLUX.1-Kontext-dev_lora.jpg")
|
||||
18
examples/flux/model_training/validate_lora/FLUX.1-dev.py
Normal file
18
examples/flux/model_training/validate_lora/FLUX.1-dev.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
image = pipe(prompt="a dog", seed=0)
|
||||
image.save("image_FLUX.1-dev_lora.jpg")
|
||||
@@ -16,6 +16,32 @@ cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
## Overview
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
@@ -173,7 +199,7 @@ Wan supports multiple acceleration techniques, including:
|
||||
* **Unified Sequence Parallel**: Sequence parallelism based on [xDiT](https://github.com/xdit-project/xDiT). Please refer to [this example](./acceleration/unified_sequence_parallel.py), and run it using the command:
|
||||
|
||||
```shell
|
||||
pip install xfuser>=0.4.3
|
||||
pip install "xfuser[flash-attn]>=0.4.3"
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||
```
|
||||
|
||||
|
||||
@@ -16,6 +16,32 @@ cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
@@ -175,7 +201,7 @@ Wan 支持多种加速方案,包括
|
||||
* 统一序列并行:基于 [xDiT](https://github.com/xdit-project/xDiT) 实现的序列并行,请参考[示例代码](./acceleration/unified_sequence_parallel.py),使用以下命令运行:
|
||||
|
||||
```shell
|
||||
pip install xfuser>=0.4.3
|
||||
pip install "xfuser[flash-attn]>=0.4.3"
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||
```
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--height 720 \
|
||||
--width 1280 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
@@ -10,4 +11,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-I2V-14B-720P_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image"
|
||||
--extra_inputs "input_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -1,8 +1,9 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--height 720 \
|
||||
--width 1280 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
@@ -12,4 +13,5 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image"
|
||||
--extra_inputs "input_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch, os, json
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task, wan_parser
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -107,4 +107,14 @@ if __name__ == "__main__":
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
)
|
||||
launch_training_task(model, dataset, args=args)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
@@ -19,12 +19,13 @@ state_dict = load_state_dict("models/train/Wan2.1-I2V-14B-720P_full/epoch-1.safe
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
|
||||
input_image = VideoData("data/example_video_dataset/video1.mp4", height=720, width=1280)[0]
|
||||
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=input_image,
|
||||
height=720, width=1280, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5)
|
||||
|
||||
@@ -18,12 +18,13 @@ pipe = WanVideoPipeline.from_pretrained(
|
||||
pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-720P_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
|
||||
input_image = VideoData("data/example_video_dataset/video1.mp4", height=720, width=1280)[0]
|
||||
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=input_image,
|
||||
height=720, width=1280, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5)
|
||||
|
||||
@@ -12,3 +12,5 @@ protobuf
|
||||
modelscope
|
||||
ftfy
|
||||
pynvml
|
||||
pandas
|
||||
accelerate
|
||||
|
||||
Reference in New Issue
Block a user