mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
@@ -35,105 +35,110 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.infinityou_processor: InfinitYou = None
|
self.infinityou_processor: InfinitYou = None
|
||||||
self.qwenvl = None
|
self.qwenvl = None
|
||||||
self.step1x_connector: Qwen2Connector = None
|
self.step1x_connector: Qwen2Connector = None
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'step1x_connector']
|
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'qwenvl', 'step1x_connector']
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
if self.text_encoder_1 is not None:
|
||||||
enable_vram_management(
|
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||||
self.text_encoder_1,
|
enable_vram_management(
|
||||||
module_map = {
|
self.text_encoder_1,
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
module_map = {
|
||||||
torch.nn.Embedding: AutoWrappedModule,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
torch.nn.Embedding: AutoWrappedModule,
|
||||||
},
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
module_config = dict(
|
},
|
||||||
offload_dtype=dtype,
|
module_config = dict(
|
||||||
offload_device="cpu",
|
offload_dtype=dtype,
|
||||||
onload_dtype=dtype,
|
offload_device="cpu",
|
||||||
onload_device="cpu",
|
onload_dtype=dtype,
|
||||||
computation_dtype=self.torch_dtype,
|
onload_device="cpu",
|
||||||
computation_device=self.device,
|
computation_dtype=self.torch_dtype,
|
||||||
),
|
computation_device=self.device,
|
||||||
)
|
),
|
||||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
)
|
||||||
enable_vram_management(
|
if self.text_encoder_2 is not None:
|
||||||
self.text_encoder_2,
|
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||||
module_map = {
|
enable_vram_management(
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
self.text_encoder_2,
|
||||||
torch.nn.Embedding: AutoWrappedModule,
|
module_map = {
|
||||||
T5LayerNorm: AutoWrappedModule,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
T5DenseActDense: AutoWrappedModule,
|
torch.nn.Embedding: AutoWrappedModule,
|
||||||
T5DenseGatedActDense: AutoWrappedModule,
|
T5LayerNorm: AutoWrappedModule,
|
||||||
},
|
T5DenseActDense: AutoWrappedModule,
|
||||||
module_config = dict(
|
T5DenseGatedActDense: AutoWrappedModule,
|
||||||
offload_dtype=dtype,
|
},
|
||||||
offload_device="cpu",
|
module_config = dict(
|
||||||
onload_dtype=dtype,
|
offload_dtype=dtype,
|
||||||
onload_device="cpu",
|
offload_device="cpu",
|
||||||
computation_dtype=self.torch_dtype,
|
onload_dtype=dtype,
|
||||||
computation_device=self.device,
|
onload_device="cpu",
|
||||||
),
|
computation_dtype=self.torch_dtype,
|
||||||
)
|
computation_device=self.device,
|
||||||
dtype = next(iter(self.dit.parameters())).dtype
|
),
|
||||||
enable_vram_management(
|
)
|
||||||
self.dit,
|
if self.dit is not None:
|
||||||
module_map = {
|
dtype = next(iter(self.dit.parameters())).dtype
|
||||||
RMSNorm: AutoWrappedModule,
|
enable_vram_management(
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
self.dit,
|
||||||
},
|
module_map = {
|
||||||
module_config = dict(
|
RMSNorm: AutoWrappedModule,
|
||||||
offload_dtype=dtype,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
offload_device="cpu",
|
},
|
||||||
onload_dtype=dtype,
|
module_config = dict(
|
||||||
onload_device="cuda",
|
offload_dtype=dtype,
|
||||||
computation_dtype=self.torch_dtype,
|
offload_device="cpu",
|
||||||
computation_device=self.device,
|
onload_dtype=dtype,
|
||||||
),
|
onload_device="cuda",
|
||||||
max_num_param=num_persistent_param_in_dit,
|
computation_dtype=self.torch_dtype,
|
||||||
overflow_module_config = dict(
|
computation_device=self.device,
|
||||||
offload_dtype=dtype,
|
),
|
||||||
offload_device="cpu",
|
max_num_param=num_persistent_param_in_dit,
|
||||||
onload_dtype=dtype,
|
overflow_module_config = dict(
|
||||||
onload_device="cpu",
|
offload_dtype=dtype,
|
||||||
computation_dtype=self.torch_dtype,
|
offload_device="cpu",
|
||||||
computation_device=self.device,
|
onload_dtype=dtype,
|
||||||
),
|
onload_device="cpu",
|
||||||
)
|
computation_dtype=self.torch_dtype,
|
||||||
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
computation_device=self.device,
|
||||||
enable_vram_management(
|
),
|
||||||
self.vae_decoder,
|
)
|
||||||
module_map = {
|
if self.vae_decoder is not None:
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||||
torch.nn.Conv2d: AutoWrappedModule,
|
enable_vram_management(
|
||||||
torch.nn.GroupNorm: AutoWrappedModule,
|
self.vae_decoder,
|
||||||
},
|
module_map = {
|
||||||
module_config = dict(
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
offload_dtype=dtype,
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
offload_device="cpu",
|
torch.nn.GroupNorm: AutoWrappedModule,
|
||||||
onload_dtype=dtype,
|
},
|
||||||
onload_device="cpu",
|
module_config = dict(
|
||||||
computation_dtype=self.torch_dtype,
|
offload_dtype=dtype,
|
||||||
computation_device=self.device,
|
offload_device="cpu",
|
||||||
),
|
onload_dtype=dtype,
|
||||||
)
|
onload_device="cpu",
|
||||||
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
computation_dtype=self.torch_dtype,
|
||||||
enable_vram_management(
|
computation_device=self.device,
|
||||||
self.vae_encoder,
|
),
|
||||||
module_map = {
|
)
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
if self.vae_encoder is not None:
|
||||||
torch.nn.Conv2d: AutoWrappedModule,
|
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||||
torch.nn.GroupNorm: AutoWrappedModule,
|
enable_vram_management(
|
||||||
},
|
self.vae_encoder,
|
||||||
module_config = dict(
|
module_map = {
|
||||||
offload_dtype=dtype,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
offload_device="cpu",
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
onload_dtype=dtype,
|
torch.nn.GroupNorm: AutoWrappedModule,
|
||||||
onload_device="cpu",
|
},
|
||||||
computation_dtype=self.torch_dtype,
|
module_config = dict(
|
||||||
computation_device=self.device,
|
offload_dtype=dtype,
|
||||||
),
|
offload_device="cpu",
|
||||||
)
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
self.enable_cpu_offload()
|
self.enable_cpu_offload()
|
||||||
|
|
||||||
|
|
||||||
@@ -403,6 +408,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
def prepare_step1x_kwargs(self, prompt, negative_prompt, image):
|
def prepare_step1x_kwargs(self, prompt, negative_prompt, image):
|
||||||
if image is None:
|
if image is None:
|
||||||
return {}, {}
|
return {}, {}
|
||||||
|
self.load_models_to_device(["qwenvl", "vae_encoder"])
|
||||||
captions = [prompt, negative_prompt]
|
captions = [prompt, negative_prompt]
|
||||||
ref_images = [image, image]
|
ref_images = [image, image]
|
||||||
embs, masks = self.qwenvl(captions, ref_images)
|
embs, masks = self.qwenvl(captions, ref_images)
|
||||||
@@ -504,7 +510,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(['dit', 'controlnet'])
|
self.load_models_to_device(['dit', 'controlnet', 'step1x_connector'])
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
timestep = timestep.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ model_manager.load_models([
|
|||||||
"models/stepfun-ai/Step1X-Edit/vae.safetensors",
|
"models/stepfun-ai/Step1X-Edit/vae.safetensors",
|
||||||
])
|
])
|
||||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
|
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
|
|||||||
Reference in New Issue
Block a user