Merge branch 'modelscope:main' into main

This commit is contained in:
lzws
2025-08-21 20:53:19 +08:00
committed by GitHub
12 changed files with 428 additions and 7 deletions

View File

@@ -467,6 +467,7 @@ class QwenImageDiT(torch.nn.Module):
image_start = sum(seq_lens)
image_end = total_seq_len
cumsum = [0]
single_image_seq = image_end - image_start
for length in seq_lens:
cumsum.append(cumsum[-1] + length)
for i in range(N):
@@ -474,6 +475,9 @@ class QwenImageDiT(torch.nn.Module):
prompt_end = cumsum[i+1]
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
# repeat image mask to match the single image sequence length
repeat_time = single_image_seq // image_mask.shape[-1]
image_mask = image_mask.repeat(1, 1, repeat_time)
# prompt update with image
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# image update with prompt
@@ -493,7 +497,8 @@ class QwenImageDiT(torch.nn.Module):
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
return all_prompt_emb, image_rotary_emb, attention_mask
def forward(
self,
latents=None,

View File

@@ -69,6 +69,7 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_Inpaint(),
QwenImageUnit_EditImageEmbedder(),
QwenImageUnit_ContextImageEmbedder(),
QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(),
QwenImageUnit_BlockwiseControlNet(),
@@ -76,10 +77,63 @@ class QwenImagePipeline(BasePipeline):
self.model_fn = model_fn_qwen_image
def load_lora(self, module, path, alpha=1):
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=False,
state_dict=None,
):
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora = state_dict
if hotload:
for name, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
if lora_a_name in lora and lora_b_name in lora:
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
else:
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def clear_lora(self):
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
def enable_lora_magic(self):
if self.dit is not None:
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device=self.device,
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=None,
)
def training_loss(self, **inputs):
@@ -284,6 +338,8 @@ class QwenImagePipeline(BasePipeline):
edit_image: Image.Image = None,
edit_image_auto_resize: bool = True,
edit_rope_interpolation: bool = False,
# In-context control
context_image: Image.Image = None,
# FP8
enable_fp8_attention: bool = False,
# Tile
@@ -315,6 +371,7 @@ class QwenImagePipeline(BasePipeline):
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
"context_image": context_image,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -615,6 +672,21 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride):
if context_image is None:
return {}
pipe.load_models_to_device(['vae'])
context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return {"context_latents": context_latents}
def model_fn_qwen_image(
dit: QwenImageDiT = None,
@@ -633,6 +705,7 @@ def model_fn_qwen_image(
entity_prompt_emb_mask=None,
entity_masks=None,
edit_latents=None,
context_latents=None,
enable_fp8_attention=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
@@ -646,6 +719,10 @@ def model_fn_qwen_image(
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image_seq_len = image.shape[1]
if context_latents is not None:
img_shapes += [(context_latents.shape[0], context_latents.shape[2]//2, context_latents.shape[3]//2)]
context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2)
image = torch.cat([image, context_image], dim=1)
if edit_latents is not None:
img_shapes += [(edit_latents.shape[0], edit_latents.shape[2]//2, edit_latents.shape[3]//2)]
edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=edit_latents.shape[2]//2, W=edit_latents.shape[3]//2, P=2, Q=2)
@@ -692,8 +769,7 @@ def model_fn_qwen_image(
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)
if edit_latents is not None:
image = image[:, :image_seq_len]
image = image[:, :image_seq_len]
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
return latents