mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
8 Commits
value-cont
...
flux-ref
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a10635818a | ||
|
|
9ea29a769d | ||
|
|
2a5355b7cb | ||
|
|
7a06a58f49 | ||
|
|
a3b4f235a0 | ||
|
|
a572254a1d | ||
|
|
9e78bf5e89 | ||
|
|
d21676b4dc |
125
diffsynth/data/image_pulse.py
Normal file
125
diffsynth/data/image_pulse.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import torch, os, json, torchvision
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SingleTaskDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, base_path, keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")), height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None):
|
||||||
|
self.base_path = base_path
|
||||||
|
self.keys = keys
|
||||||
|
self.metadata = []
|
||||||
|
self.bad_data = []
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.random = random
|
||||||
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
self.image_process = v2.Compose([
|
||||||
|
v2.CenterCrop(size=(height, width)),
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
if metadata_path is None:
|
||||||
|
self.search_for_data("", report_data_log=True)
|
||||||
|
self.report_data_log()
|
||||||
|
else:
|
||||||
|
with open(metadata_path, "r", encoding="utf-8-sig") as f:
|
||||||
|
self.metadata = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def report_data_log(self):
|
||||||
|
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
|
||||||
|
|
||||||
|
|
||||||
|
def dump_metadata(self, path):
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_file(self, absolute_path, relative_path):
|
||||||
|
data_list = []
|
||||||
|
with open(absolute_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
for image_1, image_2, instruction in self.keys:
|
||||||
|
image_1 = os.path.join(relative_path, metadata[image_1])
|
||||||
|
image_2 = os.path.join(relative_path, metadata[image_2])
|
||||||
|
instruction = metadata[instruction]
|
||||||
|
data_list.append((image_1, image_2, instruction))
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
def search_for_data(self, path, report_data_log=False):
|
||||||
|
now_path = os.path.join(self.base_path, path)
|
||||||
|
if os.path.isfile(now_path) and path.endswith(".json"):
|
||||||
|
try:
|
||||||
|
data_list = self.parse_json_file(now_path, os.path.dirname(path))
|
||||||
|
self.metadata.extend(data_list)
|
||||||
|
except:
|
||||||
|
self.bad_data.append(now_path)
|
||||||
|
elif os.path.isdir(now_path):
|
||||||
|
for sub_path in os.listdir(now_path):
|
||||||
|
self.search_for_data(os.path.join(path, sub_path))
|
||||||
|
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
|
||||||
|
self.report_data_log()
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(self, image_path):
|
||||||
|
image_path = os.path.join(self.base_path, image_path)
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(self.width / width, self.height / height)
|
||||||
|
image = torchvision.transforms.functional.resize(
|
||||||
|
image,
|
||||||
|
(round(height*scale), round(width*scale)),
|
||||||
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
|
)
|
||||||
|
image = self.image_process(image)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(self, data_id):
|
||||||
|
image_1, image_2, instruction = self.metadata[data_id]
|
||||||
|
image_1 = self.load_image(image_1)
|
||||||
|
image_2 = self.load_image(image_2)
|
||||||
|
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
if self.random:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
|
||||||
|
data = self.load_data(data_id)
|
||||||
|
return data
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return self.load_data(data_id)
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.steps_per_epoch if self.random else len(self.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTaskDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
|
||||||
|
self.dataset_list = dataset_list
|
||||||
|
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
|
||||||
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
|
||||||
|
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
|
||||||
|
data = self.dataset_list[dataset_id][data_id]
|
||||||
|
return data
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.steps_per_epoch
|
||||||
@@ -20,10 +20,11 @@ class RoPEEmbedding(torch.nn.Module):
|
|||||||
self.axes_dim = axes_dim
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
|
||||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
def rope(self, pos: torch.Tensor, dim: int, theta: int, device="cpu") -> torch.Tensor:
|
||||||
assert dim % 2 == 0, "The dimension must be even."
|
assert dim % 2 == 0, "The dimension must be even."
|
||||||
|
|
||||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
|
scale = scale.to(device)
|
||||||
omega = 1.0 / (theta**scale)
|
omega = 1.0 / (theta**scale)
|
||||||
|
|
||||||
batch_size, seq_length = pos.shape
|
batch_size, seq_length = pos.shape
|
||||||
@@ -36,9 +37,9 @@ class RoPEEmbedding(torch.nn.Module):
|
|||||||
return out.float()
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, ids):
|
def forward(self, ids, device="cpu"):
|
||||||
n_axes = ids.shape[-1]
|
n_axes = ids.shape[-1]
|
||||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta, device) for i in range(n_axes)], dim=-3)
|
||||||
return emb.unsqueeze(1)
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
31
diffsynth/models/flux_reference_embedder.py
Normal file
31
diffsynth/models/flux_reference_embedder.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from .sd3_dit import TimestepEmbeddings
|
||||||
|
from .flux_dit import RoPEEmbedding
|
||||||
|
import torch
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
|
||||||
|
class FluxReferenceEmbedder(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
|
self.idx_embedder = TimestepEmbeddings(256, 256)
|
||||||
|
self.proj = torch.nn.Linear(3072, 3072)
|
||||||
|
|
||||||
|
def forward(self, image_ids, idx, dtype, device):
|
||||||
|
pos_emb = self.pos_embedder(image_ids, device=device)
|
||||||
|
idx_emb = self.idx_embedder(idx, dtype=dtype).to(device)
|
||||||
|
length = pos_emb.shape[2]
|
||||||
|
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
|
||||||
|
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
|
||||||
|
image_rotary_emb = pos_emb + idx_emb
|
||||||
|
return image_rotary_emb
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.idx_embedder.timestep_embedder[-1].load_state_dict({
|
||||||
|
"weight": torch.zeros((256, 256)),
|
||||||
|
"bias": torch.zeros((256,))
|
||||||
|
}),
|
||||||
|
self.proj.load_state_dict({
|
||||||
|
"weight": torch.eye(3072),
|
||||||
|
"bias": torch.zeros((3072,))
|
||||||
|
})
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
||||||
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||||
|
from ..models.flux_reference_embedder import FluxReferenceEmbedder
|
||||||
from ..prompters import FluxPrompter
|
from ..prompters import FluxPrompter
|
||||||
from ..schedulers import FlowMatchScheduler
|
from ..schedulers import FlowMatchScheduler
|
||||||
from .base import BasePipeline
|
from .base import BasePipeline
|
||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -32,6 +34,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.ipadapter: FluxIpAdapter = None
|
self.ipadapter: FluxIpAdapter = None
|
||||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||||
self.infinityou_processor: InfinitYou = None
|
self.infinityou_processor: InfinitYou = None
|
||||||
|
self.reference_embedder: FluxReferenceEmbedder = None
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||||
|
|
||||||
|
|
||||||
@@ -360,6 +363,20 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||||
else:
|
else:
|
||||||
return {}, controlnet_image
|
return {}, controlnet_image
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_reference_images(self, reference_images, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
if reference_images is not None:
|
||||||
|
hidden_states_ref = []
|
||||||
|
for reference_image in reference_images:
|
||||||
|
self.load_models_to_device(['vae_encoder'])
|
||||||
|
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
latents = self.encode_image(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
hidden_states_ref.append(latents)
|
||||||
|
hidden_states_ref = torch.concat(hidden_states_ref, dim=0)
|
||||||
|
return {"hidden_states_ref": hidden_states_ref}
|
||||||
|
else:
|
||||||
|
return {"hidden_states_ref": None}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -398,6 +415,8 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
# InfiniteYou
|
# InfiniteYou
|
||||||
infinityou_id_image=None,
|
infinityou_id_image=None,
|
||||||
infinityou_guidance=1.0,
|
infinityou_guidance=1.0,
|
||||||
|
# Reference images
|
||||||
|
reference_images=None,
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_l1_thresh=None,
|
tea_cache_l1_thresh=None,
|
||||||
# Tile
|
# Tile
|
||||||
@@ -436,6 +455,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
# ControlNets
|
# ControlNets
|
||||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||||
|
|
||||||
|
# Reference images
|
||||||
|
reference_kwargs = self.prepare_reference_images(reference_images, **tiler_kwargs)
|
||||||
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||||
@@ -447,9 +469,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
# Positive side
|
# Positive side
|
||||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **reference_kwargs,
|
||||||
)
|
)
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||||
@@ -464,9 +486,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
# Negative side
|
# Negative side
|
||||||
noise_pred_nega = lets_dance_flux(
|
noise_pred_nega = lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **reference_kwargs,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
@@ -586,6 +608,7 @@ class TeaCache:
|
|||||||
def lets_dance_flux(
|
def lets_dance_flux(
|
||||||
dit: FluxDiT,
|
dit: FluxDiT,
|
||||||
controlnet: FluxMultiControlNetManager = None,
|
controlnet: FluxMultiControlNetManager = None,
|
||||||
|
reference_embedder: FluxReferenceEmbedder = None,
|
||||||
hidden_states=None,
|
hidden_states=None,
|
||||||
timestep=None,
|
timestep=None,
|
||||||
prompt_emb=None,
|
prompt_emb=None,
|
||||||
@@ -594,6 +617,7 @@ def lets_dance_flux(
|
|||||||
text_ids=None,
|
text_ids=None,
|
||||||
image_ids=None,
|
image_ids=None,
|
||||||
controlnet_frames=None,
|
controlnet_frames=None,
|
||||||
|
hidden_states_ref=None,
|
||||||
tiled=False,
|
tiled=False,
|
||||||
tile_size=128,
|
tile_size=128,
|
||||||
tile_stride=64,
|
tile_stride=64,
|
||||||
@@ -603,6 +627,7 @@ def lets_dance_flux(
|
|||||||
id_emb=None,
|
id_emb=None,
|
||||||
infinityou_guidance=None,
|
infinityou_guidance=None,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if tiled:
|
if tiled:
|
||||||
@@ -669,28 +694,55 @@ def lets_dance_flux(
|
|||||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||||
else:
|
else:
|
||||||
prompt_emb = dit.context_embedder(prompt_emb)
|
prompt_emb = dit.context_embedder(prompt_emb)
|
||||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1), device=hidden_states.device)
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
|
|
||||||
|
# Reference images
|
||||||
|
if hidden_states_ref is not None:
|
||||||
|
# RoPE
|
||||||
|
image_ids_ref = dit.prepare_image_ids(hidden_states_ref)
|
||||||
|
idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100
|
||||||
|
image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2)
|
||||||
|
# hidden_states
|
||||||
|
original_hidden_states_length = hidden_states.shape[1]
|
||||||
|
hidden_states_ref = dit.patchify(hidden_states_ref)
|
||||||
|
hidden_states_ref = dit.x_embedder(hidden_states_ref)
|
||||||
|
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
|
||||||
|
hidden_states_ref = reference_embedder.proj(hidden_states_ref)
|
||||||
|
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
|
||||||
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
|
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
|
||||||
else:
|
else:
|
||||||
tea_cache_update = False
|
tea_cache_update = False
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
hidden_states = tea_cache.update(hidden_states)
|
hidden_states = tea_cache.update(hidden_states)
|
||||||
else:
|
else:
|
||||||
# Joint Blocks
|
# Joint Blocks
|
||||||
for block_id, block in enumerate(dit.blocks):
|
for block_id, block in enumerate(dit.blocks):
|
||||||
hidden_states, prompt_emb = block(
|
if use_gradient_checkpointing:
|
||||||
hidden_states,
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
prompt_emb,
|
create_custom_forward(block),
|
||||||
conditioning,
|
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None),
|
||||||
image_rotary_emb,
|
use_reentrant=False,
|
||||||
attention_mask,
|
)
|
||||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
else:
|
||||||
)
|
hidden_states, prompt_emb = block(
|
||||||
|
hidden_states,
|
||||||
|
prompt_emb,
|
||||||
|
conditioning,
|
||||||
|
image_rotary_emb,
|
||||||
|
attention_mask,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||||
|
)
|
||||||
# ControlNet
|
# ControlNet
|
||||||
if controlnet is not None and controlnet_frames is not None:
|
if controlnet is not None and controlnet_frames is not None:
|
||||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||||
@@ -699,14 +751,21 @@ def lets_dance_flux(
|
|||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||||
num_joint_blocks = len(dit.blocks)
|
num_joint_blocks = len(dit.blocks)
|
||||||
for block_id, block in enumerate(dit.single_blocks):
|
for block_id, block in enumerate(dit.single_blocks):
|
||||||
hidden_states, prompt_emb = block(
|
if use_gradient_checkpointing:
|
||||||
hidden_states,
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
prompt_emb,
|
create_custom_forward(block),
|
||||||
conditioning,
|
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
|
||||||
image_rotary_emb,
|
use_reentrant=False,
|
||||||
attention_mask,
|
)
|
||||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
else:
|
||||||
)
|
hidden_states, prompt_emb = block(
|
||||||
|
hidden_states,
|
||||||
|
prompt_emb,
|
||||||
|
conditioning,
|
||||||
|
image_rotary_emb,
|
||||||
|
attention_mask,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||||
|
)
|
||||||
# ControlNet
|
# ControlNet
|
||||||
if controlnet is not None and controlnet_frames is not None:
|
if controlnet is not None and controlnet_frames is not None:
|
||||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||||
@@ -715,6 +774,8 @@ def lets_dance_flux(
|
|||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(hidden_states)
|
tea_cache.store(hidden_states)
|
||||||
|
|
||||||
|
if hidden_states_ref is not None:
|
||||||
|
hidden_states = hidden_states[:, :original_hidden_states_length]
|
||||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||||
hidden_states = dit.final_proj_out(hidden_states)
|
hidden_states = dit.final_proj_out(hidden_states)
|
||||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||||
|
|||||||
241
train_flux_reference.py
Normal file
241
train_flux_reference.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
from diffsynth import ModelManager, FluxImagePipeline
|
||||||
|
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
|
||||||
|
from diffsynth.models.lora import FluxLoRAConverter
|
||||||
|
import torch, os, argparse
|
||||||
|
import lightning as pl
|
||||||
|
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
|
||||||
|
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||||
|
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||||
|
|
||||||
|
|
||||||
|
class LightningModel(LightningModelForT2ILoRA):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
|
||||||
|
learning_rate=1e-4, use_gradient_checkpointing=True,
|
||||||
|
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
|
||||||
|
state_dict_converter=None, quantize = None
|
||||||
|
):
|
||||||
|
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||||
|
if quantize is None:
|
||||||
|
model_manager.load_models(pretrained_weights)
|
||||||
|
else:
|
||||||
|
model_manager.load_models(pretrained_weights[1:])
|
||||||
|
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
||||||
|
if preset_lora_path is not None:
|
||||||
|
preset_lora_path = preset_lora_path.split(",")
|
||||||
|
for path in preset_lora_path:
|
||||||
|
model_manager.load_lora(path)
|
||||||
|
|
||||||
|
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
self.pipe.reference_embedder = FluxReferenceEmbedder()
|
||||||
|
self.pipe.reference_embedder.init()
|
||||||
|
|
||||||
|
if quantize is not None:
|
||||||
|
self.pipe.dit.quantize()
|
||||||
|
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
self.freeze_parameters()
|
||||||
|
self.pipe.reference_embedder.requires_grad_(True)
|
||||||
|
self.pipe.reference_embedder.train()
|
||||||
|
self.pipe.dit.requires_grad_(True)
|
||||||
|
self.pipe.dit.train()
|
||||||
|
# self.add_lora_to_model(
|
||||||
|
# self.pipe.denoising_model(),
|
||||||
|
# lora_rank=lora_rank,
|
||||||
|
# lora_alpha=lora_alpha,
|
||||||
|
# lora_target_modules=lora_target_modules,
|
||||||
|
# init_lora_weights=init_lora_weights,
|
||||||
|
# pretrained_lora_path=pretrained_lora_path,
|
||||||
|
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# Data
|
||||||
|
text, image = batch["instruction"], batch["image_2"]
|
||||||
|
image_ref = batch["image_1"]
|
||||||
|
|
||||||
|
# Prepare input parameters
|
||||||
|
self.pipe.device = self.device
|
||||||
|
prompt_emb = self.pipe.encode_prompt(text, positive=True)
|
||||||
|
if "latents" in batch:
|
||||||
|
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||||
|
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||||
|
extra_input = self.pipe.prepare_extra_input(latents)
|
||||||
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
|
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||||
|
|
||||||
|
# Reference image
|
||||||
|
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
noise_pred = lets_dance_flux(
|
||||||
|
self.pipe.denoising_model(),
|
||||||
|
reference_embedder=self.pipe.reference_embedder,
|
||||||
|
hidden_states_ref=hidden_states_ref,
|
||||||
|
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||||
|
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||||
|
)
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||||
|
|
||||||
|
# Record log
|
||||||
|
self.log("train_loss", loss, prog_bar=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
|
||||||
|
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
checkpoint.clear()
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
state_dict = self.pipe.state_dict()
|
||||||
|
lora_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in trainable_param_names:
|
||||||
|
lora_state_dict[name] = param
|
||||||
|
if self.state_dict_converter is not None:
|
||||||
|
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
|
||||||
|
checkpoint.update(lora_state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_text_encoder_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_text_encoder_2_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_dit_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_vae_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_target_modules",
|
||||||
|
type=str,
|
||||||
|
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
|
||||||
|
help="Layers with LoRA modules.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--align_to_opensource_format",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to export lora files aligned with other opensource format.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["float8_e4m3fn"],
|
||||||
|
help="Whether to use quantization when training the model, and in which format.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preset_lora_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Preset LoRA path.",
|
||||||
|
)
|
||||||
|
parser = add_general_parsers(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
model = LightningModel(
|
||||||
|
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
|
||||||
|
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
learning_rate=args.learning_rate,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
init_lora_weights=args.init_lora_weights,
|
||||||
|
pretrained_lora_path=args.pretrained_lora_path,
|
||||||
|
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
|
||||||
|
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
|
||||||
|
)
|
||||||
|
# dataset and data loader
|
||||||
|
dataset = MultiTaskDataset(
|
||||||
|
dataset_list=[
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||||
|
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dataset_weight=(4, 1, 4, 1),
|
||||||
|
steps_per_epoch=args.steps_per_epoch,
|
||||||
|
)
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=args.dataloader_num_workers
|
||||||
|
)
|
||||||
|
# train
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
precision=args.precision,
|
||||||
|
strategy=args.training_strategy,
|
||||||
|
default_root_dir=args.output_path,
|
||||||
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||||
|
logger=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||||
248
train_flux_reference_multi_node.py
Normal file
248
train_flux_reference_multi_node.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
from diffsynth import ModelManager, FluxImagePipeline
|
||||||
|
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
|
||||||
|
from diffsynth.models.lora import FluxLoRAConverter
|
||||||
|
import torch, os, argparse
|
||||||
|
import lightning as pl
|
||||||
|
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
|
||||||
|
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||||
|
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||||
|
|
||||||
|
|
||||||
|
class LightningModel(LightningModelForT2ILoRA):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
|
||||||
|
learning_rate=1e-4, use_gradient_checkpointing=True,
|
||||||
|
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
|
||||||
|
state_dict_converter=None, quantize = None
|
||||||
|
):
|
||||||
|
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||||
|
if quantize is None:
|
||||||
|
model_manager.load_models(pretrained_weights)
|
||||||
|
else:
|
||||||
|
model_manager.load_models(pretrained_weights[1:])
|
||||||
|
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
||||||
|
if preset_lora_path is not None:
|
||||||
|
preset_lora_path = preset_lora_path.split(",")
|
||||||
|
for path in preset_lora_path:
|
||||||
|
model_manager.load_lora(path)
|
||||||
|
|
||||||
|
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
self.pipe.reference_embedder = FluxReferenceEmbedder()
|
||||||
|
self.pipe.reference_embedder.init()
|
||||||
|
|
||||||
|
if quantize is not None:
|
||||||
|
self.pipe.dit.quantize()
|
||||||
|
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
self.freeze_parameters()
|
||||||
|
self.pipe.reference_embedder.requires_grad_(True)
|
||||||
|
self.pipe.reference_embedder.train()
|
||||||
|
self.pipe.dit.requires_grad_(True)
|
||||||
|
self.pipe.dit.train()
|
||||||
|
# self.add_lora_to_model(
|
||||||
|
# self.pipe.denoising_model(),
|
||||||
|
# lora_rank=lora_rank,
|
||||||
|
# lora_alpha=lora_alpha,
|
||||||
|
# lora_target_modules=lora_target_modules,
|
||||||
|
# init_lora_weights=init_lora_weights,
|
||||||
|
# pretrained_lora_path=pretrained_lora_path,
|
||||||
|
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# Data
|
||||||
|
text, image = batch["instruction"], batch["image_2"]
|
||||||
|
image_ref = batch["image_1"]
|
||||||
|
|
||||||
|
# Prepare input parameters
|
||||||
|
self.pipe.device = self.device
|
||||||
|
prompt_emb = self.pipe.encode_prompt(text, positive=True)
|
||||||
|
if "latents" in batch:
|
||||||
|
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||||
|
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||||
|
extra_input = self.pipe.prepare_extra_input(latents)
|
||||||
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
|
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||||
|
|
||||||
|
# Reference image
|
||||||
|
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
noise_pred = lets_dance_flux(
|
||||||
|
self.pipe.denoising_model(),
|
||||||
|
reference_embedder=self.pipe.reference_embedder,
|
||||||
|
hidden_states_ref=hidden_states_ref,
|
||||||
|
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||||
|
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||||
|
)
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||||
|
|
||||||
|
# Record log
|
||||||
|
self.log("train_loss", loss, prog_bar=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
|
||||||
|
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
checkpoint.clear()
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
state_dict = self.pipe.state_dict()
|
||||||
|
lora_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in trainable_param_names:
|
||||||
|
lora_state_dict[name] = param
|
||||||
|
if self.state_dict_converter is not None:
|
||||||
|
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
|
||||||
|
checkpoint.update(lora_state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_text_encoder_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_text_encoder_2_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_dit_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_vae_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_target_modules",
|
||||||
|
type=str,
|
||||||
|
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
|
||||||
|
help="Layers with LoRA modules.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--align_to_opensource_format",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to export lora files aligned with other opensource format.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["float8_e4m3fn"],
|
||||||
|
help="Whether to use quantization when training the model, and in which format.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preset_lora_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Preset LoRA path.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_nodes",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Num nodes.",
|
||||||
|
)
|
||||||
|
parser = add_general_parsers(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
model = LightningModel(
|
||||||
|
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
|
||||||
|
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
learning_rate=args.learning_rate,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
init_lora_weights=args.init_lora_weights,
|
||||||
|
pretrained_lora_path=args.pretrained_lora_path,
|
||||||
|
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
|
||||||
|
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
|
||||||
|
)
|
||||||
|
# dataset and data loader
|
||||||
|
dataset = MultiTaskDataset(
|
||||||
|
dataset_list=[
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||||
|
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dataset_weight=(4, 1, 4, 1),
|
||||||
|
steps_per_epoch=args.steps_per_epoch,
|
||||||
|
)
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=args.dataloader_num_workers
|
||||||
|
)
|
||||||
|
# train
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
num_nodes=args.num_nodes,
|
||||||
|
precision=args.precision,
|
||||||
|
strategy="ddp",
|
||||||
|
default_root_dir=args.output_path,
|
||||||
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||||
|
logger=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||||
Reference in New Issue
Block a user