mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
6
diffsynth/diffusion/__init__.py
Normal file
6
diffsynth/diffusion/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .flow_match import FlowMatchScheduler
|
||||
from .training_module import DiffusionTrainingModule
|
||||
from .logger import ModelLogger
|
||||
from .runner import launch_training_task, launch_data_process_task
|
||||
from .parsers import *
|
||||
from .loss import *
|
||||
439
diffsynth/diffusion/base_pipeline.py
Normal file
439
diffsynth/diffusion/base_pipeline.py
Normal file
@@ -0,0 +1,439 @@
|
||||
from PIL import Image
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import repeat, reduce
|
||||
from typing import Union
|
||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig
|
||||
from ..utils.lora import GeneralLoRALoader
|
||||
from ..models.model_loader import ModelPool
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
|
||||
|
||||
class PipelineUnit:
|
||||
def __init__(
|
||||
self,
|
||||
seperate_cfg: bool = False,
|
||||
take_over: bool = False,
|
||||
input_params: tuple[str] = None,
|
||||
output_params: tuple[str] = None,
|
||||
input_params_posi: dict[str, str] = None,
|
||||
input_params_nega: dict[str, str] = None,
|
||||
onload_model_names: tuple[str] = None
|
||||
):
|
||||
self.seperate_cfg = seperate_cfg
|
||||
self.take_over = take_over
|
||||
self.input_params = input_params
|
||||
self.output_params = output_params
|
||||
self.input_params_posi = input_params_posi
|
||||
self.input_params_nega = input_params_nega
|
||||
self.onload_model_names = onload_model_names
|
||||
|
||||
def fetch_input_params(self):
|
||||
params = []
|
||||
if self.input_params is not None:
|
||||
for param in self.input_params:
|
||||
params.append(param)
|
||||
if self.input_params_posi is not None:
|
||||
for _, param in self.input_params_posi.items():
|
||||
params.append(param)
|
||||
if self.input_params_nega is not None:
|
||||
for _, param in self.input_params_nega.items():
|
||||
params.append(param)
|
||||
params = sorted(list(set(params)))
|
||||
return params
|
||||
|
||||
def fetch_output_params(self):
|
||||
params = []
|
||||
if self.output_params is not None:
|
||||
for param in self.output_params:
|
||||
params.append(param)
|
||||
return params
|
||||
|
||||
def process(self, pipe, **kwargs) -> dict:
|
||||
return {}
|
||||
|
||||
def post_process(self, pipe, **kwargs) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cuda", torch_dtype=torch.float16,
|
||||
height_division_factor=64, width_division_factor=64,
|
||||
time_division_factor=None, time_division_remainder=None,
|
||||
):
|
||||
super().__init__()
|
||||
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# The following parameters are used for shape check.
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# VRAM management
|
||||
self.vram_management_enabled = False
|
||||
# Pipeline Unit Runner
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
# LoRA Loader
|
||||
self.lora_loader = GeneralLoRALoader
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width, num_frames=None):
|
||||
# Shape check
|
||||
if height % self.height_division_factor != 0:
|
||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||
if width % self.width_division_factor != 0:
|
||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||
if num_frames is None:
|
||||
return height, width
|
||||
else:
|
||||
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||
return height, width, num_frames
|
||||
|
||||
|
||||
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a PIL.Image to torch.Tensor
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32))
|
||||
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
image = image * ((max_value - min_value) / 255) + min_value
|
||||
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a list of PIL.Image to torch.Tensor
|
||||
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
||||
video = torch.stack(video, dim=pattern.index("T") // 2)
|
||||
return video
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to PIL.Image
|
||||
if pattern != "H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
||||
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
||||
image = image.to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(image.numpy())
|
||||
return image
|
||||
|
||||
|
||||
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to list of PIL.Image
|
||||
if pattern != "T H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
||||
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||
return video
|
||||
|
||||
|
||||
def load_models_to_device(self, model_names):
|
||||
if self.vram_management_enabled:
|
||||
# offload models
|
||||
for name, model in self.named_children():
|
||||
if name not in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
if hasattr(model, "offload"):
|
||||
model.offload()
|
||||
else:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
torch.cuda.empty_cache()
|
||||
# onload models
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
if hasattr(model, "onload"):
|
||||
model.onload()
|
||||
else:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
|
||||
|
||||
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
||||
# Initialize Gaussian noise
|
||||
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
||||
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
return noise
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
|
||||
def get_module(self, model, name):
|
||||
if "." in name:
|
||||
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
|
||||
if name.isdigit():
|
||||
return self.get_module(model[int(name)], suffix)
|
||||
else:
|
||||
return self.get_module(getattr(model, name), suffix)
|
||||
else:
|
||||
return getattr(model, name)
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
self.eval()
|
||||
self.requires_grad_(False)
|
||||
for name in model_names:
|
||||
module = self.get_module(self, name)
|
||||
if module is None:
|
||||
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
|
||||
continue
|
||||
module.train()
|
||||
module.requires_grad_(True)
|
||||
|
||||
|
||||
def blend_with_mask(self, base, addition, mask):
|
||||
return base * (1 - mask) + addition * mask
|
||||
|
||||
|
||||
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
|
||||
timestep = scheduler.timesteps[progress_id]
|
||||
if inpaint_mask is not None:
|
||||
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
|
||||
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
|
||||
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||
return latents_next
|
||||
|
||||
|
||||
def split_pipeline_units(self, model_names: list[str]):
|
||||
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
|
||||
|
||||
|
||||
def flush_vram_management_device(self, device):
|
||||
for module in self.modules():
|
||||
if isinstance(module, AutoTorchModule):
|
||||
module.offload_device = device
|
||||
module.onload_device = device
|
||||
module.preparing_device = device
|
||||
module.computation_device = device
|
||||
|
||||
|
||||
def load_lora(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str] = None,
|
||||
alpha=1,
|
||||
hotload=None,
|
||||
state_dict=None,
|
||||
):
|
||||
if state_dict is None:
|
||||
if isinstance(lora_config, str):
|
||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora = state_dict
|
||||
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = lora_loader.convert_state_dict(lora)
|
||||
if hotload is None:
|
||||
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
|
||||
if hotload:
|
||||
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
|
||||
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
|
||||
updated_num = 0
|
||||
for _, module in module.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
name = module.name
|
||||
lora_a_name = f'{name}.lora_A.weight'
|
||||
lora_b_name = f'{name}.lora_B.weight'
|
||||
if lora_a_name in lora and lora_b_name in lora:
|
||||
updated_num += 1
|
||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||
module.lora_B_weights.append(lora[lora_b_name])
|
||||
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
||||
else:
|
||||
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def clear_lora(self):
|
||||
cleared_num = 0
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
if hasattr(module, "lora_A_weights"):
|
||||
if len(module.lora_A_weights) > 0:
|
||||
cleared_num += 1
|
||||
module.lora_A_weights.clear()
|
||||
if hasattr(module, "lora_B_weights"):
|
||||
module.lora_B_weights.clear()
|
||||
print(f"{cleared_num} LoRA layers are cleared.")
|
||||
|
||||
|
||||
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
||||
model_pool = ModelPool()
|
||||
for model_config in model_configs:
|
||||
model_config.download_if_necessary()
|
||||
vram_config = model_config.vram_config()
|
||||
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
|
||||
vram_config["computation_device"] = vram_config["computation_device"] or self.device
|
||||
model_pool.auto_load_model(
|
||||
model_config.path,
|
||||
vram_config=vram_config,
|
||||
vram_limit=vram_limit,
|
||||
clear_parameters=model_config.clear_parameters,
|
||||
)
|
||||
return model_pool
|
||||
|
||||
|
||||
def check_vram_management_state(self):
|
||||
vram_management_enabled = False
|
||||
for module in self.children():
|
||||
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
|
||||
vram_management_enabled = True
|
||||
return vram_management_enabled
|
||||
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
|
||||
class PipelineUnitGraph:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def build_edges(self, units: list[PipelineUnit]):
|
||||
# Establish dependencies between units
|
||||
# to search for subsequent related computation units.
|
||||
last_compute_unit_id = {}
|
||||
edges = []
|
||||
for unit_id, unit in enumerate(units):
|
||||
for input_param in unit.fetch_input_params():
|
||||
if input_param in last_compute_unit_id:
|
||||
edges.append((last_compute_unit_id[input_param], unit_id))
|
||||
for output_param in unit.fetch_output_params():
|
||||
last_compute_unit_id[output_param] = unit_id
|
||||
return edges
|
||||
|
||||
def build_chains(self, units: list[PipelineUnit]):
|
||||
# Establish updating chains for each variable
|
||||
# to track their computation process.
|
||||
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
|
||||
params = sorted(list(set(params)))
|
||||
chains = {param: [] for param in params}
|
||||
for unit_id, unit in enumerate(units):
|
||||
for param in unit.fetch_output_params():
|
||||
chains[param].append(unit_id)
|
||||
return chains
|
||||
|
||||
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
|
||||
# Search for units that directly participate in the model's computation.
|
||||
related_unit_ids = []
|
||||
for unit_id, unit in enumerate(units):
|
||||
for model_name in model_names:
|
||||
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
|
||||
related_unit_ids.append(unit_id)
|
||||
break
|
||||
return related_unit_ids
|
||||
|
||||
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
|
||||
# Search for subsequent related computation units.
|
||||
related_unit_ids = [unit_id for unit_id in start_unit_ids]
|
||||
while True:
|
||||
neighbors = []
|
||||
for source, target in edges:
|
||||
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
|
||||
neighbors.append(target)
|
||||
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
|
||||
neighbors.append(source)
|
||||
neighbors = sorted(list(set(neighbors)))
|
||||
if len(neighbors) == 0:
|
||||
break
|
||||
else:
|
||||
related_unit_ids.extend(neighbors)
|
||||
related_unit_ids = sorted(list(set(related_unit_ids)))
|
||||
return related_unit_ids
|
||||
|
||||
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
|
||||
# If the input parameters of this subgraph are updated outside the subgraph,
|
||||
# search for the units where these updates occur.
|
||||
first_compute_unit_id = {}
|
||||
for unit_id in related_unit_ids:
|
||||
for param in units[unit_id].fetch_input_params():
|
||||
if param not in first_compute_unit_id:
|
||||
first_compute_unit_id[param] = unit_id
|
||||
updating_unit_ids = []
|
||||
for param in first_compute_unit_id:
|
||||
unit_id = first_compute_unit_id[param]
|
||||
chain = chains[param]
|
||||
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
|
||||
for unit_id_ in chain[chain.index(unit_id) + 1:]:
|
||||
if unit_id_ not in related_unit_ids:
|
||||
updating_unit_ids.append(unit_id_)
|
||||
related_unit_ids.extend(updating_unit_ids)
|
||||
related_unit_ids = sorted(list(set(related_unit_ids)))
|
||||
return related_unit_ids
|
||||
|
||||
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
|
||||
# Split the computation graph,
|
||||
# separating all model-related computations.
|
||||
related_unit_ids = self.search_direct_unit_ids(units, model_names)
|
||||
edges = self.build_edges(units)
|
||||
chains = self.build_chains(units)
|
||||
while True:
|
||||
num_related_unit_ids = len(related_unit_ids)
|
||||
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
|
||||
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
|
||||
if len(related_unit_ids) == num_related_unit_ids:
|
||||
break
|
||||
else:
|
||||
num_related_unit_ids = len(related_unit_ids)
|
||||
related_units = [units[i] for i in related_unit_ids]
|
||||
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
|
||||
return related_units, unrelated_units
|
||||
|
||||
|
||||
class PipelineUnitRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
||||
if unit.take_over:
|
||||
# Let the pipeline unit take over this function.
|
||||
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
||||
elif unit.seperate_cfg:
|
||||
# Positive side
|
||||
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_posi.update(processor_outputs)
|
||||
# Negative side
|
||||
if inputs_shared["cfg_scale"] != 1:
|
||||
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_shared.update(processor_outputs)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
179
diffsynth/diffusion/flow_match.py
Normal file
179
diffsynth/diffusion/flow_match.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import torch, math
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
|
||||
self.set_timesteps_fn = {
|
||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||
sigma_min = 0.003/1.002
|
||||
sigma_max = 1.0
|
||||
shift = 3 if shift is None else shift
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
shift = 5 if shift is None else shift
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
shift_terminal = 0.02
|
||||
# Sigmas
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
# Mu
|
||||
if exponential_shift_mu is not None:
|
||||
mu = exponential_shift_mu
|
||||
elif dynamic_shift_len is not None:
|
||||
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
|
||||
else:
|
||||
mu = 0.8
|
||||
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||
# Shift terminal
|
||||
one_minus_z = 1 - sigmas
|
||||
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||
sigmas = 1 - (one_minus_z / scale_factor)
|
||||
# Timesteps
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def compute_empirical_mu(image_seq_len, num_steps):
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16):
|
||||
sigma_min = 1 / num_inference_steps
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
||||
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
shift = 3 if shift is None else shift
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
if target_timesteps is not None:
|
||||
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
|
||||
for timestep in target_timesteps:
|
||||
timestep_id = torch.argmin((timesteps - timestep).abs())
|
||||
timesteps[timestep_id] = timestep
|
||||
return sigmas, timesteps
|
||||
|
||||
def set_training_weight(self):
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
||||
if len(self.timesteps) != 1000:
|
||||
# This is an empirical formula.
|
||||
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
||||
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
||||
num_inference_steps=num_inference_steps,
|
||||
denoising_strength=denoising_strength,
|
||||
**kwargs,
|
||||
)
|
||||
if training:
|
||||
self.set_training_weight()
|
||||
self.training = True
|
||||
else:
|
||||
self.training = False
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
sigma_ = 0
|
||||
else:
|
||||
sigma_ = self.sigmas[timestep_id + 1]
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
return prev_sample
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
model_output = (sample - sample_stablized) / sigma
|
||||
return model_output
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
sample = (1 - sigma) * original_samples + sigma * noise
|
||||
return sample
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
target = noise - sample
|
||||
return target
|
||||
|
||||
def training_weight(self, timestep):
|
||||
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||
weights = self.linear_timesteps_weights[timestep_id]
|
||||
return weights
|
||||
43
diffsynth/diffusion/logger.py
Normal file
43
diffsynth/diffusion/logger.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os, torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
self.state_dict_converter = state_dict_converter
|
||||
self.num_steps = 0
|
||||
|
||||
|
||||
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
||||
self.num_steps += 1
|
||||
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
||||
if save_steps is not None and self.num_steps % save_steps != 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, file_name)
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
119
diffsynth/diffusion/loss.py
Normal file
119
diffsynth/diffusion/loss.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from .base_pipeline import BasePipeline
|
||||
import torch
|
||||
|
||||
|
||||
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
noise = torch.randn_like(inputs["input_latents"])
|
||||
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||
pipe.scheduler.training = True
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
||||
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
||||
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
||||
return loss
|
||||
|
||||
|
||||
class TrajectoryImitationLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self, device):
|
||||
import lpips # TODO: remove it
|
||||
self.loss_fn = lpips.LPIPS(net='alex').to(device)
|
||||
self.initialized = True
|
||||
|
||||
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||
trajectory = [inputs_shared["latents"].clone()]
|
||||
|
||||
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
noise_pred = pipe.cfg_guided_model_fn(
|
||||
pipe.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||
|
||||
trajectory.append(inputs_shared["latents"].clone())
|
||||
return pipe.scheduler.timesteps, trajectory
|
||||
|
||||
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||
loss = 0
|
||||
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
|
||||
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
|
||||
|
||||
noise_pred = pipe.cfg_guided_model_fn(
|
||||
pipe.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
|
||||
sigma = pipe.scheduler.sigmas[progress_id]
|
||||
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
|
||||
if progress_id + 1 >= len(pipe.scheduler.timesteps):
|
||||
latents_ = trajectory_teacher[-1]
|
||||
else:
|
||||
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
||||
latents_ = trajectory_teacher[progress_id_teacher]
|
||||
|
||||
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
|
||||
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||
inputs_shared["latents"] = trajectory_teacher[0]
|
||||
pipe.scheduler.set_timesteps(num_inference_steps)
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
noise_pred = pipe.cfg_guided_model_fn(
|
||||
pipe.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||
|
||||
image_pred = pipe.vae_decoder(inputs_shared["latents"])
|
||||
image_real = pipe.vae_decoder(trajectory_teacher[-1])
|
||||
loss = self.loss_fn(image_pred.float(), image_real.float())
|
||||
return loss
|
||||
|
||||
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if not self.initialized:
|
||||
self.initialize(pipe.device)
|
||||
with torch.no_grad():
|
||||
pipe.scheduler.set_timesteps(8)
|
||||
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
|
||||
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||
loss = loss_1 + loss_2
|
||||
return loss
|
||||
70
diffsynth/diffusion/parsers.py
Normal file
70
diffsynth/diffusion/parsers.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def add_dataset_base_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
||||
return parser
|
||||
|
||||
def add_image_size_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||
return parser
|
||||
|
||||
def add_video_size_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||
return parser
|
||||
|
||||
def add_model_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
|
||||
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
|
||||
return parser
|
||||
|
||||
def add_training_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||
return parser
|
||||
|
||||
def add_output_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
return parser
|
||||
|
||||
def add_lora_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||
parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
|
||||
parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
|
||||
return parser
|
||||
|
||||
def add_gradient_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
return parser
|
||||
|
||||
def add_general_config(parser: argparse.ArgumentParser):
|
||||
parser = add_dataset_base_config(parser)
|
||||
parser = add_model_config(parser)
|
||||
parser = add_training_config(parser)
|
||||
parser = add_output_config(parser)
|
||||
parser = add_lora_config(parser)
|
||||
parser = add_gradient_config(parser)
|
||||
return parser
|
||||
71
diffsynth/diffusion/runner.py
Normal file
71
diffsynth/diffusion/runner.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os, torch
|
||||
from tqdm import tqdm
|
||||
from accelerate import Accelerator
|
||||
from .training_module import DiffusionTrainingModule
|
||||
from .logger import ModelLogger
|
||||
|
||||
|
||||
def launch_training_task(
|
||||
accelerator: Accelerator,
|
||||
dataset: torch.utils.data.Dataset,
|
||||
model: DiffusionTrainingModule,
|
||||
model_logger: ModelLogger,
|
||||
learning_rate: float = 1e-5,
|
||||
weight_decay: float = 1e-2,
|
||||
num_workers: int = 1,
|
||||
save_steps: int = None,
|
||||
num_epochs: int = 1,
|
||||
args = None,
|
||||
):
|
||||
if args is not None:
|
||||
learning_rate = args.learning_rate
|
||||
weight_decay = args.weight_decay
|
||||
num_workers = args.dataset_num_workers
|
||||
save_steps = args.save_steps
|
||||
num_epochs = args.num_epochs
|
||||
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
for epoch_id in range(num_epochs):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
if dataset.load_from_cache:
|
||||
loss = model({}, inputs=data)
|
||||
else:
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(accelerator, model, save_steps)
|
||||
scheduler.step()
|
||||
if save_steps is None:
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
model_logger.on_training_end(accelerator, model, save_steps)
|
||||
|
||||
|
||||
def launch_data_process_task(
|
||||
accelerator: Accelerator,
|
||||
dataset: torch.utils.data.Dataset,
|
||||
model: DiffusionTrainingModule,
|
||||
model_logger: ModelLogger,
|
||||
num_workers: int = 8,
|
||||
args = None,
|
||||
):
|
||||
if args is not None:
|
||||
num_workers = args.dataset_num_workers
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
model, dataloader = accelerator.prepare(model, dataloader)
|
||||
|
||||
for data_id, data in enumerate(tqdm(dataloader)):
|
||||
with accelerator.accumulate(model):
|
||||
with torch.no_grad():
|
||||
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||
data = model(data)
|
||||
torch.save(data, save_path)
|
||||
212
diffsynth/diffusion/training_module.py
Normal file
212
diffsynth/diffusion/training_module.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import torch, json
|
||||
from ..core import ModelConfig, load_state_dict
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
|
||||
|
||||
class DiffusionTrainingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
for name, model in self.named_children():
|
||||
model.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def trainable_modules(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
||||
return trainable_modules
|
||||
|
||||
|
||||
def trainable_param_names(self):
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
return trainable_param_names
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
||||
if lora_alpha is None:
|
||||
lora_alpha = lora_rank
|
||||
if isinstance(target_modules, list) and len(target_modules) == 1:
|
||||
target_modules = target_modules[0]
|
||||
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
if upcast_dtype is not None:
|
||||
for param in model.parameters():
|
||||
if param.requires_grad:
|
||||
param.data = param.to(upcast_dtype)
|
||||
return model
|
||||
|
||||
|
||||
def mapping_lora_state_dict(self, state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if "lora_A.weight" in key or "lora_B.weight" in key:
|
||||
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
||||
new_state_dict[new_key] = value
|
||||
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
|
||||
new_state_dict[key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
||||
trainable_param_names = self.trainable_param_names()
|
||||
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
||||
if remove_prefix is not None:
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith(remove_prefix):
|
||||
name = name[len(remove_prefix):]
|
||||
state_dict_[name] = param
|
||||
state_dict = state_dict_
|
||||
return state_dict
|
||||
|
||||
|
||||
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
||||
if data is None:
|
||||
return data
|
||||
elif isinstance(data, torch.Tensor):
|
||||
data = data.to(device)
|
||||
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
|
||||
data = data.to(torch_float_dtype)
|
||||
return data
|
||||
elif isinstance(data, tuple):
|
||||
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
||||
return data
|
||||
elif isinstance(data, list):
|
||||
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
||||
return data
|
||||
elif isinstance(data, dict):
|
||||
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
|
||||
return data
|
||||
else:
|
||||
return data
|
||||
|
||||
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
|
||||
if fp8:
|
||||
return {
|
||||
"offload_dtype": torch.float8_e4m3fn,
|
||||
"offload_device": device,
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": device,
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": device,
|
||||
}
|
||||
elif offload:
|
||||
return {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": device,
|
||||
"clear_parameters": True,
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
|
||||
fp8_models = [] if fp8_models is None else fp8_models.split(",")
|
||||
offload_models = [] if offload_models is None else offload_models.split(",")
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
for path in model_paths:
|
||||
vram_config = self.parse_vram_config(
|
||||
fp8=path in fp8_models,
|
||||
offload=path in offload_models,
|
||||
device=device
|
||||
)
|
||||
model_configs.append(ModelConfig(path=path, **vram_config))
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
for model_id_with_origin_path in model_id_with_origin_paths:
|
||||
model_id, origin_file_pattern = model_id_with_origin_path.split(":")
|
||||
vram_config = self.parse_vram_config(
|
||||
fp8=model_id_with_origin_path in fp8_models,
|
||||
offload=model_id_with_origin_path in offload_models,
|
||||
device=device
|
||||
)
|
||||
model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config))
|
||||
return model_configs
|
||||
|
||||
|
||||
def switch_pipe_to_training_mode(
|
||||
self,
|
||||
pipe,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
preset_lora_path=None, preset_lora_model=None,
|
||||
task="sft",
|
||||
):
|
||||
# Scheduler
|
||||
pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
# Freeze untrainable models
|
||||
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Preset LoRA
|
||||
if preset_lora_path is not None:
|
||||
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
|
||||
|
||||
# FP8
|
||||
# FP8 relies on a model-specific memory management scheme.
|
||||
# It is delegated to the subclass.
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None and not task.endswith(":data_process"):
|
||||
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
|
||||
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
|
||||
return
|
||||
model = self.add_lora_to_model(
|
||||
getattr(pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank,
|
||||
upcast_dtype=pipe.torch_dtype,
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(pipe, lora_base_model, model)
|
||||
|
||||
|
||||
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
|
||||
models_require_backward = []
|
||||
if trainable_models is not None:
|
||||
models_require_backward += trainable_models.split(",")
|
||||
if lora_base_model is not None:
|
||||
models_require_backward += [lora_base_model]
|
||||
if task.endswith(":data_process"):
|
||||
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
||||
elif task.endswith(":train"):
|
||||
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
||||
return pipe
|
||||
|
||||
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
|
||||
controlnet_keys_map = (
|
||||
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
|
||||
("controlnet_", "controlnet_inputs"),
|
||||
)
|
||||
controlnet_inputs = {}
|
||||
for extra_input in extra_inputs:
|
||||
for prefix, name in controlnet_keys_map:
|
||||
if extra_input.startswith(prefix):
|
||||
if name not in controlnet_inputs:
|
||||
controlnet_inputs[name] = {}
|
||||
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
|
||||
break
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
for name, params in controlnet_inputs.items():
|
||||
inputs_shared[name] = [ControlNetInput(**params)]
|
||||
return inputs_shared
|
||||
Reference in New Issue
Block a user