From ef2a7abad478d6f91c0839be8f491587e36fda9f Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:13:20 +0800 Subject: [PATCH] Step1x vram (#556) * support step1x vram management --- diffsynth/pipelines/flux_image.py | 200 +++++++++++++++--------------- examples/step1x/step1x.py | 1 + 2 files changed, 104 insertions(+), 97 deletions(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 0abe26f..90b196d 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -35,105 +35,110 @@ class FluxImagePipeline(BasePipeline): self.infinityou_processor: InfinitYou = None self.qwenvl = None self.step1x_connector: Qwen2Connector = None - self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'step1x_connector'] + self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'qwenvl', 'step1x_connector'] def enable_vram_management(self, num_persistent_param_in_dit=None): - dtype = next(iter(self.text_encoder_1.parameters())).dtype - enable_vram_management( - self.text_encoder_1, - module_map = { - torch.nn.Linear: AutoWrappedLinear, - torch.nn.Embedding: AutoWrappedModule, - torch.nn.LayerNorm: AutoWrappedModule, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - ) - dtype = next(iter(self.text_encoder_2.parameters())).dtype - enable_vram_management( - self.text_encoder_2, - module_map = { - torch.nn.Linear: AutoWrappedLinear, - torch.nn.Embedding: AutoWrappedModule, - T5LayerNorm: AutoWrappedModule, - T5DenseActDense: AutoWrappedModule, - T5DenseGatedActDense: AutoWrappedModule, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - ) - dtype = next(iter(self.dit.parameters())).dtype - enable_vram_management( - self.dit, - module_map = { - RMSNorm: AutoWrappedModule, - torch.nn.Linear: AutoWrappedLinear, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cuda", - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - max_num_param=num_persistent_param_in_dit, - overflow_module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - ) - dtype = next(iter(self.vae_decoder.parameters())).dtype - enable_vram_management( - self.vae_decoder, - module_map = { - torch.nn.Linear: AutoWrappedLinear, - torch.nn.Conv2d: AutoWrappedModule, - torch.nn.GroupNorm: AutoWrappedModule, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - ) - dtype = next(iter(self.vae_encoder.parameters())).dtype - enable_vram_management( - self.vae_encoder, - module_map = { - torch.nn.Linear: AutoWrappedLinear, - torch.nn.Conv2d: AutoWrappedModule, - torch.nn.GroupNorm: AutoWrappedModule, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - ) + if self.text_encoder_1 is not None: + dtype = next(iter(self.text_encoder_1.parameters())).dtype + enable_vram_management( + self.text_encoder_1, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Embedding: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + if self.text_encoder_2 is not None: + dtype = next(iter(self.text_encoder_2.parameters())).dtype + enable_vram_management( + self.text_encoder_2, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Embedding: AutoWrappedModule, + T5LayerNorm: AutoWrappedModule, + T5DenseActDense: AutoWrappedModule, + T5DenseGatedActDense: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + if self.dit is not None: + dtype = next(iter(self.dit.parameters())).dtype + enable_vram_management( + self.dit, + module_map = { + RMSNorm: AutoWrappedModule, + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cuda", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + if self.vae_decoder is not None: + dtype = next(iter(self.vae_decoder.parameters())).dtype + enable_vram_management( + self.vae_decoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.GroupNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + if self.vae_encoder is not None: + dtype = next(iter(self.vae_encoder.parameters())).dtype + enable_vram_management( + self.vae_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.GroupNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) self.enable_cpu_offload() @@ -403,6 +408,7 @@ class FluxImagePipeline(BasePipeline): def prepare_step1x_kwargs(self, prompt, negative_prompt, image): if image is None: return {}, {} + self.load_models_to_device(["qwenvl", "vae_encoder"]) captions = [prompt, negative_prompt] ref_images = [image, image] embs, masks = self.qwenvl(captions, ref_images) @@ -504,7 +510,7 @@ class FluxImagePipeline(BasePipeline): tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} # Denoise - self.load_models_to_device(['dit', 'controlnet']) + self.load_models_to_device(['dit', 'controlnet', 'step1x_connector']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) diff --git a/examples/step1x/step1x.py b/examples/step1x/step1x.py index 6ca974d..80de280 100644 --- a/examples/step1x/step1x.py +++ b/examples/step1x/step1x.py @@ -15,6 +15,7 @@ model_manager.load_models([ "models/stepfun-ai/Step1X-Edit/vae.safetensors", ]) pipe = FluxImagePipeline.from_model_manager(model_manager) +pipe.enable_vram_management() image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255) image = pipe(