Merge pull request #735 from modelscope/qwen-image

qwen-image
This commit is contained in:
Zhongjie Duan
2025-08-04 20:40:32 +08:00
committed by GitHub
21 changed files with 2903 additions and 12 deletions

View File

@@ -0,0 +1,364 @@
import torch
from PIL import Image
from typing import Union
from PIL import Image
from tqdm import tqdm
from einops import rearrange
from ..models import ModelManager, load_state_dict
from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE
from ..schedulers import FlowMatchScheduler
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora import GeneralLoRALoader
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
class QwenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
from transformers import Qwen2Tokenizer
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
self.text_encoder: QwenImageTextEncoder = None
self.dit: QwenImageDiT = None
self.vae: QwenImageVAE = None
self.tokenizer: Qwen2Tokenizer = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit",)
self.units = [
QwenImageUnit_ShapeChecker(),
QwenImageUnit_NoiseInitializer(),
QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_PromptEmbedder(),
]
self.model_fn = model_fn_qwen_image
def load_lora(self, module, path, alpha=1):
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.scheduler.training_weight(timestep)
return loss
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
dtype = next(iter(self.text_encoder.parameters())).dtype
enable_vram_management(
self.text_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
Qwen2RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit is not None:
from ..models.qwen_image_dit import RMSNorm
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
from ..models.qwen_image_vae import QwenImageRMS_norm
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
QwenImageRMS_norm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
):
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary()
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Initialize pipeline
pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype)
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
pipe.dit = model_manager.fetch_model("qwen_image_dit")
pipe.vae = model_manager.fetch_model("qwen_image_vae")
if tokenizer_config is not None and pipe.text_encoder is not None:
tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer
pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path)
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1328,
width: int = 1328,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Tile
tiled: bool = False,
tile_size: int = 128,
tile_stride: int = 64,
# Progress bar
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16))
# Parameters
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
if cfg_scale != 1.0:
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class QwenImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width"))
def process(self, pipe: QwenImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class QwenImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "seed", "rand_device"))
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class QwenImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": None}
class QwenImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
onload_model_names=("text_encoder",)
)
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
return split_result
def process(self, pipe: QwenImagePipeline, prompt) -> dict:
if pipe.text_encoder is not None:
prompt = [prompt]
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
drop_idx = 34
txt = [template.format(e) for e in prompt]
txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
else:
return {}
def model_fn_qwen_image(
dit: QwenImageDiT = None,
latents=None,
timestep=None,
prompt_emb=None,
prompt_emb_mask=None,
height=None,
width=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
timestep = timestep / 1000
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image = dit.img_in(image)
text = dit.txt_in(dit.txt_norm(prompt_emb))
conditioning = dit.time_text_embed(timestep, image.dtype)
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
for block in dit.transformer_blocks:
text, image = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
)
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
return latents