Step1x vram (#556)

* support step1x vram management
This commit is contained in:
Zhongjie Duan
2025-04-28 10:13:20 +08:00
committed by GitHub
parent 32f630ff5f
commit ef2a7abad4
2 changed files with 104 additions and 97 deletions

View File

@@ -35,105 +35,110 @@ class FluxImagePipeline(BasePipeline):
self.infinityou_processor: InfinitYou = None self.infinityou_processor: InfinitYou = None
self.qwenvl = None self.qwenvl = None
self.step1x_connector: Qwen2Connector = None self.step1x_connector: Qwen2Connector = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'step1x_connector'] self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'qwenvl', 'step1x_connector']
def enable_vram_management(self, num_persistent_param_in_dit=None): def enable_vram_management(self, num_persistent_param_in_dit=None):
dtype = next(iter(self.text_encoder_1.parameters())).dtype if self.text_encoder_1 is not None:
enable_vram_management( dtype = next(iter(self.text_encoder_1.parameters())).dtype
self.text_encoder_1, enable_vram_management(
module_map = { self.text_encoder_1,
torch.nn.Linear: AutoWrappedLinear, module_map = {
torch.nn.Embedding: AutoWrappedModule, torch.nn.Linear: AutoWrappedLinear,
torch.nn.LayerNorm: AutoWrappedModule, torch.nn.Embedding: AutoWrappedModule,
}, torch.nn.LayerNorm: AutoWrappedModule,
module_config = dict( },
offload_dtype=dtype, module_config = dict(
offload_device="cpu", offload_dtype=dtype,
onload_dtype=dtype, offload_device="cpu",
onload_device="cpu", onload_dtype=dtype,
computation_dtype=self.torch_dtype, onload_device="cpu",
computation_device=self.device, computation_dtype=self.torch_dtype,
), computation_device=self.device,
) ),
dtype = next(iter(self.text_encoder_2.parameters())).dtype )
enable_vram_management( if self.text_encoder_2 is not None:
self.text_encoder_2, dtype = next(iter(self.text_encoder_2.parameters())).dtype
module_map = { enable_vram_management(
torch.nn.Linear: AutoWrappedLinear, self.text_encoder_2,
torch.nn.Embedding: AutoWrappedModule, module_map = {
T5LayerNorm: AutoWrappedModule, torch.nn.Linear: AutoWrappedLinear,
T5DenseActDense: AutoWrappedModule, torch.nn.Embedding: AutoWrappedModule,
T5DenseGatedActDense: AutoWrappedModule, T5LayerNorm: AutoWrappedModule,
}, T5DenseActDense: AutoWrappedModule,
module_config = dict( T5DenseGatedActDense: AutoWrappedModule,
offload_dtype=dtype, },
offload_device="cpu", module_config = dict(
onload_dtype=dtype, offload_dtype=dtype,
onload_device="cpu", offload_device="cpu",
computation_dtype=self.torch_dtype, onload_dtype=dtype,
computation_device=self.device, onload_device="cpu",
), computation_dtype=self.torch_dtype,
) computation_device=self.device,
dtype = next(iter(self.dit.parameters())).dtype ),
enable_vram_management( )
self.dit, if self.dit is not None:
module_map = { dtype = next(iter(self.dit.parameters())).dtype
RMSNorm: AutoWrappedModule, enable_vram_management(
torch.nn.Linear: AutoWrappedLinear, self.dit,
}, module_map = {
module_config = dict( RMSNorm: AutoWrappedModule,
offload_dtype=dtype, torch.nn.Linear: AutoWrappedLinear,
offload_device="cpu", },
onload_dtype=dtype, module_config = dict(
onload_device="cuda", offload_dtype=dtype,
computation_dtype=self.torch_dtype, offload_device="cpu",
computation_device=self.device, onload_dtype=dtype,
), onload_device="cuda",
max_num_param=num_persistent_param_in_dit, computation_dtype=self.torch_dtype,
overflow_module_config = dict( computation_device=self.device,
offload_dtype=dtype, ),
offload_device="cpu", max_num_param=num_persistent_param_in_dit,
onload_dtype=dtype, overflow_module_config = dict(
onload_device="cpu", offload_dtype=dtype,
computation_dtype=self.torch_dtype, offload_device="cpu",
computation_device=self.device, onload_dtype=dtype,
), onload_device="cpu",
) computation_dtype=self.torch_dtype,
dtype = next(iter(self.vae_decoder.parameters())).dtype computation_device=self.device,
enable_vram_management( ),
self.vae_decoder, )
module_map = { if self.vae_decoder is not None:
torch.nn.Linear: AutoWrappedLinear, dtype = next(iter(self.vae_decoder.parameters())).dtype
torch.nn.Conv2d: AutoWrappedModule, enable_vram_management(
torch.nn.GroupNorm: AutoWrappedModule, self.vae_decoder,
}, module_map = {
module_config = dict( torch.nn.Linear: AutoWrappedLinear,
offload_dtype=dtype, torch.nn.Conv2d: AutoWrappedModule,
offload_device="cpu", torch.nn.GroupNorm: AutoWrappedModule,
onload_dtype=dtype, },
onload_device="cpu", module_config = dict(
computation_dtype=self.torch_dtype, offload_dtype=dtype,
computation_device=self.device, offload_device="cpu",
), onload_dtype=dtype,
) onload_device="cpu",
dtype = next(iter(self.vae_encoder.parameters())).dtype computation_dtype=self.torch_dtype,
enable_vram_management( computation_device=self.device,
self.vae_encoder, ),
module_map = { )
torch.nn.Linear: AutoWrappedLinear, if self.vae_encoder is not None:
torch.nn.Conv2d: AutoWrappedModule, dtype = next(iter(self.vae_encoder.parameters())).dtype
torch.nn.GroupNorm: AutoWrappedModule, enable_vram_management(
}, self.vae_encoder,
module_config = dict( module_map = {
offload_dtype=dtype, torch.nn.Linear: AutoWrappedLinear,
offload_device="cpu", torch.nn.Conv2d: AutoWrappedModule,
onload_dtype=dtype, torch.nn.GroupNorm: AutoWrappedModule,
onload_device="cpu", },
computation_dtype=self.torch_dtype, module_config = dict(
computation_device=self.device, offload_dtype=dtype,
), offload_device="cpu",
) onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload() self.enable_cpu_offload()
@@ -403,6 +408,7 @@ class FluxImagePipeline(BasePipeline):
def prepare_step1x_kwargs(self, prompt, negative_prompt, image): def prepare_step1x_kwargs(self, prompt, negative_prompt, image):
if image is None: if image is None:
return {}, {} return {}, {}
self.load_models_to_device(["qwenvl", "vae_encoder"])
captions = [prompt, negative_prompt] captions = [prompt, negative_prompt]
ref_images = [image, image] ref_images = [image, image]
embs, masks = self.qwenvl(captions, ref_images) embs, masks = self.qwenvl(captions, ref_images)
@@ -504,7 +510,7 @@ class FluxImagePipeline(BasePipeline):
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
# Denoise # Denoise
self.load_models_to_device(['dit', 'controlnet']) self.load_models_to_device(['dit', 'controlnet', 'step1x_connector'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)

View File

@@ -15,6 +15,7 @@ model_manager.load_models([
"models/stepfun-ai/Step1X-Edit/vae.safetensors", "models/stepfun-ai/Step1X-Edit/vae.safetensors",
]) ])
pipe = FluxImagePipeline.from_model_manager(model_manager) pipe = FluxImagePipeline.from_model_manager(model_manager)
pipe.enable_vram_management()
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255) image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
image = pipe( image = pipe(