From 9377214518a9f599331670eddabfe1214533c00b Mon Sep 17 00:00:00 2001 From: tc2000731 Date: Thu, 31 Oct 2024 17:38:57 +0800 Subject: [PATCH] update controlnet_frames, downloads --- diffsynth/pipelines/flux_image.py | 3 +- .../flux_controlnet_quantization.py | 65 +++++++++++-------- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index f038113..69664b3 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -258,6 +258,7 @@ def lets_dance_flux( ): if tiled: def flux_forward_fn(hl, hr, wl, wr): + tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None return lets_dance_flux( dit=dit, controlnet=controlnet, @@ -268,7 +269,7 @@ def lets_dance_flux( guidance=guidance, text_ids=text_ids, image_ids=None, - controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames], + controlnet_frames=tiled_controlnet_frames, tiled=False, **kwargs ) diff --git a/examples/ControlNet/flux_controlnet_quantization.py b/examples/ControlNet/flux_controlnet_quantization.py index 4753c87..c4aa3c1 100644 --- a/examples/ControlNet/flux_controlnet_quantization.py +++ b/examples/ControlNet/flux_controlnet_quantization.py @@ -6,10 +6,10 @@ import numpy as np def example_1(): + download_models(["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -18,11 +18,11 @@ def example_1(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -55,10 +55,10 @@ def example_1(): def example_2(): + download_models(["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -67,11 +67,11 @@ def example_2(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -102,10 +102,10 @@ def example_2(): def example_3(): + download_models(["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -114,11 +114,11 @@ def example_3(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -153,10 +153,10 @@ def example_3(): def example_4(): + download_models(["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -165,11 +165,11 @@ def example_4(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -205,10 +205,10 @@ def example_4(): def example_5(): + download_models(["FLUX.1-dev", "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -217,11 +217,11 @@ def example_5(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -257,10 +257,14 @@ def example_5(): def example_6(): + download_models([ + "FLUX.1-dev", + "jasperai/Flux.1-dev-Controlnet-Surface-Normals", + "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta" + ]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -269,12 +273,12 @@ def example_6(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -314,10 +318,15 @@ def example_6(): def example_7(): + download_models([ + "FLUX.1-dev", + "InstantX/FLUX.1-dev-Controlnet-Union-alpha", + "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", + "jasperai/Flux.1-dev-Controlnet-Upscaler", + ]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -326,13 +335,13 @@ def example_7(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit(