Compare commits

...

8 Commits

Author SHA1 Message Date
xuyixuan.xyx
a10635818a multi-node 2025-04-22 10:40:40 +08:00
Artiprocher
9ea29a769d fix base model 2025-04-14 13:18:31 +08:00
Artiprocher
2a5355b7cb fix base model 2025-04-14 13:11:45 +08:00
Artiprocher
7a06a58f49 support reference image 2025-04-11 16:23:14 +08:00
Artiprocher
a3b4f235a0 support reference image 2025-04-11 16:17:39 +08:00
Artiprocher
a572254a1d support reference image 2025-04-11 16:14:24 +08:00
Artiprocher
9e78bf5e89 support reference image 2025-04-11 15:36:51 +08:00
Artiprocher
d21676b4dc support reference image 2025-04-11 15:31:30 +08:00
6 changed files with 731 additions and 24 deletions

View File

@@ -0,0 +1,125 @@
import torch, os, json, torchvision
from PIL import Image
from torchvision.transforms import v2
class SingleTaskDataset(torch.utils.data.Dataset):
def __init__(self, base_path, keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")), height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None):
self.base_path = base_path
self.keys = keys
self.metadata = []
self.bad_data = []
self.height = height
self.width = width
self.random = random
self.steps_per_epoch = steps_per_epoch
self.image_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if metadata_path is None:
self.search_for_data("", report_data_log=True)
self.report_data_log()
else:
with open(metadata_path, "r", encoding="utf-8-sig") as f:
self.metadata = json.load(f)
def report_data_log(self):
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
def dump_metadata(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
def parse_json_file(self, absolute_path, relative_path):
data_list = []
with open(absolute_path, "r") as f:
metadata = json.load(f)
for image_1, image_2, instruction in self.keys:
image_1 = os.path.join(relative_path, metadata[image_1])
image_2 = os.path.join(relative_path, metadata[image_2])
instruction = metadata[instruction]
data_list.append((image_1, image_2, instruction))
return data_list
def search_for_data(self, path, report_data_log=False):
now_path = os.path.join(self.base_path, path)
if os.path.isfile(now_path) and path.endswith(".json"):
try:
data_list = self.parse_json_file(now_path, os.path.dirname(path))
self.metadata.extend(data_list)
except:
self.bad_data.append(now_path)
elif os.path.isdir(now_path):
for sub_path in os.listdir(now_path):
self.search_for_data(os.path.join(path, sub_path))
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
self.report_data_log()
def load_image(self, image_path):
image_path = os.path.join(self.base_path, image_path)
image = Image.open(image_path).convert("RGB")
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = self.image_process(image)
return image
def load_data(self, data_id):
image_1, image_2, instruction = self.metadata[data_id]
image_1 = self.load_image(image_1)
image_2 = self.load_image(image_2)
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
def __getitem__(self, data_id):
if self.random:
while True:
try:
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
data = self.load_data(data_id)
return data
except:
continue
else:
return self.load_data(data_id)
def __len__(self):
return self.steps_per_epoch if self.random else len(self.metadata)
class MultiTaskDataset(torch.utils.data.Dataset):
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
self.dataset_list = dataset_list
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
self.steps_per_epoch = steps_per_epoch
def __getitem__(self, data_id):
while True:
try:
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
data = self.dataset_list[dataset_id][data_id]
return data
except:
continue
def __len__(self):
return self.steps_per_epoch

View File

@@ -20,10 +20,11 @@ class RoPEEmbedding(torch.nn.Module):
self.axes_dim = axes_dim self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: def rope(self, pos: torch.Tensor, dim: int, theta: int, device="cpu") -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even." assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = scale.to(device)
omega = 1.0 / (theta**scale) omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape batch_size, seq_length = pos.shape
@@ -36,9 +37,9 @@ class RoPEEmbedding(torch.nn.Module):
return out.float() return out.float()
def forward(self, ids): def forward(self, ids, device="cpu"):
n_axes = ids.shape[-1] n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta, device) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1) return emb.unsqueeze(1)

View File

