Merge pull request #197 from mi804/cpuoffload

add cpuoffload support for image pipelines
This commit is contained in:
Zhongjie Duan
2024-09-09 14:48:26 +08:00
committed by GitHub
8 changed files with 124 additions and 11 deletions

View File

@@ -23,6 +23,14 @@ class MultiControlNetManager:
self.models = [unit.model for unit in controlnet_units]
self.scales = [unit.scale for unit in controlnet_units]
def cpu(self):
for model in self.models:
model.cpu()
def to(self, device):
for model in self.models:
model.to(device)
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]

View File

@@ -10,6 +10,8 @@ class BasePipeline(torch.nn.Module):
super().__init__()
self.device = device
self.torch_dtype = torch_dtype
self.cpu_offload = False
self.model_names = []
def preprocess_image(self, image):
@@ -59,4 +61,24 @@ class BasePipeline(torch.nn.Module):
masks += extended_prompt_dict.get("masks", [])
mask_scales += [5.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
return
# offload the unneeded models to cpu
for model_name in self.model_names:
if model_name not in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
model.cpu()
# load the needed models to device
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()

View File

@@ -19,6 +19,7 @@ class FluxImagePipeline(BasePipeline):
self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
@@ -37,9 +38,9 @@ class FluxImagePipeline(BasePipeline):
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[],prompt_extender_classes=[]):
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
pipe = FluxImagePipeline(
device=model_manager.device,
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes)
@@ -99,6 +100,7 @@ class FluxImagePipeline(BasePipeline):
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
@@ -107,6 +109,7 @@ class FluxImagePipeline(BasePipeline):
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
@@ -119,6 +122,7 @@ class FluxImagePipeline(BasePipeline):
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Denoise
self.load_models_to_device(['dit'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
@@ -143,6 +147,9 @@ class FluxImagePipeline(BasePipeline):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# Offload all models
self.load_models_to_device([])
return image

View File

@@ -135,6 +135,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
self.dit: HunyuanDiT = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.model_names = ['text_encoder', 'text_encoder_t5', 'dit', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
@@ -153,9 +154,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
pipe = HunyuanDiTImagePipeline(
device=model_manager.device,
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
@@ -234,6 +235,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
# Prepare latent tensors
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
@@ -241,6 +243,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
latents = noise.clone()
# Encode prompts
self.load_models_to_device(['text_encoder', 'text_encoder_t5'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
@@ -250,6 +253,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
# Denoise
self.load_models_to_device(['dit'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
@@ -273,6 +277,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# Offload all models
self.load_models_to_device([])
return image

View File

@@ -20,6 +20,7 @@ class SD3ImagePipeline(BasePipeline):
self.dit: SD3DiT = None
self.vae_decoder: SD3VAEDecoder = None
self.vae_encoder: SD3VAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'text_encoder_3', 'dit', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
@@ -38,9 +39,9 @@ class SD3ImagePipeline(BasePipeline):
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
pipe = SD3ImagePipeline(
device=model_manager.device,
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
@@ -97,6 +98,7 @@ class SD3ImagePipeline(BasePipeline):
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
@@ -105,11 +107,13 @@ class SD3ImagePipeline(BasePipeline):
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
# Denoise
self.load_models_to_device(['dit'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
@@ -131,6 +135,9 @@ class SD3ImagePipeline(BasePipeline):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -25,6 +25,7 @@ class SDImagePipeline(BasePipeline):
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
self.ipadapter: SDIpAdapter = None
self.model_names = ['text_encoder', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
def denoising_model(self):
@@ -57,9 +58,9 @@ class SDImagePipeline(BasePipeline):
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
pipe = SDImagePipeline(
device=model_manager.device,
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
@@ -118,6 +119,7 @@ class SDImagePipeline(BasePipeline):
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
@@ -126,13 +128,16 @@ class SDImagePipeline(BasePipeline):
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(['text_encoder'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts]
# IP-Adapter
if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
@@ -140,6 +145,7 @@ class SDImagePipeline(BasePipeline):
# Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['controlnet'])
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
controlnet_kwargs = {"controlnet_frames": controlnet_image}
@@ -147,6 +153,7 @@ class SDImagePipeline(BasePipeline):
controlnet_kwargs = {"controlnet_frames": None}
# Denoise
self.load_models_to_device(['controlnet', 'unet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
@@ -173,6 +180,9 @@ class SDImagePipeline(BasePipeline):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -28,6 +28,7 @@ class SDXLImagePipeline(BasePipeline):
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
self.model_names = ['text_encoder', 'text_encoder_2', 'text_encoder_kolors', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
def denoising_model(self):
@@ -70,9 +71,9 @@ class SDXLImagePipeline(BasePipeline):
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
pipe = SDXLImagePipeline(
device=model_manager.device,
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
@@ -139,6 +140,7 @@ class SDXLImagePipeline(BasePipeline):
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
@@ -147,6 +149,7 @@ class SDXLImagePipeline(BasePipeline):
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(['text_encoder', 'text_encoder_2', 'text_encoder_kolors'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
@@ -157,7 +160,9 @@ class SDXLImagePipeline(BasePipeline):
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
@@ -165,6 +170,7 @@ class SDXLImagePipeline(BasePipeline):
# Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['controlnet'])
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
controlnet_kwargs = {"controlnet_frames": controlnet_image}
@@ -175,6 +181,7 @@ class SDXLImagePipeline(BasePipeline):
extra_input = self.prepare_extra_input(latents)
# Denoise
self.load_models_to_device(['controlnet', 'unet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
@@ -206,6 +213,9 @@ class SDXLImagePipeline(BasePipeline):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -0,0 +1,42 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_models
download_models(["FLUX.1-dev"])
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager, device='cuda')
pipe.enable_cpu_offload()
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
torch.manual_seed(6)
image = pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5
)
image.save("image_1024.jpg")
# Enable classifier-free guidance
torch.manual_seed(6)
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
)
image.save("image_1024_cfg.jpg")
# Highres-fix
torch.manual_seed(7)
image = pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5,
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
)
image.save("image_2048_highres.jpg")