update lora and full train

This commit is contained in:
mi804
2025-08-18 19:09:19 +08:00
parent f9ce261a0e
commit 123f6dbadb
3 changed files with 7 additions and 3 deletions

View File

@@ -623,11 +623,11 @@ def model_fn_qwen_image(
timestep = timestep / 1000
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image_seq_len = image.shape[1]
if edit_latents is not None:
img_shapes += [(edit_latents.shape[0], edit_latents.shape[2]//2, edit_latents.shape[3]//2)]
edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=edit_latents.shape[2]//2, W=edit_latents.shape[3]//2, P=2, Q=2)
image_seq_len = image.shape[1]
image = torch.cat([image, edit_image], dim=1)
image = dit.img_in(image)
@@ -660,8 +660,8 @@ def model_fn_qwen_image(
enable_fp8_attention=enable_fp8_attention,
)
if blockwise_controlnet_conditioning is not None:
image = image + blockwise_controlnet.blockwise_forward(
image=image, conditionings=blockwise_controlnet_conditioning,
image[:, :image_seq_len] = image[:, :image_seq_len] + blockwise_controlnet.blockwise_forward(
image=image[:, :image_seq_len], conditionings=blockwise_controlnet_conditioning,
controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id,
progress_id=progress_id, num_inference_steps=num_inference_steps,
)