@@ -0,0 +1,31 @@
from .sd3_dit import TimestepEmbeddings
from .flux_dit import RoPEEmbedding
import torch
from einops import repeat
class FluxReferenceEmbedder(torch.nn.Module):
def __init__(self):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.idx_embedder = TimestepEmbeddings(256, 256)
self.proj = torch.nn.Linear(3072, 3072)
def forward(self, image_ids, idx, dtype, device):
pos_emb = self.pos_embedder(image_ids, device=device)
idx_emb = self.idx_embedder(idx, dtype=dtype).to(device)
length = pos_emb.shape[2]
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
image_rotary_emb = pos_emb + idx_emb
return image_rotary_emb
def init(self):
self.idx_embedder.timestep_embedder[-1].load_state_dict({
"weight": torch.zeros((256, 256)),
"bias": torch.zeros((256,))
}),
self.proj.load_state_dict({
"weight": torch.eye(3072),
"bias": torch.zeros((3072,))
})

View File

@@ -1,10 +1,12 @@
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..models.flux_reference_embedder import FluxReferenceEmbedder
from ..prompters import FluxPrompter from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler from ..schedulers import FlowMatchScheduler
from .base import BasePipeline from .base import BasePipeline
from typing import List from typing import List
import torch import torch
from einops import rearrange
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@@ -32,6 +34,7 @@ class FluxImagePipeline(BasePipeline):
self.ipadapter: FluxIpAdapter = None self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None self.ipadapter_image_encoder: SiglipVisionModel = None
self.infinityou_processor: InfinitYou = None self.infinityou_processor: InfinitYou = None
self.reference_embedder: FluxReferenceEmbedder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder'] self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
@@ -362,6 +365,20 @@ class FluxImagePipeline(BasePipeline):
return {}, controlnet_image return {}, controlnet_image
def prepare_reference_images(self, reference_images, tiled=False, tile_size=64, tile_stride=32):
if reference_images is not None:
hidden_states_ref = []
for reference_image in reference_images:
self.load_models_to_device(['vae_encoder'])
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
hidden_states_ref.append(latents)
hidden_states_ref = torch.concat(hidden_states_ref, dim=0)
return {"hidden_states_ref": hidden_states_ref}
else:
return {"hidden_states_ref": None}
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
@@ -398,6 +415,8 @@ class FluxImagePipeline(BasePipeline):
# InfiniteYou # InfiniteYou
infinityou_id_image=None, infinityou_id_image=None,
infinityou_guidance=1.0, infinityou_guidance=1.0,
# Reference images
reference_images=None,
# TeaCache # TeaCache
tea_cache_l1_thresh=None, tea_cache_l1_thresh=None,
# Tile # Tile
@@ -437,6 +456,9 @@ class FluxImagePipeline(BasePipeline):
# ControlNets # ControlNets
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# Reference images
reference_kwargs = self.prepare_reference_images(reference_images, **tiler_kwargs)
# TeaCache # TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
@@ -447,9 +469,9 @@ class FluxImagePipeline(BasePipeline):
# Positive side # Positive side
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet, dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder,
hidden_states=latents, timestep=timestep, hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **reference_kwargs,
) )
noise_pred_posi = self.control_noise_via_local_prompts( noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -464,9 +486,9 @@ class FluxImagePipeline(BasePipeline):
if cfg_scale != 1.0: if cfg_scale != 1.0:
# Negative side # Negative side
noise_pred_nega = lets_dance_flux( noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet, dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder,
hidden_states=latents, timestep=timestep, hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **reference_kwargs,
) )
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
@@ -586,6 +608,7 @@ class TeaCache:
def lets_dance_flux( def lets_dance_flux(
dit: FluxDiT, dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None, controlnet: FluxMultiControlNetManager = None,
reference_embedder: FluxReferenceEmbedder = None,
hidden_states=None, hidden_states=None,
timestep=None, timestep=None,
prompt_emb=None, prompt_emb=None,
@@ -594,6 +617,7 @@ def lets_dance_flux(
text_ids=None, text_ids=None,
image_ids=None, image_ids=None,
controlnet_frames=None, controlnet_frames=None,
hidden_states_ref=None,
tiled=False, tiled=False,
tile_size=128, tile_size=128,
tile_stride=64, tile_stride=64,
@@ -603,6 +627,7 @@ def lets_dance_flux(
id_emb=None, id_emb=None,
infinityou_guidance=None, infinityou_guidance=None,
tea_cache: TeaCache = None, tea_cache: TeaCache = None,
use_gradient_checkpointing=False,
**kwargs **kwargs
): ):
if tiled: if tiled:
@@ -669,20 +694,47 @@ def lets_dance_flux(
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else: else:
prompt_emb = dit.context_embedder(prompt_emb) prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1), device=hidden_states.device)
attention_mask = None attention_mask = None
# Reference images
if hidden_states_ref is not None:
# RoPE
image_ids_ref = dit.prepare_image_ids(hidden_states_ref)
idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100
image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype, device=hidden_states.device)
image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2)
# hidden_states
original_hidden_states_length = hidden_states.shape[1]
hidden_states_ref = dit.patchify(hidden_states_ref)
hidden_states_ref = dit.x_embedder(hidden_states_ref)
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
hidden_states_ref = reference_embedder.proj(hidden_states_ref)
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
# TeaCache # TeaCache
if tea_cache is not None: if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
else: else:
tea_cache_update = False tea_cache_update = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tea_cache_update: if tea_cache_update:
hidden_states = tea_cache.update(hidden_states) hidden_states = tea_cache.update(hidden_states)
else: else:
# Joint Blocks # Joint Blocks
for block_id, block in enumerate(dit.blocks): for block_id, block in enumerate(dit.blocks):
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None),
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block( hidden_states, prompt_emb = block(
hidden_states, hidden_states,
prompt_emb, prompt_emb,
@@ -699,6 +751,13 @@ def lets_dance_flux(
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks) num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks): for block_id, block in enumerate(dit.single_blocks):
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block( hidden_states, prompt_emb = block(
hidden_states, hidden_states,
prompt_emb, prompt_emb,
@@ -715,6 +774,8 @@ def lets_dance_flux(
if tea_cache is not None: if tea_cache is not None:
tea_cache.store(hidden_states) tea_cache.store(hidden_states)
if hidden_states_ref is not None:
hidden_states = hidden_states[:, :original_hidden_states_length]
hidden_states = dit.final_norm_out(hidden_states, conditioning) hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states) hidden_states = dit.final_proj_out(hidden_states)
hidden_states = dit.unpatchify(hidden_states, height, width) hidden_states = dit.unpatchify(hidden_states, height, width)

241
train_flux_reference.py Normal file
View File

@@ -0,0 +1,241 @@
from diffsynth import ModelManager, FluxImagePipeline
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
from diffsynth.models.lora import FluxLoRAConverter
import torch, os, argparse
import lightning as pl
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
from diffsynth.pipelines.flux_image import lets_dance_flux
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class LightningModel(LightningModelForT2ILoRA):
def __init__(
self,
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
state_dict_converter=None, quantize = None
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
# Load models
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
if quantize is None:
model_manager.load_models(pretrained_weights)
else:
model_manager.load_models(pretrained_weights[1:])
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
if preset_lora_path is not None:
preset_lora_path = preset_lora_path.split(",")
for path in preset_lora_path:
model_manager.load_lora(path)
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.pipe.reference_embedder = FluxReferenceEmbedder()
self.pipe.reference_embedder.init()
if quantize is not None:
self.pipe.dit.quantize()
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.pipe.reference_embedder.requires_grad_(True)
self.pipe.reference_embedder.train()
self.pipe.dit.requires_grad_(True)
self.pipe.dit.train()
# self.add_lora_to_model(
# self.pipe.denoising_model(),
# lora_rank=lora_rank,
# lora_alpha=lora_alpha,
# lora_target_modules=lora_target_modules,
# init_lora_weights=init_lora_weights,
# pretrained_lora_path=pretrained_lora_path,
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
# )
def training_step(self, batch, batch_idx):
# Data
text, image = batch["instruction"], batch["image_2"]
image_ref = batch["image_1"]
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
if "latents" in batch:
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
else:
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Reference image
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
# Compute loss
noise_pred = lets_dance_flux(
self.pipe.denoising_model(),
reference_embedder=self.pipe.reference_embedder,
hidden_states_ref=hidden_states_ref,
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
if self.state_dict_converter is not None:
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
checkpoint.update(lora_state_dict)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_text_encoder_path",
type=str,
default=None,
required=True,
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
)
parser.add_argument(
"--pretrained_text_encoder_2_path",
type=str,
default=None,
required=True,
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
)
parser.add_argument(
"--pretrained_dit_path",
type=str,
default=None,
required=True,
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
)
parser.add_argument(
"--pretrained_vae_path",
type=str,
default=None,
required=True,
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--align_to_opensource_format",
default=False,
action="store_true",
help="Whether to export lora files aligned with other opensource format.",
)
parser.add_argument(
"--quantize",
type=str,
default=None,
choices=["float8_e4m3fn"],
help="Whether to use quantization when training the model, and in which format.",
)
parser.add_argument(
"--preset_lora_path",
type=str,
default=None,
help="Preset LoRA path.",
)
parser = add_general_parsers(parser)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = LightningModel(
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
preset_lora_path=args.preset_lora_path,
learning_rate=args.learning_rate,
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
pretrained_lora_path=args.pretrained_lora_path,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
)
# dataset and data loader
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
height=512, width=512,
),
],
dataset_weight=(4, 1, 4, 1),
steps_per_epoch=args.steps_per_epoch,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers
)
# train
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision=args.precision,
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=None,
)
trainer.fit(model=model, train_dataloaders=train_loader)

