import torch from PIL import Image from typing import Union from PIL import Image from tqdm import tqdm from einops import rearrange import numpy as np from ..models import ModelManager, load_state_dict from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_vae import QwenImageVAE from ..models.qwen_image_controlnet import QwenImageControlNet, QwenImageBlockWiseControlNet from ..schedulers import FlowMatchScheduler from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit from ..lora import GeneralLoRALoader from .flux_image_new import ControlNetInput from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear class QwenImageMultiControlNet(torch.nn.Module): def __init__(self, models: list[QwenImageControlNet]): super().__init__() if not isinstance(models, list): models = [models] self.models = torch.nn.ModuleList(models) def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs): model = self.models[controlnet_input.controlnet_id] res_stack = model( controlnet_conditioning=conditioning, processor_id=controlnet_input.processor_id, **kwargs ) res_stack = [res * controlnet_input.scale for res in res_stack] return res_stack def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs): res_stack = None for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1) if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4): continue res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs) if res_stack is None: res_stack = res_stack_ else: res_stack = [i + j for i, j in zip(res_stack, res_stack_)] return res_stack class QwenImagePipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, ) from transformers import Qwen2Tokenizer self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02) self.text_encoder: QwenImageTextEncoder = None self.dit: QwenImageDiT = None self.vae: QwenImageVAE = None self.controlnet: QwenImageMultiControlNet = None self.tokenizer: Qwen2Tokenizer = None self.unit_runner = PipelineUnitRunner() self.in_iteration_models = ("dit", "controlnet", "blockwise_controlnet") self.units = [ QwenImageUnit_ShapeChecker(), QwenImageUnit_NoiseInitializer(), QwenImageUnit_InputImageEmbedder(), QwenImageUnit_PromptEmbedder(), QwenImageUnit_EntityControl(), QwenImageUnit_ControlNet(), ] self.model_fn = model_fn_qwen_image def load_lora(self, module, path, alpha=1): loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) loader.load(module, lora, alpha=alpha) def training_loss(self, **inputs): timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) noise_pred = self.model_fn(**inputs, timestep=timestep) loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = loss * self.scheduler.training_weight(timestep) return loss def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False): self.vram_management_enabled = True if vram_limit is None: vram_limit = self.get_vram() vram_limit = vram_limit - vram_buffer if self.text_encoder is not None: from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm dtype = next(iter(self.text_encoder.parameters())).dtype enable_vram_management( self.text_encoder, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Embedding: AutoWrappedModule, Qwen2_5_VLRotaryEmbedding: AutoWrappedModule, Qwen2RMSNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), vram_limit=vram_limit, ) if self.dit is not None: from ..models.qwen_image_dit import RMSNorm dtype = next(iter(self.dit.parameters())).dtype device = "cpu" if vram_limit is not None else self.device if not enable_dit_fp8_computation: enable_vram_management( self.dit, module_map = { RMSNorm: AutoWrappedModule, torch.nn.Linear: AutoWrappedLinear, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=device, computation_dtype=self.torch_dtype, computation_device=self.device, ), vram_limit=vram_limit, ) else: enable_vram_management( self.dit, module_map = { RMSNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=device, computation_dtype=self.torch_dtype, computation_device=self.device, ), vram_limit=vram_limit, ) enable_vram_management( self.dit, module_map = { torch.nn.Linear: AutoWrappedLinear, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=device, computation_dtype=dtype, computation_device=self.device, ), vram_limit=vram_limit, ) if self.vae is not None: from ..models.qwen_image_vae import QwenImageRMS_norm dtype = next(iter(self.vae.parameters())).dtype enable_vram_management( self.vae, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.Conv2d: AutoWrappedModule, QwenImageRMS_norm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), vram_limit=vram_limit, ) @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, device: Union[str, torch.device] = "cuda", model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), ): # Download and load models model_manager = ModelManager() for model_config in model_configs: model_config.download_if_necessary() model_manager.load_model( model_config.path, device=model_config.offload_device or device, torch_dtype=model_config.offload_dtype or torch_dtype ) # Initialize pipeline pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype) pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder") pipe.dit = model_manager.fetch_model("qwen_image_dit") pipe.vae = model_manager.fetch_model("qwen_image_vae") pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all")) pipe.blockwise_controlnet = model_manager.fetch_model("qwen_image_blockwise_controlnet") if tokenizer_config is not None and pipe.text_encoder is not None: tokenizer_config.download_if_necessary() from transformers import Qwen2Tokenizer pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path) return pipe @torch.no_grad() def __call__( self, # Prompt prompt: str, negative_prompt: str = "", cfg_scale: float = 4.0, # Image input_image: Image.Image = None, denoising_strength: float = 1.0, # Shape height: int = 1328, width: int = 1328, # Randomness seed: int = None, rand_device: str = "cpu", # Steps num_inference_steps: int = 30, # ControlNet controlnet_inputs: list[ControlNetInput] = None, # EliGen eligen_entity_prompts: list[str] = None, eligen_entity_masks: list[Image.Image] = None, eligen_enable_on_negative: bool = False, # FP8 enable_fp8_attention: bool = False, # Tile tiled: bool = False, tile_size: int = 128, tile_stride: int = 64, # Progress bar progress_bar_cmd = tqdm, ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16)) # Parameters inputs_posi = { "prompt": prompt, } inputs_nega = { "negative_prompt": negative_prompt, } inputs_shared = { "cfg_scale": cfg_scale, "input_image": input_image, "denoising_strength": denoising_strength, "height": height, "width": width, "seed": seed, "rand_device": rand_device, "enable_fp8_attention": enable_fp8_attention, "num_inference_steps": num_inference_steps, "controlnet_inputs": controlnet_inputs, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id) if cfg_scale != 1.0: noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi # Scheduler inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) # Decode self.load_models_to_device(['vae']) image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) image = self.vae_output_to_image(image) self.load_models_to_device([]) return image class QwenImageUnit_ShapeChecker(PipelineUnit): def __init__(self): super().__init__(input_params=("height", "width")) def process(self, pipe: QwenImagePipeline, height, width): height, width = pipe.check_resize_height_width(height, width) return {"height": height, "width": width} class QwenImageUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__(input_params=("height", "width", "seed", "rand_device")) def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device): noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) return {"noise": noise} class QwenImageUnit_InputImageEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae",) ) def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride): if input_image is None: return {"latents": noise, "input_latents": None} pipe.load_models_to_device(['vae']) image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) if pipe.scheduler.training: return {"latents": noise, "input_latents": input_latents} else: latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) return {"latents": latents, "input_latents": None} class QwenImageUnit_PromptEmbedder(PipelineUnit): def __init__(self): super().__init__( seperate_cfg=True, input_params_posi={"prompt": "prompt"}, input_params_nega={"prompt": "negative_prompt"}, onload_model_names=("text_encoder",) ) def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) selected = hidden_states[bool_mask] split_result = torch.split(selected, valid_lengths.tolist(), dim=0) return split_result def process(self, pipe: QwenImagePipeline, prompt) -> dict: if pipe.text_encoder is not None: prompt = [prompt] template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" drop_idx = 34 txt = [template.format(e) for e in prompt] txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1] split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} else: return {} class QwenImageUnit_EntityControl(PipelineUnit): def __init__(self): super().__init__( take_over=True, onload_model_names=("text_encoder") ) def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) selected = hidden_states[bool_mask] split_result = torch.split(selected, valid_lengths.tolist(), dim=0) return split_result def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict: if pipe.text_encoder is not None: prompt = [prompt] template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" drop_idx = 34 txt = [template.format(e) for e in prompt] txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1] split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} else: return {} def preprocess_masks(self, pipe, masks, height, width, dim): out_masks = [] for mask in masks: mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) out_masks.append(mask) return out_masks def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height): entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w prompt_embs, prompt_emb_masks = [], [] for entity_prompt in entity_prompts: prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt) prompt_embs.append(prompt_emb_dict['prompt_emb']) prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask']) return prompt_embs, prompt_emb_masks, entity_masks def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale): entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height) if enable_eligen_on_negative and cfg_scale != 1.0: entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi) entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi) entity_masks_nega = entity_masks_posi else: entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask} eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask} return eligen_kwargs_posi, eligen_kwargs_nega def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega): eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0: return inputs_shared, inputs_posi, inputs_nega pipe.load_models_to_device(self.onload_model_names) eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], eligen_enable_on_negative, inputs_shared["cfg_scale"]) inputs_posi.update(eligen_kwargs_posi) if inputs_shared.get("cfg_scale", 1.0) != 1.0: inputs_nega.update(eligen_kwargs_nega) return inputs_shared, inputs_posi, inputs_nega class QwenImageUnit_ControlNet(PipelineUnit): def __init__(self): super().__init__( input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae",) ) def apply_controlnet_mask_on_latents(self, pipe, latents, mask): mask = (pipe.preprocess_image(mask) + 1) / 2 mask = mask.mean(dim=1, keepdim=True) mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) latents = torch.concat([latents, mask], dim=1) return latents def apply_controlnet_mask_on_image(self, pipe, image, mask): mask = mask.resize(image.size) mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu() image = np.array(image) image[mask > 0] = 0 image = Image.fromarray(image) return image def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): if controlnet_inputs is None: return {} return_key = "blockwise_controlnet_conditioning" if pipe.blockwise_controlnet is not None else "controlnet_conditionings" pipe.load_models_to_device(self.onload_model_names) conditionings = [] for controlnet_input in controlnet_inputs: image = controlnet_input.image if controlnet_input.inpaint_mask is not None: image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask) image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) if controlnet_input.inpaint_mask is not None: image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) conditionings.append(image) return {return_key: conditionings} def model_fn_qwen_image( dit: QwenImageDiT = None, controlnet: QwenImageMultiControlNet = None, blockwise_controlnet: QwenImageBlockWiseControlNet = None, latents=None, timestep=None, prompt_emb=None, prompt_emb_mask=None, height=None, width=None, controlnet_inputs=None, controlnet_conditionings=None, blockwise_controlnet_conditioning=None, progress_id=0, num_inference_steps=1, entity_prompt_emb=None, entity_prompt_emb_mask=None, entity_masks=None, enable_fp8_attention=False, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs ): # ControlNet if controlnet_conditionings is not None: controlnet_extra_kwargs = { "latents": latents, "timestep": timestep, "prompt_emb": prompt_emb, "prompt_emb_mask": prompt_emb_mask, "height": height, "width": width, "use_gradient_checkpointing": use_gradient_checkpointing, "use_gradient_checkpointing_offload": use_gradient_checkpointing_offload, } res_stack = controlnet( controlnet_conditionings, controlnet_inputs, progress_id, num_inference_steps, **controlnet_extra_kwargs ) img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() timestep = timestep / 1000 image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) image = dit.img_in(image) conditioning = dit.time_text_embed(timestep, image.dtype) if entity_prompt_emb is not None: text, image_rotary_emb, attention_mask = dit.process_entity_masks( latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes, ) else: text = dit.txt_in(dit.txt_norm(prompt_emb)) image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) attention_mask = None if blockwise_controlnet_conditioning is not None: blockwise_controlnet_conditioning = rearrange( blockwise_controlnet_conditioning[0], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2 ) blockwise_controlnet_conditioning = blockwise_controlnet.process_controlnet_conditioning(blockwise_controlnet_conditioning) # blockwise_controlnet_conditioning = for block_id, block in enumerate(dit.transformer_blocks): text, image = gradient_checkpoint_forward( block, use_gradient_checkpointing, use_gradient_checkpointing_offload, image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention, ) if blockwise_controlnet is not None: image = image + blockwise_controlnet.blockwise_forward(image, blockwise_controlnet_conditioning, block_id) if controlnet_conditionings is not None: image = image + res_stack[block_id] image = dit.norm_out(image, conditioning) image = dit.proj_out(image) latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) return latents