mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
24 Commits
qwen-image
...
fp8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17714a8cc8 | ||
|
|
a947459bda | ||
|
|
a0eec8c673 | ||
|
|
4f7c3b6a1e | ||
|
|
57128dc89f | ||
|
|
d20680baae | ||
|
|
970403f78e | ||
|
|
bee2a969e5 | ||
|
|
2803ffcb38 | ||
|
|
d3224e1fdc | ||
|
|
3c2f85606f | ||
|
|
1f25ad416b | ||
|
|
d0b9b25db7 | ||
|
|
ef09db69cd | ||
|
|
84ede171fd | ||
|
|
6f4e38276e | ||
|
|
a3b67436a6 | ||
|
|
829ca3414b | ||
|
|
3915bc3ee6 | ||
|
|
4299c999b5 | ||
|
|
6bae70eee0 | ||
|
|
6452edb738 | ||
|
|
26461c1963 | ||
|
|
0412fc7232 |
@@ -91,7 +91,7 @@ image.save("image.jpg")
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
</details>
|
||||
|
||||
### FLUX Series
|
||||
@@ -363,6 +363,7 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
|
||||
|
||||
## Update History
|
||||
- **August 7, 2025** We open-sourced the entity control LoRA of Qwen-Image, [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen). Qwen-Image-EliGen is able to achieve entity-level controlled text-to-image generation. See the [paper](https://arxiv.org/abs/2501.01097) for technical details. Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
|
||||
|
||||
- **August 5, 2025** We open-sourced the distilled acceleration model of Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), achieving approximately 5x speedup.
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ image.save("image.jpg")
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -379,6 +380,7 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
|
||||
|
||||
## 更新历史
|
||||
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
|
||||
|
||||
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
|
||||
|
||||
|
||||
382
apps/gradio/qwen_image_eligen.py
Normal file
382
apps/gradio/qwen_image_eligen.py
Normal file
@@ -0,0 +1,382 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import random
|
||||
import json
|
||||
import gradio as gr
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
|
||||
# pip install pydantic==2.10.6
|
||||
# pip install gradio==5.4.0
|
||||
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors")
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/*")
|
||||
example_json = 'data/examples/eligen/qwen-image/ui_examples.json'
|
||||
with open(example_json, 'r') as f:
|
||||
examples = json.load(f)['examples']
|
||||
|
||||
for idx in range(len(examples)):
|
||||
example_id = examples[idx]['example_id']
|
||||
entity_prompts = examples[idx]['local_prompt_list']
|
||||
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
|
||||
def create_canvas_data(background, masks):
|
||||
if background.shape[-1] == 3:
|
||||
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
|
||||
layers = []
|
||||
for mask in masks:
|
||||
if mask is not None:
|
||||
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
|
||||
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
|
||||
layer[..., -1] = mask_single_channel
|
||||
layers.append(layer)
|
||||
else:
|
||||
layers.append(np.zeros_like(background))
|
||||
|
||||
composite = background.copy()
|
||||
for layer in layers:
|
||||
if layer.size > 0:
|
||||
composite = np.where(layer[..., -1:] > 0, layer, composite)
|
||||
return {
|
||||
"background": background,
|
||||
"layers": layers,
|
||||
"composite": composite,
|
||||
}
|
||||
|
||||
def load_example(load_example_button):
|
||||
example_idx = int(load_example_button.split()[-1]) - 1
|
||||
example = examples[example_idx]
|
||||
result = [
|
||||
50,
|
||||
example["global_prompt"],
|
||||
example["negative_prompt"],
|
||||
example["seed"],
|
||||
*example["local_prompt_list"],
|
||||
]
|
||||
num_entities = len(example["local_prompt_list"])
|
||||
result += [""] * (config["max_num_painter_layers"] - num_entities)
|
||||
masks = []
|
||||
for mask in example["mask_lists"]:
|
||||
mask_single_channel = np.array(mask.convert("L"))
|
||||
masks.append(mask_single_channel)
|
||||
for _ in range(config["max_num_painter_layers"] - len(masks)):
|
||||
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
|
||||
masks.append(blank_mask)
|
||||
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
|
||||
canvas_data_list = []
|
||||
for mask in masks:
|
||||
canvas_data = create_canvas_data(background, [mask])
|
||||
canvas_data_list.append(canvas_data)
|
||||
result.extend(canvas_data_list)
|
||||
return result
|
||||
|
||||
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
||||
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
||||
print(f'save to {save_dir}')
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
for i, mask in enumerate(masks):
|
||||
save_path = os.path.join(save_dir, f'{i}.png')
|
||||
mask.save(save_path)
|
||||
sample = {
|
||||
"global_prompt": global_prompt,
|
||||
"mask_prompts": mask_prompts,
|
||||
"seed": seed,
|
||||
}
|
||||
with open(os.path.join(save_dir, f"prompts.json"), 'w', encoding='utf-8') as f:
|
||||
json.dump(sample, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
if mask is None:
|
||||
continue
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
if mask_bbox is None:
|
||||
continue
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
return result
|
||||
|
||||
config = {
|
||||
"max_num_painter_layers": 8,
|
||||
"max_num_model_cache": 1,
|
||||
}
|
||||
|
||||
model_dict = {}
|
||||
|
||||
def load_model(model_type='qwen-image'):
|
||||
global model_dict
|
||||
model_key = f"{model_type}"
|
||||
if model_key in model_dict:
|
||||
return model_dict[model_key]
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors")
|
||||
model_dict[model_key] = pipe
|
||||
return pipe
|
||||
|
||||
load_model('qwen-image')
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
"""## EliGen: Entity-Level Controllable Text-to-Image Model
|
||||
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
|
||||
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
|
||||
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
|
||||
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
|
||||
"""
|
||||
)
|
||||
|
||||
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
|
||||
main_interface = gr.Column(visible=False)
|
||||
|
||||
def initialize_model():
|
||||
try:
|
||||
load_model('qwen-image')
|
||||
return {
|
||||
loading_status: gr.update(value="Model loaded successfully!", visible=False),
|
||||
main_interface: gr.update(visible=True),
|
||||
}
|
||||
except Exception as e:
|
||||
print(f'Failed to load model with error: {e}')
|
||||
return {
|
||||
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
|
||||
main_interface: gr.update(visible=True),
|
||||
}
|
||||
|
||||
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
|
||||
|
||||
with main_interface:
|
||||
with gr.Row():
|
||||
local_prompt_list = []
|
||||
canvas_list = []
|
||||
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
|
||||
with gr.Column(scale=382, min_width=100):
|
||||
model_type = gr.State('qwen-image')
|
||||
with gr.Accordion(label="Global prompt"):
|
||||
prompt = gr.Textbox(label="Global Prompt", lines=3)
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", value="", lines=3)
|
||||
with gr.Accordion(label="Inference Options", open=True):
|
||||
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
|
||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
|
||||
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=4.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
||||
with gr.Accordion(label="Inpaint Input Image", open=False, visible=False):
|
||||
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
|
||||
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
|
||||
|
||||
with gr.Column():
|
||||
reset_input_button = gr.Button(value="Reset Inpaint Input")
|
||||
send_input_to_painter = gr.Button(value="Set as painter's background")
|
||||
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
|
||||
def reset_input_image(input_image):
|
||||
return None
|
||||
|
||||
with gr.Column(scale=618, min_width=100):
|
||||
with gr.Accordion(label="Entity Painter"):
|
||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
||||
with gr.Tab(label=f"Entity {painter_layer_id}"):
|
||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
||||
canvas = gr.ImageEditor(
|
||||
canvas_size=(1024, 1024),
|
||||
sources=None,
|
||||
layers=False,
|
||||
interactive=True,
|
||||
image_mode="RGBA",
|
||||
brush=gr.Brush(
|
||||
default_size=50,
|
||||
default_color="#000000",
|
||||
colors=["#000000"],
|
||||
),
|
||||
label="Entity Mask Painter",
|
||||
key=f"canvas_{painter_layer_id}",
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
|
||||
def resize_canvas(height, width, canvas):
|
||||
if canvas is None or canvas["background"] is None:
|
||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||
h, w = canvas["background"].shape[:2]
|
||||
if h != height or width != w:
|
||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||
else:
|
||||
return canvas
|
||||
local_prompt_list.append(local_prompt)
|
||||
canvas_list.append(canvas)
|
||||
with gr.Accordion(label="Results"):
|
||||
run_button = gr.Button(value="Generate", variant="primary")
|
||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
||||
with gr.Column():
|
||||
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
|
||||
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
|
||||
real_output = gr.State(None)
|
||||
mask_out = gr.State(None)
|
||||
|
||||
@gr.on(
|
||||
inputs=[model_type, prompt, negative_prompt, cfg_scale, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
|
||||
outputs=[output_image, real_output, mask_out],
|
||||
triggers=run_button.click
|
||||
)
|
||||
def generate_image(model_type, prompt, negative_prompt, cfg_scale, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
||||
pipe = load_model(model_type)
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"cfg_scale": cfg_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"progress_bar_cmd": progress.tqdm,
|
||||
}
|
||||
# if input_image is not None:
|
||||
# input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
||||
# input_params["enable_eligen_inpaint"] = True
|
||||
|
||||
local_prompt_list, canvas_list = (
|
||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
||||
)
|
||||
local_prompts, masks = [], []
|
||||
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
|
||||
if isinstance(local_prompt, str) and len(local_prompt) > 0:
|
||||
local_prompts.append(local_prompt)
|
||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
||||
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
||||
entity_masks = None if len(masks) == 0 or entity_prompts is None else masks
|
||||
input_params.update({
|
||||
"eligen_entity_prompts": entity_prompts,
|
||||
"eligen_entity_masks": entity_masks,
|
||||
})
|
||||
torch.manual_seed(seed)
|
||||
save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
|
||||
image = pipe(**input_params)
|
||||
masks = [mask.resize(image.size) for mask in masks]
|
||||
image_with_mask = visualize_masks(image, masks, local_prompts)
|
||||
|
||||
real_output = gr.State(image)
|
||||
mask_out = gr.State(image_with_mask)
|
||||
|
||||
if return_with_mask:
|
||||
return image_with_mask, real_output, mask_out
|
||||
return image, real_output, mask_out
|
||||
|
||||
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
|
||||
def send_input_to_painter_background(input_image, *canvas_list):
|
||||
if input_image is None:
|
||||
return tuple(canvas_list)
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = input_image.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
||||
def send_output_to_painter_background(real_output, *canvas_list):
|
||||
if real_output is None:
|
||||
return tuple(canvas_list)
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = real_output.value.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
|
||||
def show_output(return_with_mask, real_output, mask_out):
|
||||
if return_with_mask:
|
||||
return mask_out.value
|
||||
else:
|
||||
return real_output.value
|
||||
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
|
||||
def send_output_to_pipe_input(real_output):
|
||||
return real_output.value
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## Examples")
|
||||
for i in range(0, len(examples), 2):
|
||||
with gr.Row():
|
||||
if i < len(examples):
|
||||
example = examples[i]
|
||||
with gr.Column():
|
||||
example_image = gr.Image(
|
||||
value=f"data/examples/eligen/qwen-image/example_{example['example_id']}/example_image.png",
|
||||
label=example["description"],
|
||||
interactive=False,
|
||||
width=1024,
|
||||
height=512
|
||||
)
|
||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||
load_example_button.click(
|
||||
load_example,
|
||||
inputs=[load_example_button],
|
||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||
)
|
||||
|
||||
if i + 1 < len(examples):
|
||||
example = examples[i + 1]
|
||||
with gr.Column():
|
||||
example_image = gr.Image(
|
||||
value=f"data/examples/eligen/qwen-image/example_{example['example_id']}/example_image.png",
|
||||
label=example["description"],
|
||||
interactive=False,
|
||||
width=1024,
|
||||
height=512
|
||||
)
|
||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||
load_example_button.click(
|
||||
load_example,
|
||||
inputs=[load_example_button],
|
||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||
)
|
||||
app.config["show_progress"] = "hidden"
|
||||
app.launch(share=False)
|
||||
@@ -75,7 +75,6 @@ from ..models.nexus_gen import NexusGenAutoregressiveModel
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..models.qwen_image_controlnet import QwenImageControlNet
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -168,7 +167,6 @@ model_loader_configs = [
|
||||
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
|
||||
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
|
||||
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
||||
(None, "6834bf9ef1a6723291d82719fb02e953", ["qwen_image_controlnet"], [QwenImageControlNet], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
@@ -375,8 +375,7 @@ class FluxDiT(torch.nn.Module):
|
||||
return attention_mask
|
||||
|
||||
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
||||
repeat_dim = hidden_states.shape[1]
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
|
||||
max_masks = 0
|
||||
attention_mask = None
|
||||
prompt_embs = [prompt_emb]
|
||||
|
||||
@@ -383,5 +383,20 @@ class WanLoRAConverter:
|
||||
return state_dict
|
||||
|
||||
|
||||
class QwenImageLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .qwen_image_dit import QwenEmbedRope, QwenImageTransformerBlock
|
||||
from ..vram_management import gradient_checkpoint_forward
|
||||
from einops import rearrange
|
||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
|
||||
|
||||
class QwenImageControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
num_controlnet_layers: int = 6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||
|
||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
||||
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
||||
|
||||
self.img_in = nn.Linear(64 * 2, 3072)
|
||||
self.txt_in = nn.Linear(3584, 3072)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
QwenImageTransformerBlock(
|
||||
dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
)
|
||||
for _ in range(num_controlnet_layers)
|
||||
]
|
||||
)
|
||||
self.alpha = torch.nn.Parameter(torch.zeros((num_layers,)))
|
||||
self.num_layers = num_layers
|
||||
self.num_controlnet_layers = num_controlnet_layers
|
||||
self.align_map = {i: i // (num_layers // num_controlnet_layers) for i in range(num_layers)}
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
prompt_emb_mask=None,
|
||||
height=None,
|
||||
width=None,
|
||||
controlnet_conditioning=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2)
|
||||
controlnet_conditioning = rearrange(controlnet_conditioning, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image = torch.concat([image, controlnet_conditioning], dim=-1)
|
||||
|
||||
image = self.img_in(image)
|
||||
text = self.txt_in(self.txt_norm(prompt_emb))
|
||||
|
||||
conditioning = self.time_text_embed(timestep, image.dtype)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
|
||||
outputs = []
|
||||
for block in self.transformer_blocks:
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
outputs.append(image)
|
||||
|
||||
alpha = self.alpha.to(dtype=image.dtype, device=image.device)
|
||||
outputs_aligned = [outputs[self.align_map[i]] * alpha[i] for i in range(self.num_layers)]
|
||||
return outputs_aligned
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageControlNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageControlNetStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -1,10 +1,44 @@
|
||||
import torch
|
||||
import torch, math
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional, Union, List
|
||||
from einops import rearrange
|
||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .flux_dit import AdaLayerNorm
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
|
||||
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
|
||||
if FLASH_ATTN_3_AVAILABLE:
|
||||
if not enable_fp8_attention:
|
||||
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v)
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
origin_dtype = q.dtype
|
||||
q_std, k_std, v_std = q.std(), k.std(), v.std()
|
||||
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
|
||||
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
x = x.to(origin_dtype) * v_std
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
return x
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
@@ -158,7 +192,9 @@ class QwenDoubleStreamAttention(nn.Module):
|
||||
self,
|
||||
image: torch.FloatTensor,
|
||||
text: torch.FloatTensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
enable_fp8_attention: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
||||
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
||||
@@ -186,9 +222,7 @@ class QwenDoubleStreamAttention(nn.Module):
|
||||
joint_k = torch.cat([txt_k, img_k], dim=2)
|
||||
joint_v = torch.cat([txt_v, img_v], dim=2)
|
||||
|
||||
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v)
|
||||
|
||||
joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
|
||||
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
|
||||
|
||||
txt_attn_output = joint_attn_out[:, :seq_txt, :]
|
||||
img_attn_output = joint_attn_out[:, seq_txt:, :]
|
||||
@@ -245,6 +279,8 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
text: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
enable_fp8_attention = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
@@ -260,6 +296,8 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
image=img_modulated,
|
||||
text=txt_modulated,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
enable_fp8_attention = False,
|
||||
)
|
||||
|
||||
image = image + img_gate * img_attn_out
|
||||
@@ -309,6 +347,69 @@ class QwenImageDiT(torch.nn.Module):
|
||||
self.proj_out = nn.Linear(3072, 64)
|
||||
|
||||
|
||||
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
|
||||
# prompt_emb
|
||||
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
||||
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||
|
||||
# image_rotary_emb
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
|
||||
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
|
||||
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||
|
||||
# attention_mask
|
||||
repeat_dim = latents.shape[1]
|
||||
max_masks = entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
entity_masks = entity_masks + [global_mask]
|
||||
|
||||
N = len(entity_masks)
|
||||
batch_size = entity_masks[0].shape[0]
|
||||
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
|
||||
total_seq_len = sum(seq_lens) + image.shape[1]
|
||||
patched_masks = []
|
||||
for i in range(N):
|
||||
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
patched_masks.append(patched_mask)
|
||||
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||
|
||||
# prompt-image attention mask
|
||||
image_start = sum(seq_lens)
|
||||
image_end = total_seq_len
|
||||
cumsum = [0]
|
||||
for length in seq_lens:
|
||||
cumsum.append(cumsum[-1] + length)
|
||||
for i in range(N):
|
||||
prompt_start = cumsum[i]
|
||||
prompt_end = cumsum[i+1]
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
# prompt-prompt attention mask, let the prompt tokens not attend to each other
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i == j:
|
||||
continue
|
||||
start_i, end_i = cumsum[i], cumsum[i+1]
|
||||
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||
|
||||
attention_mask = attention_mask.float()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||
|
||||
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
|
||||
@@ -335,7 +335,7 @@ class WanModel(torch.nn.Module):
|
||||
else:
|
||||
self.control_adapter = None
|
||||
|
||||
def patchify(self, x: torch.Tensor,control_camera_latents_input: torch.Tensor = None):
|
||||
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
|
||||
x = self.patch_embedding(x)
|
||||
if self.control_adapter is not None and control_camera_latents_input is not None:
|
||||
y_camera = self.control_adapter(control_camera_latents_input)
|
||||
|
||||
@@ -762,7 +762,7 @@ def lets_dance_flux(
|
||||
hidden_states = dit.x_embedder(hidden_states)
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, 16)
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
@@ -1233,7 +1233,7 @@ def model_fn_flux_image(
|
||||
|
||||
# EliGen
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1])
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
@@ -4,54 +4,19 @@ from typing import Union
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..models.qwen_image_controlnet import QwenImageControlNet
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora import GeneralLoRALoader
|
||||
from .flux_image_new import ControlNetInput
|
||||
|
||||
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
|
||||
|
||||
class QwenImageMultiControlNet(torch.nn.Module):
|
||||
def __init__(self, models: list[QwenImageControlNet]):
|
||||
super().__init__()
|
||||
if not isinstance(models, list):
|
||||
models = [models]
|
||||
self.models = torch.nn.ModuleList(models)
|
||||
|
||||
def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
|
||||
model = self.models[controlnet_input.controlnet_id]
|
||||
res_stack = model(
|
||||
controlnet_conditioning=conditioning,
|
||||
processor_id=controlnet_input.processor_id,
|
||||
**kwargs
|
||||
)
|
||||
res_stack = [res * controlnet_input.scale for res in res_stack]
|
||||
return res_stack
|
||||
|
||||
def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
|
||||
res_stack = None
|
||||
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
|
||||
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
|
||||
if progress > controlnet_input.start or progress < controlnet_input.end:
|
||||
continue
|
||||
res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
|
||||
if res_stack is None:
|
||||
res_stack = res_stack_
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
return res_stack
|
||||
|
||||
|
||||
|
||||
class QwenImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
@@ -65,16 +30,15 @@ class QwenImagePipeline(BasePipeline):
|
||||
self.text_encoder: QwenImageTextEncoder = None
|
||||
self.dit: QwenImageDiT = None
|
||||
self.vae: QwenImageVAE = None
|
||||
self.controlnet: QwenImageMultiControlNet = None
|
||||
self.tokenizer: Qwen2Tokenizer = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit", "controlnet")
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
QwenImageUnit_ShapeChecker(),
|
||||
QwenImageUnit_NoiseInitializer(),
|
||||
QwenImageUnit_InputImageEmbedder(),
|
||||
QwenImageUnit_PromptEmbedder(),
|
||||
QwenImageUnit_ControlNet(),
|
||||
QwenImageUnit_EntityControl(),
|
||||
]
|
||||
self.model_fn = model_fn_qwen_image
|
||||
|
||||
@@ -99,14 +63,12 @@ class QwenImagePipeline(BasePipeline):
|
||||
return loss
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
|
||||
self.vram_management_enabled = True
|
||||
if num_persistent_param_in_dit is not None:
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
|
||||
if self.text_encoder is not None:
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
|
||||
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||
@@ -132,31 +94,54 @@ class QwenImagePipeline(BasePipeline):
|
||||
from ..models.qwen_image_dit import RMSNorm
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if not enable_dit_fp8_computation:
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
else:
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae is not None:
|
||||
from ..models.qwen_image_vae import QwenImageRMS_norm
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
@@ -202,7 +187,6 @@ class QwenImagePipeline(BasePipeline):
|
||||
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
||||
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
||||
pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all"))
|
||||
if tokenizer_config is not None and pipe.text_encoder is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
from transformers import Qwen2Tokenizer
|
||||
@@ -228,8 +212,10 @@ class QwenImagePipeline(BasePipeline):
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# ControlNet
|
||||
controlnet_inputs: list[ControlNetInput] = None,
|
||||
# EliGen
|
||||
eligen_entity_prompts: list[str] = None,
|
||||
eligen_entity_masks: list[Image.Image] = None,
|
||||
eligen_enable_on_negative: bool = False,
|
||||
# Tile
|
||||
tiled: bool = False,
|
||||
tile_size: int = 128,
|
||||
@@ -252,9 +238,9 @@ class QwenImagePipeline(BasePipeline):
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"controlnet_inputs": controlnet_inputs,
|
||||
"enable_fp8_attention": enable_fp8_attention,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -364,83 +350,101 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_ControlNet(PipelineUnit):
|
||||
class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
take_over=True,
|
||||
onload_model_names=("text_encoder")
|
||||
)
|
||||
|
||||
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
|
||||
mask = (pipe.preprocess_image(mask) + 1) / 2
|
||||
mask = mask.mean(dim=1, keepdim=True)
|
||||
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
|
||||
latents = torch.concat([latents, mask], dim=1)
|
||||
return latents
|
||||
|
||||
def apply_controlnet_mask_on_image(self, pipe, image, mask):
|
||||
mask = mask.resize(image.size)
|
||||
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
|
||||
image = np.array(image)
|
||||
image[mask > 0] = 0
|
||||
image = Image.fromarray(image)
|
||||
return image
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
|
||||
if controlnet_inputs is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
conditionings = []
|
||||
for controlnet_input in controlnet_inputs:
|
||||
image = controlnet_input.image
|
||||
if controlnet_input.inpaint_mask is not None:
|
||||
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
|
||||
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict:
|
||||
if pipe.text_encoder is not None:
|
||||
prompt = [prompt]
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
|
||||
|
||||
if controlnet_input.inpaint_mask is not None:
|
||||
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
|
||||
conditionings.append(image)
|
||||
return {"controlnet_conditionings": conditionings}
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
||||
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
|
||||
prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def preprocess_masks(self, pipe, masks, height, width, dim):
|
||||
out_masks = []
|
||||
for mask in masks:
|
||||
mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
out_masks.append(mask)
|
||||
return out_masks
|
||||
|
||||
def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height):
|
||||
entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1)
|
||||
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
|
||||
prompt_embs, prompt_emb_masks = [], []
|
||||
for entity_prompt in entity_prompts:
|
||||
prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt)
|
||||
prompt_embs.append(prompt_emb_dict['prompt_emb'])
|
||||
prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask'])
|
||||
return prompt_embs, prompt_emb_masks, entity_masks
|
||||
|
||||
def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale):
|
||||
entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height)
|
||||
if enable_eligen_on_negative and cfg_scale != 1.0:
|
||||
entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi)
|
||||
entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi)
|
||||
entity_masks_nega = entity_masks_posi
|
||||
else:
|
||||
entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None
|
||||
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask}
|
||||
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask}
|
||||
return eligen_kwargs_posi, eligen_kwargs_nega
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None)
|
||||
if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False)
|
||||
eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
|
||||
eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
|
||||
eligen_enable_on_negative, inputs_shared["cfg_scale"])
|
||||
inputs_posi.update(eligen_kwargs_posi)
|
||||
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
|
||||
inputs_nega.update(eligen_kwargs_nega)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
controlnet: QwenImageMultiControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
prompt_emb_mask=None,
|
||||
height=None,
|
||||
width=None,
|
||||
controlnet_inputs=None,
|
||||
controlnet_conditionings=None,
|
||||
progress_id=0,
|
||||
num_inference_steps=1,
|
||||
entity_prompt_emb=None,
|
||||
entity_prompt_emb_mask=None,
|
||||
entity_masks=None,
|
||||
enable_fp8_attention=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs
|
||||
):
|
||||
# ControlNet
|
||||
if controlnet_conditionings is not None:
|
||||
controlnet_extra_kwargs = {
|
||||
"latents": latents,
|
||||
"timestep": timestep,
|
||||
"prompt_emb": prompt_emb,
|
||||
"prompt_emb_mask": prompt_emb_mask,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"use_gradient_checkpointing": use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": use_gradient_checkpointing_offload,
|
||||
}
|
||||
res_stack = controlnet(
|
||||
controlnet_conditionings, controlnet_inputs, progress_id, num_inference_steps,
|
||||
**controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
timestep = timestep / 1000
|
||||
@@ -448,11 +452,19 @@ def model_fn_qwen_image(
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
|
||||
image = dit.img_in(image)
|
||||
text = dit.txt_in(dit.txt_norm(prompt_emb))
|
||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
|
||||
for block_id, block in enumerate(dit.transformer_blocks):
|
||||
if entity_prompt_emb is not None:
|
||||
text, image_rotary_emb, attention_mask = dit.process_entity_masks(
|
||||
latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask,
|
||||
entity_masks, height, width, image, img_shapes,
|
||||
)
|
||||
else:
|
||||
text = dit.txt_in(dit.txt_norm(prompt_emb))
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
attention_mask = None
|
||||
|
||||
for block in dit.transformer_blocks:
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
@@ -461,9 +473,9 @@ def model_fn_qwen_image(
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
if controlnet_inputs is not None:
|
||||
image = image + res_stack[block_id]
|
||||
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
|
||||
@@ -4,6 +4,7 @@ from PIL import Image
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
|
||||
|
||||
@@ -364,12 +365,15 @@ class ModelLogger:
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
self.state_dict_converter = state_dict_converter
|
||||
|
||||
|
||||
def on_step_end(self, loss):
|
||||
pass
|
||||
|
||||
|
||||
self.num_steps = 0
|
||||
|
||||
|
||||
def on_step_end(self, accelerator, model, save_steps=None):
|
||||
self.num_steps += 1
|
||||
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
@@ -381,6 +385,21 @@ class ModelLogger:
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
def on_training_end(self, accelerator, model, save_steps=None):
|
||||
if save_steps is not None and self.num_steps % save_steps != 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def save_model(self, accelerator, model, file_name):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, file_name)
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
def launch_training_task(
|
||||
dataset: torch.utils.data.Dataset,
|
||||
@@ -388,11 +407,17 @@ def launch_training_task(
|
||||
model_logger: ModelLogger,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
||||
num_workers: int = 8,
|
||||
save_steps: int = None,
|
||||
num_epochs: int = 1,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
find_unused_parameters: bool = False,
|
||||
):
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
|
||||
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)],
|
||||
)
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
for epoch_id in range(num_epochs):
|
||||
@@ -402,10 +427,11 @@ def launch_training_task(
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(loss)
|
||||
model_logger.on_step_end(accelerator, model, save_steps)
|
||||
scheduler.step()
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
|
||||
if save_steps is None:
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
model_logger.on_training_end(accelerator, model, save_steps)
|
||||
|
||||
|
||||
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
||||
@@ -446,6 +472,9 @@ def wan_parser():
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -474,6 +503,9 @@ def flux_parser():
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -503,4 +535,7 @@ def qwen_image_parser():
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
return parser
|
||||
|
||||
@@ -110,8 +110,48 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
self.lora_A_weights = []
|
||||
self.lora_B_weights = []
|
||||
self.lora_merger = None
|
||||
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||
|
||||
def fp8_linear(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = input.device
|
||||
origin_dtype = input.dtype
|
||||
origin_shape = input.shape
|
||||
input = input.reshape(-1, origin_shape[-1])
|
||||
|
||||
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
||||
fp8_max = 448.0
|
||||
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
||||
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
||||
# we scale down the input by 2.0 in advance.
|
||||
# This scaling will be compensated later during the final result scaling.
|
||||
if self.computation_dtype == torch.float8_e4m3fnuz:
|
||||
fp8_max = fp8_max / 2.0
|
||||
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
||||
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
||||
input = input / (scale_a + 1e-8)
|
||||
input = input.to(self.computation_dtype)
|
||||
weight = weight.to(self.computation_dtype)
|
||||
bias = bias.to(torch.bfloat16)
|
||||
|
||||
result = torch._scaled_mm(
|
||||
input,
|
||||
weight.T,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b.T,
|
||||
bias=bias,
|
||||
out_dtype=origin_dtype,
|
||||
)
|
||||
new_shape = origin_shape[:-1] + result.shape[-1:]
|
||||
result = result.reshape(new_shape)
|
||||
return result
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
# VRAM management
|
||||
if self.state == 2:
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
@@ -123,8 +163,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
else:
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
out = torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
# Linear forward
|
||||
if self.enable_fp8:
|
||||
out = self.fp8_linear(x, weight, bias)
|
||||
else:
|
||||
out = torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
# LoRA
|
||||
if len(self.lora_A_weights) == 0:
|
||||
# No LoRA
|
||||
return out
|
||||
|
||||
@@ -249,6 +249,7 @@ The script includes the following parameters:
|
||||
* `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--data_file_keys`: Data file keys in the metadata. Separate with commas.
|
||||
* `--dataset_repeat`: Number of times the dataset repeats per epoch.
|
||||
* `--dataset_num_workers`: Number of workers for data loading.
|
||||
* Model
|
||||
* `--model_paths`: Paths to load models. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas.
|
||||
@@ -257,6 +258,8 @@ The script includes the following parameters:
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
|
||||
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
|
||||
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
|
||||
@@ -249,6 +249,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* 模型
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
|
||||
@@ -257,6 +258,8 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
|
||||
* 可训练模块
|
||||
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
|
||||
@@ -121,4 +121,7 @@ if __name__ == "__main__":
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
save_steps=args.save_steps,
|
||||
find_unused_parameters=args.find_unused_parameters,
|
||||
num_workers=args.dataset_num_workers,
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ image.save("image.jpg")
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image )|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|
||||
## Model Inference
|
||||
|
||||
@@ -164,6 +164,7 @@ After enabling VRAM management, the framework will automatically choose a memory
|
||||
* `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not a strict limit. If the set limit is too low but actual free VRAM is enough, the model will run with minimal VRAM use. Set it to 0 for the smallest possible VRAM use.
|
||||
* `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because large network layers may use more VRAM than expected during loading. The best value is the VRAM size of the largest model layer.
|
||||
* `num_persistent_param_in_dit`: Number of parameters to keep in VRAM in the DiT model. Default is no limit. This option will be removed in the future. Do not rely on it.
|
||||
* `enable_dit_fp8_computation`: Whether to enable FP8 computation in the DiT model. This is only applicable to GPUs that support FP8 operations (e.g., H200, etc.). Disabled by default.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -172,7 +173,11 @@ After enabling VRAM management, the framework will automatically choose a memory
|
||||
|
||||
<summary>Inference Acceleration</summary>
|
||||
|
||||
Inference acceleration for Qwen-Image is under development. Please stay tuned!
|
||||
* FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements.
|
||||
* GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
|
||||
* GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers.
|
||||
* Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
|
||||
* Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -219,6 +224,7 @@ The script includes the following parameters:
|
||||
* `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--data_file_keys`: Data file keys in metadata. Separate with commas.
|
||||
* `--dataset_repeat`: Number of times the dataset repeats per epoch.
|
||||
* `--dataset_num_workers`: Number of workers for data loading.
|
||||
* Model
|
||||
* `--model_paths`: Model paths to load. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas.
|
||||
@@ -228,6 +234,8 @@ The script includes the following parameters:
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
|
||||
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
|
||||
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
|
||||
@@ -44,7 +44,7 @@ image.save("image.jpg")
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|
||||
## 模型推理
|
||||
|
||||
@@ -164,6 +164,7 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
|
||||
* `vram_limit`: 显存占用量限制(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。将其设置为0时,将会实现理论最小显存占用。
|
||||
* `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
|
||||
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
|
||||
* `enable_dit_fp8_computation`: 是否启用 DiT 模型中的 FP8 计算,仅适用于支持 FP8 运算的 GPU(例如 H200 等),默认不启用。
|
||||
|
||||
</details>
|
||||
|
||||
@@ -172,7 +173,11 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
|
||||
|
||||
<summary>推理加速</summary>
|
||||
|
||||
Qwen-Image 的推理加速技术正在开发中,敬请期待!
|
||||
* FP8 量化:根据您的硬件与需求,请选择合适的量化方式
|
||||
* GPU 不支持 FP8 计算(例如 A100、4090 等):FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
|
||||
* GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效
|
||||
* 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
|
||||
* 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -219,6 +224,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* 模型
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。
|
||||
@@ -228,6 +234,8 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
|
||||
* 可训练模块
|
||||
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
|
||||
18
examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py
Normal file
18
examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management(enable_dit_fp8_computation=True)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
|
||||
image.save("image.jpg")
|
||||
51
examples/qwen_image/accelerate/Qwen-Image-FP8.py
Normal file
51
examples/qwen_image/accelerate/Qwen-Image-FP8.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.models.qwen_image_dit import RMSNorm
|
||||
from diffsynth.vram_management.layers import enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
enable_vram_management(
|
||||
pipe.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=torch.bfloat16,
|
||||
offload_device="cuda",
|
||||
onload_dtype=torch.bfloat16,
|
||||
onload_device="cuda",
|
||||
computation_dtype=torch.bfloat16,
|
||||
computation_device="cuda",
|
||||
),
|
||||
vram_limit=None,
|
||||
)
|
||||
enable_vram_management(
|
||||
pipe.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=torch.float8_e4m3fn,
|
||||
offload_device="cuda",
|
||||
onload_dtype=torch.float8_e4m3fn,
|
||||
onload_device="cuda",
|
||||
computation_dtype=torch.float8_e4m3fn,
|
||||
computation_device="cuda",
|
||||
),
|
||||
vram_limit=None,
|
||||
)
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
|
||||
image.save("image.jpg")
|
||||
128
examples/qwen_image/model_inference/Qwen-Image-EliGen.py
Normal file
128
examples/qwen_image/model_inference/Qwen-Image-EliGen.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
import random
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = ""
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=4.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=30,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors")
|
||||
|
||||
# example 1
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "yellow belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful asia woman wearing white dress, holding a mirror, with a forest background."
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
@@ -1,34 +0,0 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
|
||||
--data_file_keys "image,controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 8000 \
|
||||
--model_paths '[
|
||||
[
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
|
||||
],
|
||||
[
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
],
|
||||
"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors",
|
||||
"models/controlnet.safetensors"
|
||||
]' \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
|
||||
--output_path "./models/train/Qwen-Image-ControlNet_full" \
|
||||
--trainable_models "controlnet" \
|
||||
--extra_inputs "controlnet_image" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,33 +0,0 @@
|
||||
# This script is for initializing a Qwen-Image-ControlNet
|
||||
from diffsynth import load_state_dict, hash_state_dict_keys
|
||||
from diffsynth.pipelines.qwen_image import QwenImageControlNet
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
state_dict_dit = {}
|
||||
for i in range(1, 10):
|
||||
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
|
||||
|
||||
controlnet = QwenImageControlNet().to(dtype=torch.bfloat16, device="cuda")
|
||||
state_dict_controlnet = controlnet.state_dict()
|
||||
|
||||
state_dict_init = {}
|
||||
for k in state_dict_controlnet:
|
||||
if k in state_dict_dit:
|
||||
if state_dict_dit[k].shape == state_dict_controlnet[k].shape:
|
||||
state_dict_init[k] = state_dict_dit[k]
|
||||
elif k == "img_in.weight":
|
||||
state_dict_init[k] = torch.concat(
|
||||
[
|
||||
state_dict_dit[k],
|
||||
state_dict_dit[k],
|
||||
],
|
||||
dim=-1
|
||||
)
|
||||
elif k == "alpha":
|
||||
state_dict_init[k] = torch.zeros_like(state_dict_controlnet[k])
|
||||
controlnet.load_state_dict(state_dict_init)
|
||||
|
||||
print(hash_state_dict_keys(state_dict_controlnet))
|
||||
save_file(state_dict_controlnet, "models/controlnet.safetensors")
|
||||
18
examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh
Normal file
18
examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path "data/example_image_dataset" \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_eligen.json \
|
||||
--data_file_keys "image,eligen_entity_masks" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-EliGen_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--extra_inputs "eligen_entity_masks,eligen_entity_prompts" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
@@ -12,4 +12,6 @@ accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch, os, json
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
|
||||
from diffsynth.models.lora import QwenImageLoRAConverter
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -29,7 +30,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path))
|
||||
else:
|
||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||
|
||||
|
||||
# Reset training scheduler (do it in each training step)
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
@@ -49,7 +50,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
# CFG-sensitive parameters
|
||||
@@ -72,14 +73,8 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
controlnet_input = {}
|
||||
for extra_input in self.extra_inputs:
|
||||
if extra_input.startswith("controlnet_"):
|
||||
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if len(controlnet_input) > 0:
|
||||
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
|
||||
# Pipeline units will automatically process the input parameters.
|
||||
for unit in self.pipe.units:
|
||||
@@ -114,6 +109,7 @@ if __name__ == "__main__":
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
@@ -121,4 +117,7 @@ if __name__ == "__main__":
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
save_steps=args.save_steps,
|
||||
find_unused_parameters=args.find_unused_parameters,
|
||||
num_workers=args.dataset_num_workers,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-EliGen_lora/epoch-4.safetensors")
|
||||
|
||||
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
|
||||
image = pipe(global_prompt,
|
||||
seed=0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks)
|
||||
image.save("Qwen-Image_EliGen.jpg")
|
||||
@@ -280,6 +280,7 @@ The script includes the following parameters:
|
||||
* `--num_frames`: Number of frames per video. Frames are sampled from the video prefix.
|
||||
* `--data_file_keys`: Data file keys in the metadata. Comma-separated.
|
||||
* `--dataset_repeat`: Number of times to repeat the dataset per epoch.
|
||||
* `--dataset_num_workers`: Number of workers for data loading.
|
||||
* Models
|
||||
* `--model_paths`: Paths to load models. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
|
||||
@@ -290,6 +291,8 @@ The script includes the following parameters:
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Output save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
|
||||
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
|
||||
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which model LoRA is added to.
|
||||
|
||||
@@ -282,6 +282,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
|
||||
* `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。
|
||||
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* 模型
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。
|
||||
@@ -292,6 +293,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
|
||||
* 可训练模块
|
||||
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
|
||||
@@ -127,4 +127,7 @@ if __name__ == "__main__":
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
save_steps=args.save_steps,
|
||||
find_unused_parameters=args.find_unused_parameters,
|
||||
num_workers=args.dataset_num_workers,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user