This commit is contained in:
Artiprocher
2025-05-06 17:53:56 +08:00
parent f7737aff98
commit 1ed676b076
3 changed files with 517 additions and 42 deletions

View File

@@ -680,6 +680,7 @@ def lets_dance_flux(
step1x_mask=None,
step1x_reference_latents=None,
tea_cache: TeaCache = None,
use_gradient_checkpointing=False,
**kwargs
):
if tiled:
@@ -774,20 +775,32 @@ def lets_dance_flux(
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
else:
tea_cache_update = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tea_cache_update:
hidden_states = tea_cache.update(hidden_states)
else:
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None),
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
@@ -796,14 +809,21 @@ def lets_dance_flux(
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]

99
test.py
View File

@@ -2,9 +2,10 @@ from transformers import AutoConfig, AutoTokenizer
import torch
from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys
from qwen_vl_utils import smart_resize
from PIL import Image
import numpy as np
@@ -15,6 +16,8 @@ class NexusGenQwenVLEncoder(torch.nn.Module):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device)
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
self.t2i_template = "Here is an image based on the description: <|vision_start|><|image_pad|><|vision_end|>"
self.i2i_template = "Here is the image: <|vision_start|><|image_pad|><|vision_end|>"
@staticmethod
def from_pretrained(model_path, torch_dtype="auto", device="cpu"):
@@ -35,34 +38,62 @@ class NexusGenQwenVLEncoder(torch.nn.Module):
images[j] = input_image.resize((resized_width, resized_height))
return images
def forward(self, prompt, images=None):
messages = [{
"role": "user",
"content": [{
"type": "text",
"text": prompt
},],
}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def forward(self, prompt, images=None, num_img_tokens=81):
messages = [
{
"role": "user",
"content": [{
"type": "text",
"text": prompt
},],
},
{
"role": "assistant",
"content": [{
"type": "text",
"text": self.t2i_template if images is None else self.i2i_template
},],
}
]
images = self.process_images(images)
target_image = Image.fromarray(np.zeros((252, 252, 3), dtype=np.uint8))
if images is None:
images = [target_image]
else:
images = images + [target_image]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
inputs = self.processor(
text=[text],
images=self.process_images(images),
images=images,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
generation_image_grid_thw = torch.tensor([[1, 18, 18]]).to(self.model.device)
outputs = self.model.generate(**inputs,
max_new_tokens=1024,
return_dict_in_generate=True,
generation_image_grid_thw=generation_image_grid_thw)
output_image_embeddings = outputs['output_image_embeddings']
input_embeds = self.model.model.embed_tokens(inputs['input_ids'])
image_embeds = self.model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
input_image_embeds = image_embeds[:-num_img_tokens]
image_mask = inputs['input_ids'] == self.model.config.image_token_id
indices = image_mask.cumsum(dim=1)
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
position_ids, _ = self.model.get_rope_index(inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids = position_ids.contiguous()
outputs = self.model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
output_image_embeddings = outputs.image_embeddings[:, :-1, :] # shift right
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
output_image_embeddings = output_image_embeddings.unsqueeze(0)
return output_image_embeddings
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
@@ -70,16 +101,32 @@ model_manager.load_models([
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
pipe.dit.load_state_dict(load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16), strict=False)
state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict, strict=False)
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
adapter.load_state_dict(state_dict, strict=False)
qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda")
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
adapter.load_state_dict(load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16), strict=False)
with torch.no_grad():
instruction = "<|vision_start|><|image_pad|><|vision_end|> Transform the style to flat anime. Keep the color."
emb = qwenvl(instruction, images=[Image.open("image_3.jpg").convert('RGB')])
instruction = "Generate an image according to the following description: a beautiful Asian girl"
emb = qwenvl(instruction, images=None)
emb = adapter(emb)
image = pipe("", image_emb=emb)
image.save("image_4.jpg")
image.save("image_1.jpg")
with torch.no_grad():
instruction = "<|vision_start|><|image_pad|><|vision_end|> Add sunglasses."
emb = qwenvl(instruction, images=[Image.open("image_1.jpg")])
emb = adapter(emb)
image = pipe("", image_emb=emb)
image.save("image_2.jpg")
with torch.no_grad():
instruction = "<|vision_start|><|image_pad|><|vision_end|> Let her smile."
emb = qwenvl(instruction, images=[Image.open("image_2.jpg")])
emb = adapter(emb)
image = pipe("", image_emb=emb)
image.save("image_3.jpg")

408
train.py Normal file
View File

@@ -0,0 +1,408 @@
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
from diffsynth.models.lora import FluxLoRAConverter
import torch, os, argparse
from diffsynth.pipelines.flux_image import lets_dance_flux
from accelerate import Accelerator
from tqdm import tqdm
import torch, os, json, torchvision
from PIL import Image
from torchvision.transforms import v2
from transformers import AutoConfig, AutoTokenizer
import torch
from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys
from qwen_vl_utils import smart_resize
from PIL import Image
import numpy as np
import lightning as pl
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class SingleTaskDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path,
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None
):
self.base_path = base_path
self.keys = keys
self.metadata = []
self.bad_data = []
self.height = height
self.width = width
self.random = random
self.steps_per_epoch = steps_per_epoch
self.image_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if metadata_path is None:
self.search_for_data("", report_data_log=True)
self.report_data_log()
else:
with open(metadata_path, "r", encoding="utf-8-sig") as f:
self.metadata = json.load(f)
def report_data_log(self):
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
def dump_metadata(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
def parse_json_file(self, absolute_path, relative_path):
data_list = []
with open(absolute_path, "r") as f:
metadata = json.load(f)
for image_1, image_2, instruction in self.keys:
image_1 = os.path.join(relative_path, metadata[image_1]) if image_1 is not None else None
image_2 = os.path.join(relative_path, metadata[image_2])
instruction = metadata[instruction]
data_list.append((image_1, image_2, instruction))
return data_list
def search_for_data(self, path, report_data_log=False):
now_path = os.path.join(self.base_path, path)
if os.path.isfile(now_path) and path.endswith(".json"):
try:
data_list = self.parse_json_file(now_path, os.path.dirname(path))
self.metadata.extend(data_list)
except:
self.bad_data.append(now_path)
elif os.path.isdir(now_path):
for sub_path in os.listdir(now_path):
self.search_for_data(os.path.join(path, sub_path))
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
self.report_data_log()
def load_image(self, image_path, skip_process=False):
image_path = os.path.join(self.base_path, image_path)
image = Image.open(image_path).convert("RGB")
if skip_process:
return image
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = self.image_process(image)
return image
def load_data(self, data_id):
image_1, image_2, instruction = self.metadata[data_id]
image_1 = self.load_image(image_1, skip_process=True) if image_1 is not None else None
image_2 = self.load_image(image_2)
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
def __getitem__(self, data_id):
if self.random:
while True:
try:
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
data = self.load_data(data_id)
return data
except:
continue
else:
return self.load_data(data_id)
def __len__(self):
return self.steps_per_epoch if self.random else len(self.metadata)
class MultiTaskDataset(torch.utils.data.Dataset):
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
self.dataset_list = dataset_list
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
self.steps_per_epoch = steps_per_epoch
def __getitem__(self, data_id):
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
data = self.dataset_list[dataset_id][data_id]
return data
def __len__(self):
return self.steps_per_epoch
class NexusGenQwenVLEncoder(torch.nn.Module):
def __init__(self, model_path, torch_dtype="auto", device="cpu"):
super().__init__()
model_config = AutoConfig.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device)
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
self.t2i_template = "Here is an image based on the description: <|vision_start|><|image_pad|><|vision_end|>"
self.i2i_template = "Here is the image: <|vision_start|><|image_pad|><|vision_end|>"
@staticmethod
def from_pretrained(model_path, torch_dtype="auto", device="cpu"):
return NexusGenQwenVLEncoder(model_path, torch_dtype=torch_dtype, device=device).eval()
def process_images(self, images=None):
if images is None:
return None
# resize input to max_pixels to avoid oom
for j in range(len(images)):
input_image = images[j]
input_w, input_h = input_image.size
resized_height, resized_width = smart_resize(
input_h,
input_w,
max_pixels=262640,
)
images[j] = input_image.resize((resized_width, resized_height))
return images
def forward(self, prompt, images=None, num_img_tokens=81):
messages = [
{
"role": "user",
"content": [{
"type": "text",
"text": prompt
},],
},
{
"role": "assistant",
"content": [{
"type": "text",
"text": self.t2i_template if images is None else self.i2i_template
},],
}
]
images = self.process_images(images)
target_image = Image.fromarray(np.zeros((252, 252, 3), dtype=np.uint8))
if images is None:
images = [target_image]
else:
images = images + [target_image]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
inputs = self.processor(
text=[text],
images=images,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
input_embeds = self.model.model.embed_tokens(inputs['input_ids'])
image_embeds = self.model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
input_image_embeds = image_embeds[:-num_img_tokens]
image_mask = inputs['input_ids'] == self.model.config.image_token_id
indices = image_mask.cumsum(dim=1)
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
position_ids, _ = self.model.get_rope_index(inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids = position_ids.contiguous()
outputs = self.model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
output_image_embeddings = outputs.image_embeddings[:, :-1, :] # shift right
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
output_image_embeddings = output_image_embeddings.unsqueeze(0)
return output_image_embeddings
class UnifiedModel(pl.LightningModule):
def __init__(self, flux_text_encoder_path, flux_vae_path, flux_dit_path, flux_decoder_path, qwenvl_path):
super().__init__()
# Load models
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([
flux_text_encoder_path,
flux_vae_path,
flux_dit_path
])
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
state_dict = load_state_dict(flux_decoder_path, torch_dtype=torch.bfloat16)
self.pipe.dit.load_state_dict(state_dict, strict=False)
self.adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16)
self.adapter.load_state_dict(state_dict, strict=False)
self.qwenvl = NexusGenQwenVLEncoder.from_pretrained(qwenvl_path)
self.pipe.vae_decoder.requires_grad_(False)
self.pipe.vae_encoder.requires_grad_(False)
self.pipe.text_encoder_1.requires_grad_(False)
self.qwenvl.requires_grad_(False)
self.pipe.scheduler.set_timesteps(1000, training=True)
def training_step(self, batch, batch_idx):
# Data
text, image = batch["instruction"], batch["image_2"]
image_ref = batch["image_1"]
image = image.unsqueeze(0)
# Prepare input parameters
self.pipe.device = self.device
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
if image_ref is None:
instruction = f"Generate an image according to the following description: {text}"
images_ref = None
else:
instruction = f"<|vision_start|><|image_pad|><|vision_end|> {text}"
images_ref = [image_ref]
emb = self.qwenvl(instruction, images=images_ref)
emb = self.adapter(emb)
prompt_emb = self.pipe.encode_prompt("", positive=True, image_emb=emb)
noise_pred = lets_dance_flux(
self.pipe.denoising_model(),
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
image_emb=emb,
use_gradient_checkpointing=True
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
return loss
def forward(self, batch):
return self.training_step(batch, 0)
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def search_for_last_checkpoint(path):
if not os.path.exists(path):
return None, 0
checkpoint_list = os.listdir(path)
checkpoint_list = [int(checkpoint.split("-")[1]) for checkpoint in checkpoint_list if checkpoint.startswith("epoch")]
if len(checkpoint_list) == 0:
return None, 0
else:
max_epoch_id = max(checkpoint_list)
return f"{path}/epoch-{max_epoch_id}/model.safetensors", max_epoch_id + 1
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="gradient_accumulation_steps",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=100,
help="steps_per_epoch",
)
parser.add_argument(
"--output_path",
type=str,
default="./models",
help="output_path",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="learning_rate",
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = UnifiedModel(
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
"models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin",
"models/DiffSynth-Studio/Nexus-Gen",
)
# dataset and data loader
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"data/example_dataset",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=512, width=512,
),
],
dataset_weight=(1,),
steps_per_epoch=args.steps_per_epoch * accelerator.num_processes,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=1,
collate_fn=lambda x: x[0]
)
# train
pretrained_path, start_epoch_id = search_for_last_checkpoint(args.output_path)
if pretrained_path is not None:
print(f"pretrained_path: {pretrained_path}")
model.load_state_dict(load_state_dict(pretrained_path, torch_dtype=torch.bfloat16), assign=True, strict=False)
model.to(torch.bfloat16)
model.to(accelerator.device)
trainable_modules = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=args.learning_rate)
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
for epoch in range(start_epoch_id, 100000):
for batch in tqdm(train_loader, desc=f"epoch-{epoch}", disable=not accelerator.is_local_main_process):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(batch)
accelerator.backward(loss)
optimizer.step()
path = args.output_path
os.makedirs(path, exist_ok=True)
accelerator.wait_for_everyone()
accelerator.save_model(model, f"{path}/epoch-{epoch}", max_shard_size="10GB", safe_serialization=True)