diff --git a/diffsynth/controlnets/controlnet_unit.py b/diffsynth/controlnets/controlnet_unit.py index 63129d4..f03fec5 100644 --- a/diffsynth/controlnets/controlnet_unit.py +++ b/diffsynth/controlnets/controlnet_unit.py @@ -23,6 +23,14 @@ class MultiControlNetManager: self.models = [unit.model for unit in controlnet_units] self.scales = [unit.scale for unit in controlnet_units] + def cpu(self): + for model in self.models: + model.cpu() + + def to(self, device): + for model in self.models: + model.to(device) + def process_image(self, image, processor_id=None): if processor_id is None: processed_image = [processor(image) for processor in self.processors] diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index 2feb405..faa1d86 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -10,6 +10,8 @@ class BasePipeline(torch.nn.Module): super().__init__() self.device = device self.torch_dtype = torch_dtype + self.cpu_offload = False + self.model_names = [] def preprocess_image(self, image): @@ -59,4 +61,24 @@ class BasePipeline(torch.nn.Module): masks += extended_prompt_dict.get("masks", []) mask_scales += [5.0] * len(extended_prompt_dict.get("masks", [])) return prompt, local_prompts, masks, mask_scales - \ No newline at end of file + + def enable_cpu_offload(self): + self.cpu_offload = True + + def load_models_to_device(self, loadmodel_names=[]): + # only load models to device if cpu_offload is enabled + if not self.cpu_offload: + return + # offload the unneeded models to cpu + for model_name in self.model_names: + if model_name not in loadmodel_names: + model = getattr(self, model_name) + if model is not None: + model.cpu() + # load the needed models to device + for model_name in loadmodel_names: + model = getattr(self, model_name) + if model is not None: + model.to(self.device) + # fresh the cuda cache + torch.cuda.empty_cache() diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 8d6a246..67d961a 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -19,6 +19,7 @@ class FluxImagePipeline(BasePipeline): self.dit: FluxDiT = None self.vae_decoder: FluxVAEDecoder = None self.vae_encoder: FluxVAEEncoder = None + self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder'] def denoising_model(self): @@ -37,9 +38,9 @@ class FluxImagePipeline(BasePipeline): @staticmethod - def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[],prompt_extender_classes=[]): + def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None): pipe = FluxImagePipeline( - device=model_manager.device, + device=model_manager.device if device is None else device, torch_dtype=model_manager.torch_dtype, ) pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes) @@ -99,6 +100,7 @@ class FluxImagePipeline(BasePipeline): # Prepare latent tensors if input_image is not None: + self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) latents = self.encode_image(image, **tiler_kwargs) noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) @@ -107,6 +109,7 @@ class FluxImagePipeline(BasePipeline): latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) # Extend prompt + self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) # Encode prompts @@ -119,6 +122,7 @@ class FluxImagePipeline(BasePipeline): extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) # Denoise + self.load_models_to_device(['dit']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) @@ -143,6 +147,9 @@ class FluxImagePipeline(BasePipeline): progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image + self.load_models_to_device(['vae_decoder']) image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + # Offload all models + self.load_models_to_device([]) return image diff --git a/diffsynth/pipelines/hunyuan_image.py b/diffsynth/pipelines/hunyuan_image.py index 9181431..3407e54 100644 --- a/diffsynth/pipelines/hunyuan_image.py +++ b/diffsynth/pipelines/hunyuan_image.py @@ -135,6 +135,7 @@ class HunyuanDiTImagePipeline(BasePipeline): self.dit: HunyuanDiT = None self.vae_decoder: SDXLVAEDecoder = None self.vae_encoder: SDXLVAEEncoder = None + self.model_names = ['text_encoder', 'text_encoder_t5', 'dit', 'vae_decoder', 'vae_encoder'] def denoising_model(self): @@ -153,9 +154,9 @@ class HunyuanDiTImagePipeline(BasePipeline): @staticmethod - def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]): + def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None): pipe = HunyuanDiTImagePipeline( - device=model_manager.device, + device=model_manager.device if device is None else device, torch_dtype=model_manager.torch_dtype, ) pipe.fetch_models(model_manager, prompt_refiner_classes) @@ -234,6 +235,7 @@ class HunyuanDiTImagePipeline(BasePipeline): # Prepare latent tensors noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) if input_image is not None: + self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32) latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) @@ -241,6 +243,7 @@ class HunyuanDiTImagePipeline(BasePipeline): latents = noise.clone() # Encode prompts + self.load_models_to_device(['text_encoder', 'text_encoder_t5']) prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) if cfg_scale != 1.0: prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) @@ -250,6 +253,7 @@ class HunyuanDiTImagePipeline(BasePipeline): extra_input = self.prepare_extra_input(latents, tiled, tile_size) # Denoise + self.load_models_to_device(['dit']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device) @@ -273,6 +277,9 @@ class HunyuanDiTImagePipeline(BasePipeline): progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image + self.load_models_to_device(['vae_decoder']) image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + # Offload all models + self.load_models_to_device([]) return image diff --git a/diffsynth/pipelines/sd3_image.py b/diffsynth/pipelines/sd3_image.py index d7dd371..c624ce4 100644 --- a/diffsynth/pipelines/sd3_image.py +++ b/diffsynth/pipelines/sd3_image.py @@ -20,6 +20,7 @@ class SD3ImagePipeline(BasePipeline): self.dit: SD3DiT = None self.vae_decoder: SD3VAEDecoder = None self.vae_encoder: SD3VAEEncoder = None + self.model_names = ['text_encoder_1', 'text_encoder_2', 'text_encoder_3', 'dit', 'vae_decoder', 'vae_encoder'] def denoising_model(self): @@ -38,9 +39,9 @@ class SD3ImagePipeline(BasePipeline): @staticmethod - def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]): + def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None): pipe = SD3ImagePipeline( - device=model_manager.device, + device=model_manager.device if device is None else device, torch_dtype=model_manager.torch_dtype, ) pipe.fetch_models(model_manager, prompt_refiner_classes) @@ -97,6 +98,7 @@ class SD3ImagePipeline(BasePipeline): # Prepare latent tensors if input_image is not None: + self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) latents = self.encode_image(image, **tiler_kwargs) noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) @@ -105,11 +107,13 @@ class SD3ImagePipeline(BasePipeline): latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) # Encode prompts + self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3']) prompt_emb_posi = self.encode_prompt(prompt, positive=True) prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts] # Denoise + self.load_models_to_device(['dit']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) @@ -131,6 +135,9 @@ class SD3ImagePipeline(BasePipeline): progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image + self.load_models_to_device(['vae_decoder']) image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + # offload all models + self.load_models_to_device([]) return image diff --git a/diffsynth/pipelines/sd_image.py b/diffsynth/pipelines/sd_image.py index 016720d..8847027 100644 --- a/diffsynth/pipelines/sd_image.py +++ b/diffsynth/pipelines/sd_image.py @@ -25,6 +25,7 @@ class SDImagePipeline(BasePipeline): self.controlnet: MultiControlNetManager = None self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None self.ipadapter: SDIpAdapter = None + self.model_names = ['text_encoder', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter'] def denoising_model(self): @@ -57,9 +58,9 @@ class SDImagePipeline(BasePipeline): @staticmethod - def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]): + def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None): pipe = SDImagePipeline( - device=model_manager.device, + device=model_manager.device if device is None else device, torch_dtype=model_manager.torch_dtype, ) pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[]) @@ -118,6 +119,7 @@ class SDImagePipeline(BasePipeline): # Prepare latent tensors if input_image is not None: + self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) latents = self.encode_image(image, **tiler_kwargs) noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) @@ -126,13 +128,16 @@ class SDImagePipeline(BasePipeline): latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) # Encode prompts + self.load_models_to_device(['text_encoder']) prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True) prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False) prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts] # IP-Adapter if ipadapter_images is not None: + self.load_models_to_device(['ipadapter_image_encoder']) ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) + self.load_models_to_device(['ipadapter']) ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} else: @@ -140,6 +145,7 @@ class SDImagePipeline(BasePipeline): # Prepare ControlNets if controlnet_image is not None: + self.load_models_to_device(['controlnet']) controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype) controlnet_image = controlnet_image.unsqueeze(1) controlnet_kwargs = {"controlnet_frames": controlnet_image} @@ -147,6 +153,7 @@ class SDImagePipeline(BasePipeline): controlnet_kwargs = {"controlnet_frames": None} # Denoise + self.load_models_to_device(['controlnet', 'unet']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) @@ -173,6 +180,9 @@ class SDImagePipeline(BasePipeline): progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image + self.load_models_to_device(['vae_decoder']) image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + # offload all models + self.load_models_to_device([]) return image diff --git a/diffsynth/pipelines/sdxl_image.py b/diffsynth/pipelines/sdxl_image.py index 2cd73d8..81f6b40 100644 --- a/diffsynth/pipelines/sdxl_image.py +++ b/diffsynth/pipelines/sdxl_image.py @@ -28,6 +28,7 @@ class SDXLImagePipeline(BasePipeline): self.controlnet: MultiControlNetManager = None self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None self.ipadapter: SDXLIpAdapter = None + self.model_names = ['text_encoder', 'text_encoder_2', 'text_encoder_kolors', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter'] def denoising_model(self): @@ -70,9 +71,9 @@ class SDXLImagePipeline(BasePipeline): @staticmethod - def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]): + def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None): pipe = SDXLImagePipeline( - device=model_manager.device, + device=model_manager.device if device is None else device, torch_dtype=model_manager.torch_dtype, ) pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes) @@ -139,6 +140,7 @@ class SDXLImagePipeline(BasePipeline): # Prepare latent tensors if input_image is not None: + self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) latents = self.encode_image(image, **tiler_kwargs) noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) @@ -147,6 +149,7 @@ class SDXLImagePipeline(BasePipeline): latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) # Encode prompts + self.load_models_to_device(['text_encoder', 'text_encoder_2', 'text_encoder_kolors']) prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False) prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts] @@ -157,7 +160,9 @@ class SDXLImagePipeline(BasePipeline): self.ipadapter.set_less_adapter() else: self.ipadapter.set_full_adapter() + self.load_models_to_device(['ipadapter_image_encoder']) ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) + self.load_models_to_device(['ipadapter']) ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} else: @@ -165,6 +170,7 @@ class SDXLImagePipeline(BasePipeline): # Prepare ControlNets if controlnet_image is not None: + self.load_models_to_device(['controlnet']) controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype) controlnet_image = controlnet_image.unsqueeze(1) controlnet_kwargs = {"controlnet_frames": controlnet_image} @@ -175,6 +181,7 @@ class SDXLImagePipeline(BasePipeline): extra_input = self.prepare_extra_input(latents) # Denoise + self.load_models_to_device(['controlnet', 'unet']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) @@ -206,6 +213,9 @@ class SDXLImagePipeline(BasePipeline): progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image + self.load_models_to_device(['vae_decoder']) image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + # offload all models + self.load_models_to_device([]) return image diff --git a/examples/image_synthesis/flux_text_to_image_cpu_offload.py b/examples/image_synthesis/flux_text_to_image_cpu_offload.py new file mode 100644 index 0000000..4298a2f --- /dev/null +++ b/examples/image_synthesis/flux_text_to_image_cpu_offload.py @@ -0,0 +1,42 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_models + + +download_models(["FLUX.1-dev"]) +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") +model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" +]) +pipe = FluxImagePipeline.from_model_manager(model_manager, device='cuda') +pipe.enable_cpu_offload() + +prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." +negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," + +# Disable classifier-free guidance (consistent with the original implementation of FLUX.1) +torch.manual_seed(6) +image = pipe( + prompt=prompt, + num_inference_steps=30, embedded_guidance=3.5 +) +image.save("image_1024.jpg") + +# Enable classifier-free guidance +torch.manual_seed(6) +image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5 +) +image.save("image_1024_cfg.jpg") + +# Highres-fix +torch.manual_seed(7) +image = pipe( + prompt=prompt, + num_inference_steps=30, embedded_guidance=3.5, + input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True +) +image.save("image_2048_highres.jpg")