mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #197 from mi804/cpuoffload
add cpuoffload support for image pipelines
This commit is contained in:
@@ -23,6 +23,14 @@ class MultiControlNetManager:
|
||||
self.models = [unit.model for unit in controlnet_units]
|
||||
self.scales = [unit.scale for unit in controlnet_units]
|
||||
|
||||
def cpu(self):
|
||||
for model in self.models:
|
||||
model.cpu()
|
||||
|
||||
def to(self, device):
|
||||
for model in self.models:
|
||||
model.to(device)
|
||||
|
||||
def process_image(self, image, processor_id=None):
|
||||
if processor_id is None:
|
||||
processed_image = [processor(image) for processor in self.processors]
|
||||
|
||||
@@ -10,6 +10,8 @@ class BasePipeline(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.cpu_offload = False
|
||||
self.model_names = []
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
@@ -59,4 +61,24 @@ class BasePipeline(torch.nn.Module):
|
||||
masks += extended_prompt_dict.get("masks", [])
|
||||
mask_scales += [5.0] * len(extended_prompt_dict.get("masks", []))
|
||||
return prompt, local_prompts, masks, mask_scales
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
self.cpu_offload = True
|
||||
|
||||
def load_models_to_device(self, loadmodel_names=[]):
|
||||
# only load models to device if cpu_offload is enabled
|
||||
if not self.cpu_offload:
|
||||
return
|
||||
# offload the unneeded models to cpu
|
||||
for model_name in self.model_names:
|
||||
if model_name not in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.cpu()
|
||||
# load the needed models to device
|
||||
for model_name in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.to(self.device)
|
||||
# fresh the cuda cache
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -19,6 +19,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.dit: FluxDiT = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
@@ -37,9 +38,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[],prompt_extender_classes=[]):
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
|
||||
pipe = FluxImagePipeline(
|
||||
device=model_manager.device,
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes)
|
||||
@@ -99,6 +100,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
@@ -107,6 +109,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Extend prompt
|
||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
|
||||
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
|
||||
|
||||
# Encode prompts
|
||||
@@ -119,6 +122,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['dit'])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
@@ -143,6 +147,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# Offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
@@ -135,6 +135,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
self.dit: HunyuanDiT = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
self.model_names = ['text_encoder', 'text_encoder_t5', 'dit', 'vae_decoder', 'vae_encoder']
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
@@ -153,9 +154,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
|
||||
pipe = HunyuanDiTImagePipeline(
|
||||
device=model_manager.device,
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
||||
@@ -234,6 +235,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
if input_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
@@ -241,6 +243,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
latents = noise.clone()
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(['text_encoder', 'text_encoder_t5'])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
@@ -250,6 +253,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['dit'])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
@@ -273,6 +277,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# Offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
@@ -20,6 +20,7 @@ class SD3ImagePipeline(BasePipeline):
|
||||
self.dit: SD3DiT = None
|
||||
self.vae_decoder: SD3VAEDecoder = None
|
||||
self.vae_encoder: SD3VAEEncoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'text_encoder_3', 'dit', 'vae_decoder', 'vae_encoder']
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
@@ -38,9 +39,9 @@ class SD3ImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
|
||||
pipe = SD3ImagePipeline(
|
||||
device=model_manager.device,
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
||||
@@ -97,6 +98,7 @@ class SD3ImagePipeline(BasePipeline):
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
@@ -105,11 +107,13 @@ class SD3ImagePipeline(BasePipeline):
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['dit'])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
@@ -131,6 +135,9 @@ class SD3ImagePipeline(BasePipeline):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
@@ -25,6 +25,7 @@ class SDImagePipeline(BasePipeline):
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
|
||||
self.ipadapter: SDIpAdapter = None
|
||||
self.model_names = ['text_encoder', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
@@ -57,9 +58,9 @@ class SDImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
|
||||
pipe = SDImagePipeline(
|
||||
device=model_manager.device,
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
|
||||
@@ -118,6 +119,7 @@ class SDImagePipeline(BasePipeline):
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
@@ -126,13 +128,16 @@ class SDImagePipeline(BasePipeline):
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(['text_encoder'])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts]
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
self.load_models_to_device(['ipadapter_image_encoder'])
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
self.load_models_to_device(['ipadapter'])
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
@@ -140,6 +145,7 @@ class SDImagePipeline(BasePipeline):
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_image is not None:
|
||||
self.load_models_to_device(['controlnet'])
|
||||
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
controlnet_image = controlnet_image.unsqueeze(1)
|
||||
controlnet_kwargs = {"controlnet_frames": controlnet_image}
|
||||
@@ -147,6 +153,7 @@ class SDImagePipeline(BasePipeline):
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['controlnet', 'unet'])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
@@ -173,6 +180,9 @@ class SDImagePipeline(BasePipeline):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
@@ -28,6 +28,7 @@ class SDXLImagePipeline(BasePipeline):
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
self.model_names = ['text_encoder', 'text_encoder_2', 'text_encoder_kolors', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
@@ -70,9 +71,9 @@ class SDXLImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
|
||||
pipe = SDXLImagePipeline(
|
||||
device=model_manager.device,
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
||||
@@ -139,6 +140,7 @@ class SDXLImagePipeline(BasePipeline):
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
@@ -147,6 +149,7 @@ class SDXLImagePipeline(BasePipeline):
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(['text_encoder', 'text_encoder_2', 'text_encoder_kolors'])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
|
||||
@@ -157,7 +160,9 @@ class SDXLImagePipeline(BasePipeline):
|
||||
self.ipadapter.set_less_adapter()
|
||||
else:
|
||||
self.ipadapter.set_full_adapter()
|
||||
self.load_models_to_device(['ipadapter_image_encoder'])
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
self.load_models_to_device(['ipadapter'])
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
@@ -165,6 +170,7 @@ class SDXLImagePipeline(BasePipeline):
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_image is not None:
|
||||
self.load_models_to_device(['controlnet'])
|
||||
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
controlnet_image = controlnet_image.unsqueeze(1)
|
||||
controlnet_kwargs = {"controlnet_frames": controlnet_image}
|
||||
@@ -175,6 +181,7 @@ class SDXLImagePipeline(BasePipeline):
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['controlnet', 'unet'])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
@@ -206,6 +213,9 @@ class SDXLImagePipeline(BasePipeline):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
42
examples/image_synthesis/flux_text_to_image_cpu_offload.py
Normal file
42
examples/image_synthesis/flux_text_to_image_cpu_offload.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_models
|
||||
|
||||
|
||||
download_models(["FLUX.1-dev"])
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, device='cuda')
|
||||
pipe.enable_cpu_offload()
|
||||
|
||||
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
|
||||
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
|
||||
|
||||
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
|
||||
torch.manual_seed(6)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=30, embedded_guidance=3.5
|
||||
)
|
||||
image.save("image_1024.jpg")
|
||||
|
||||
# Enable classifier-free guidance
|
||||
torch.manual_seed(6)
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
|
||||
)
|
||||
image.save("image_1024_cfg.jpg")
|
||||
|
||||
# Highres-fix
|
||||
torch.manual_seed(7)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=30, embedded_guidance=3.5,
|
||||
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
|
||||
)
|
||||
image.save("image_2048_highres.jpg")
|
||||
Reference in New Issue
Block a user