View File

@@ -0,0 +1,248 @@
from diffsynth import ModelManager, FluxImagePipeline
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
from diffsynth.models.lora import FluxLoRAConverter
import torch, os, argparse
import lightning as pl
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
from diffsynth.pipelines.flux_image import lets_dance_flux
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class LightningModel(LightningModelForT2ILoRA):
def __init__(
self,
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
state_dict_converter=None, quantize = None
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
# Load models
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
if quantize is None:
model_manager.load_models(pretrained_weights)
else:
model_manager.load_models(pretrained_weights[1:])
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
if preset_lora_path is not None:
preset_lora_path = preset_lora_path.split(",")
for path in preset_lora_path:
model_manager.load_lora(path)
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.pipe.reference_embedder = FluxReferenceEmbedder()
self.pipe.reference_embedder.init()
if quantize is not None:
self.pipe.dit.quantize()
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.pipe.reference_embedder.requires_grad_(True)
self.pipe.reference_embedder.train()
self.pipe.dit.requires_grad_(True)
self.pipe.dit.train()
# self.add_lora_to_model(
# self.pipe.denoising_model(),
# lora_rank=lora_rank,
# lora_alpha=lora_alpha,
# lora_target_modules=lora_target_modules,
# init_lora_weights=init_lora_weights,
# pretrained_lora_path=pretrained_lora_path,
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
# )
def training_step(self, batch, batch_idx):
# Data
text, image = batch["instruction"], batch["image_2"]
image_ref = batch["image_1"]
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
if "latents" in batch:
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
else:
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Reference image
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
# Compute loss
noise_pred = lets_dance_flux(
self.pipe.denoising_model(),
reference_embedder=self.pipe.reference_embedder,
hidden_states_ref=hidden_states_ref,
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
if self.state_dict_converter is not None:
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
checkpoint.update(lora_state_dict)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_text_encoder_path",
type=str,
default=None,
required=True,
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
)
parser.add_argument(
"--pretrained_text_encoder_2_path",
type=str,
default=None,
required=True,
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
)
parser.add_argument(
"--pretrained_dit_path",
type=str,
default=None,
required=True,
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
)
parser.add_argument(
"--pretrained_vae_path",
type=str,
default=None,
required=True,
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--align_to_opensource_format",
default=False,
action="store_true",
help="Whether to export lora files aligned with other opensource format.",
)
parser.add_argument(
"--quantize",
type=str,
default=None,
choices=["float8_e4m3fn"],
help="Whether to use quantization when training the model, and in which format.",
)
parser.add_argument(
"--preset_lora_path",
type=str,
default=None,
help="Preset LoRA path.",
)
parser.add_argument(
"--num_nodes",
type=int,
default=1,
help="Num nodes.",
)
parser = add_general_parsers(parser)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = LightningModel(
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
preset_lora_path=args.preset_lora_path,
learning_rate=args.learning_rate,
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
pretrained_lora_path=args.pretrained_lora_path,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
)
# dataset and data loader
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
height=512, width=512,
),
],
dataset_weight=(4, 1, 4, 1),
steps_per_epoch=args.steps_per_epoch,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers
)
# train
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
num_nodes=args.num_nodes,
precision=args.precision,
strategy="ddp",
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=None,
)
trainer.fit(model=model, train_dataloaders=train_loader)