From 900a1c095f37a307b5f623adc929e596e8a9c6f9 Mon Sep 17 00:00:00 2001 From: tc2000731 Date: Tue, 29 Oct 2024 17:29:24 +0800 Subject: [PATCH 1/2] add Flux_ControlNet_Quantization --- diffsynth/controlnets/controlnet_unit.py | 2 + diffsynth/controlnets/processors.py | 5 + diffsynth/models/flux_controlnet.py | 105 ++++- diffsynth/models/flux_dit.py | 3 + diffsynth/models/lora.py | 6 + diffsynth/pipelines/flux_image.py | 1 + .../flux_controlnet_quantization.py | 438 ++++++++++++++++++ 7 files changed, 558 insertions(+), 2 deletions(-) create mode 100644 examples/ControlNet/flux_controlnet_quantization.py diff --git a/diffsynth/controlnets/controlnet_unit.py b/diffsynth/controlnets/controlnet_unit.py index fba09b6..fdb4829 100644 --- a/diffsynth/controlnets/controlnet_unit.py +++ b/diffsynth/controlnets/controlnet_unit.py @@ -31,6 +31,8 @@ class MultiControlNetManager: def to(self, device): for model in self.models: model.to(device) + for processor in self.processors: + processor.to(device) def process_image(self, image, processor_id=None): if processor_id is None: diff --git a/diffsynth/controlnets/processors.py b/diffsynth/controlnets/processors.py index 71e47da..cbc6ff4 100644 --- a/diffsynth/controlnets/processors.py +++ b/diffsynth/controlnets/processors.py @@ -37,6 +37,11 @@ class Annotator: self.processor_id = processor_id self.detect_resolution = detect_resolution + + def to(self,device): + if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"): + + self.processor.model.to(device) def __call__(self, image, mask=None): width, height = image.size diff --git a/diffsynth/models/flux_controlnet.py b/diffsynth/models/flux_controlnet.py index d6053b1..d812e6c 100644 --- a/diffsynth/models/flux_controlnet.py +++ b/diffsynth/models/flux_controlnet.py @@ -1,7 +1,7 @@ import torch from einops import rearrange, repeat -from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock -from .utils import hash_state_dict_keys +from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm +from .utils import hash_state_dict_keys, init_weights_on_device @@ -106,6 +106,107 @@ class FluxControlNet(torch.nn.Module): def state_dict_converter(): return FluxControlNetStateDictConverter() + def quantize(self): + def cast_to(weight, dtype=None, device=None, copy=False): + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + return weight.to(dtype=dtype, copy=copy) + + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + def cast_weight(s, input=None, dtype=None, device=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if device is None: + device = input.device + weight = cast_to(s.weight, dtype, device) + return weight + + def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if bias_dtype is None: + bias_dtype = dtype + if device is None: + device = input.device + bias = None + weight = cast_to(s.weight, dtype, device) + bias = cast_to(s.bias, bias_dtype, device) + return weight, bias + + class quantized_layer: + class QLinear(torch.nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self,input,**kwargs): + weight,bias= cast_bias_weight(self,input) + return torch.nn.functional.linear(input,weight,bias) + + class QRMSNorm(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self,hidden_states,**kwargs): + weight= cast_weight(self.module,hidden_states) + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) + hidden_states = hidden_states.to(input_dtype) * weight + return hidden_states + + class QEmbedding(torch.nn.Embedding): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self,input,**kwargs): + weight= cast_weight(self,input) + return torch.nn.functional.embedding( + input, weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + + def replace_layer(model): + for name, module in model.named_children(): + if isinstance(module,quantized_layer.QRMSNorm): + continue + if isinstance(module, torch.nn.Linear): + with init_weights_on_device(): + new_layer = quantized_layer.QLinear(module.in_features,module.out_features) + new_layer.weight = module.weight + if module.bias is not None: + new_layer.bias = module.bias + setattr(model, name, new_layer) + elif isinstance(module, RMSNorm): + if hasattr(module,"quantized"): + continue + module.quantized= True + new_layer = quantized_layer.QRMSNorm(module) + setattr(model, name, new_layer) + elif isinstance(module,torch.nn.Embedding): + rows, cols = module.weight.shape + new_layer = quantized_layer.QEmbedding( + num_embeddings=rows, + embedding_dim=cols, + _weight=module.weight, + # _freeze=module.freeze, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse) + setattr(model, name, new_layer) + else: + replace_layer(module) + + replace_layer(self) + class FluxControlNetStateDictConverter: diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index f308a93..9ab3958 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -475,6 +475,9 @@ class FluxDiT(torch.nn.Module): # del module setattr(model, name, new_layer) elif isinstance(module, RMSNorm): + if hasattr(module,"quantized"): + continue + module.quantized= True new_layer = quantized_layer.RMSNorm(module) setattr(model, name, new_layer) else: diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index e948945..eebc4a2 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -83,8 +83,14 @@ class LoRAFromCivitai: if len(state_dict_lora) > 0: print(f" {len(state_dict_lora)} tensors are updated.") for name in state_dict_lora: + fp8=False + if state_dict_model[name].dtype == torch.float8_e4m3fn: + state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype) + fp8=True state_dict_model[name] += state_dict_lora[name].to( dtype=state_dict_model[name].dtype, device=state_dict_model[name].device) + if fp8: + state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn) model.load_state_dict(state_dict_model) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 89d730f..f038113 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -187,6 +187,7 @@ class FluxImagePipeline(BasePipeline): # Prepare ControlNets if controlnet_image is not None: + self.load_models_to_device(['vae_encoder']) controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)} if len(masks) > 0 and controlnet_inpaint_mask is not None: print("The controlnet_inpaint_mask will be overridden by masks.") diff --git a/examples/ControlNet/flux_controlnet_quantization.py b/examples/ControlNet/flux_controlnet_quantization.py new file mode 100644 index 0000000..4753c87 --- /dev/null +++ b/examples/ControlNet/flux_controlnet_quantization.py @@ -0,0 +1,438 @@ +from diffsynth import ModelManager, FluxImagePipeline, ControlNetConfigUnit, download_models, download_customized_models +import torch +from PIL import Image +import numpy as np + + + +def example_1(): + model_manager = ModelManager( + torch_dtype=torch.bfloat16, + device="cpu" # To reduce VRAM required, we load models to RAM. + # device="cuda" # To reduce VRAM required, we load models to RAM. + ) + model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + ]) + model_manager.load_models( + ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + model_manager.load_models( + ["models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="tile", + model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", + scale=0.7 + ), + ],device="cuda") + pipe.enable_cpu_offload() + pipe.dit.quantize() + for model in pipe.controlnet.models: + model.quantize() + + image_1 = pipe( + prompt="a photo of a cat, highly detailed", + height=768, width=768, + seed=0 + ) + image_1.save("image_1.jpg") + + image_2 = pipe( + prompt="a photo of a cat, highly detailed", + controlnet_image=image_1.resize((2048, 2048)), + input_image=image_1.resize((2048, 2048)), denoising_strength=0.99, + height=2048, width=2048, tiled=True, + seed=1 + ) + image_2.save("image_2.jpg") + + + +def example_2(): + model_manager = ModelManager( + torch_dtype=torch.bfloat16, + device="cpu" # To reduce VRAM required, we load models to RAM. + # device="cuda" # To reduce VRAM required, we load models to RAM. + ) + model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + ]) + model_manager.load_models( + ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + model_manager.load_models( + ["models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="tile", + model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", + scale=0.7 + ), + ],device="cuda") + pipe.enable_cpu_offload() + pipe.dit.quantize() + for model in pipe.controlnet.models: + model.quantize() + image_1 = pipe( + prompt="a beautiful Chinese girl, delicate skin texture", + height=768, width=768, + seed=2 + ) + image_1.save("image_3.jpg") + + image_2 = pipe( + prompt="a beautiful Chinese girl, delicate skin texture", + controlnet_image=image_1.resize((2048, 2048)), + input_image=image_1.resize((2048, 2048)), denoising_strength=0.99, + height=2048, width=2048, tiled=True, + seed=3 + ) + image_2.save("image_4.jpg") + + +def example_3(): + model_manager = ModelManager( + torch_dtype=torch.bfloat16, + device="cpu" # To reduce VRAM required, we load models to RAM. + # device="cuda" # To reduce VRAM required, we load models to RAM. + ) + model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + ]) + model_manager.load_models( + ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + model_manager.load_models( + ["models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="canny", + model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + scale=0.3 + ), + ControlNetConfigUnit( + processor_id="depth", + model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + scale=0.3 + ), + ],device="cuda") + pipe.enable_cpu_offload() + pipe.dit.quantize() + for model in pipe.controlnet.models: + model.quantize() + image_1 = pipe( + prompt="a cat is running", + height=1024, width=1024, + seed=4 + ) + image_1.save("image_5.jpg") + + image_2 = pipe( + prompt="sunshine, a cat is running", + controlnet_image=image_1, + height=1024, width=1024, + seed=5 + ) + image_2.save("image_6.jpg") + + +def example_4(): + model_manager = ModelManager( + torch_dtype=torch.bfloat16, + device="cpu" # To reduce VRAM required, we load models to RAM. + # device="cuda" # To reduce VRAM required, we load models to RAM. + ) + model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + ]) + model_manager.load_models( + ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + model_manager.load_models( + ["models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="canny", + model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + scale=0.3 + ), + ControlNetConfigUnit( + processor_id="depth", + model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + scale=0.3 + ), + ],device="cuda") + pipe.enable_cpu_offload() + pipe.dit.quantize() + for model in pipe.controlnet.models: + model.quantize() + image_1 = pipe( + prompt="a beautiful Asian girl, full body, red dress, summer", + height=1024, width=1024, + seed=6 + ) + image_1.save("image_7.jpg") + + image_2 = pipe( + prompt="a beautiful Asian girl, full body, red dress, winter", + controlnet_image=image_1, + height=1024, width=1024, + seed=7 + ) + image_2.save("image_8.jpg") + + + +def example_5(): + model_manager = ModelManager( + torch_dtype=torch.bfloat16, + device="cpu" # To reduce VRAM required, we load models to RAM. + # device="cuda" # To reduce VRAM required, we load models to RAM. + ) + model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + ]) + model_manager.load_models( + ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + model_manager.load_models( + ["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="inpaint", + model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", + scale=0.9 + ), + ],device="cuda") + pipe.enable_cpu_offload() + pipe.dit.quantize() + for model in pipe.controlnet.models: + model.quantize() + image_1 = pipe( + prompt="a cat sitting on a chair", + height=1024, width=1024, + seed=8 + ) + image_1.save("image_9.jpg") + + mask = np.zeros((1024, 1024, 3), dtype=np.uint8) + mask[100:350, 350: -300] = 255 + mask = Image.fromarray(mask) + mask.save("mask_9.jpg") + + image_2 = pipe( + prompt="a cat sitting on a chair, wearing sunglasses", + controlnet_image=image_1, controlnet_inpaint_mask=mask, + height=1024, width=1024, + seed=9 + ) + image_2.save("image_10.jpg") + + + +def example_6(): + model_manager = ModelManager( + torch_dtype=torch.bfloat16, + device="cpu" # To reduce VRAM required, we load models to RAM. + # device="cuda" # To reduce VRAM required, we load models to RAM. + ) + model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + ]) + model_manager.load_models( + ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + model_manager.load_models( + ["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", + "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="inpaint", + model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", + scale=0.9 + ), + ControlNetConfigUnit( + processor_id="normal", + model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors", + scale=0.6 + ), + ],device="cuda") + pipe.enable_cpu_offload() + pipe.dit.quantize() + for model in pipe.controlnet.models: + model.quantize() + image_1 = pipe( + prompt="a beautiful Asian woman looking at the sky, wearing a blue t-shirt.", + height=1024, width=1024, + seed=10 + ) + image_1.save("image_11.jpg") + + mask = np.zeros((1024, 1024, 3), dtype=np.uint8) + mask[-400:, 10:-40] = 255 + mask = Image.fromarray(mask) + mask.save("mask_11.jpg") + + image_2 = pipe( + prompt="a beautiful Asian woman looking at the sky, wearing a yellow t-shirt.", + controlnet_image=image_1, controlnet_inpaint_mask=mask, + height=1024, width=1024, + seed=11 + ) + image_2.save("image_12.jpg") + + +def example_7(): + model_manager = ModelManager( + torch_dtype=torch.bfloat16, + device="cpu" # To reduce VRAM required, we load models to RAM. + # device="cuda" # To reduce VRAM required, we load models to RAM. + ) + model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + ]) + model_manager.load_models( + ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + model_manager.load_models( + ["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", + "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + ) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="inpaint", + model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", + scale=0.9 + ), + ControlNetConfigUnit( + processor_id="canny", + model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + scale=0.5 + ), + ],device="cuda") + pipe.enable_cpu_offload() + pipe.dit.quantize() + for model in pipe.controlnet.models: + model.quantize() + image_1 = pipe( + prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.", + height=1024, width=1024, + seed=100 + ) + image_1.save("image_13.jpg") + + mask_global = np.zeros((1024, 1024, 3), dtype=np.uint8) + mask_global = Image.fromarray(mask_global) + mask_global.save("mask_13_global.jpg") + + mask_1 = np.zeros((1024, 1024, 3), dtype=np.uint8) + mask_1[300:-100, 30: 450] = 255 + mask_1 = Image.fromarray(mask_1) + mask_1.save("mask_13_1.jpg") + + mask_2 = np.zeros((1024, 1024, 3), dtype=np.uint8) + mask_2[500:-100, -400:] = 255 + mask_2[-200:-100, -500:-400] = 255 + mask_2 = Image.fromarray(mask_2) + mask_2.save("mask_13_2.jpg") + + image_2 = pipe( + prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.", + controlnet_image=image_1, controlnet_inpaint_mask=mask_global, + local_prompts=["an orange cat, highly detailed", "a girl wearing a red camisole"], masks=[mask_1, mask_2], mask_scales=[10.0, 10.0], + height=1024, width=1024, + seed=101 + ) + image_2.save("image_14.jpg") + + model_manager.load_lora("models/lora/FLUX-dev-lora-AntiBlur.safetensors", lora_alpha=2) + image_3 = pipe( + prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. clear background.", + negative_prompt="blur, blurry", + input_image=image_2, denoising_strength=0.7, + height=1024, width=1024, + cfg_scale=2.0, num_inference_steps=50, + seed=102 + ) + image_3.save("image_15.jpg") + + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="tile", + model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", + scale=0.7 + ), + ],device="cuda") + pipe.enable_cpu_offload() + pipe.dit.quantize() + for model in pipe.controlnet.models: + model.quantize() + image_4 = pipe( + prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.", + controlnet_image=image_3.resize((2048, 2048)), + input_image=image_3.resize((2048, 2048)), denoising_strength=0.99, + height=2048, width=2048, tiled=True, + seed=103 + ) + image_4.save("image_16.jpg") + + image_5 = pipe( + prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.", + controlnet_image=image_4.resize((4096, 4096)), + input_image=image_4.resize((4096, 4096)), denoising_strength=0.99, + height=4096, width=4096, tiled=True, + seed=104 + ) + image_5.save("image_17.jpg") + + + +download_models(["Annotators:Depth", "Annotators:Normal"]) +download_customized_models( + model_id="LiblibAI/FLUX.1-dev-LoRA-AntiBlur", + origin_file_path="FLUX-dev-lora-AntiBlur.safetensors", + local_dir="models/lora" +) +example_1() +example_2() +example_3() +example_4() +example_5() +example_6() +example_7() From 9377214518a9f599331670eddabfe1214533c00b Mon Sep 17 00:00:00 2001 From: tc2000731 Date: Thu, 31 Oct 2024 17:38:57 +0800 Subject: [PATCH 2/2] update controlnet_frames, downloads --- diffsynth/pipelines/flux_image.py | 3 +- .../flux_controlnet_quantization.py | 65 +++++++++++-------- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index f038113..69664b3 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -258,6 +258,7 @@ def lets_dance_flux( ): if tiled: def flux_forward_fn(hl, hr, wl, wr): + tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None return lets_dance_flux( dit=dit, controlnet=controlnet, @@ -268,7 +269,7 @@ def lets_dance_flux( guidance=guidance, text_ids=text_ids, image_ids=None, - controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames], + controlnet_frames=tiled_controlnet_frames, tiled=False, **kwargs ) diff --git a/examples/ControlNet/flux_controlnet_quantization.py b/examples/ControlNet/flux_controlnet_quantization.py index 4753c87..c4aa3c1 100644 --- a/examples/ControlNet/flux_controlnet_quantization.py +++ b/examples/ControlNet/flux_controlnet_quantization.py @@ -6,10 +6,10 @@ import numpy as np def example_1(): + download_models(["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -18,11 +18,11 @@ def example_1(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -55,10 +55,10 @@ def example_1(): def example_2(): + download_models(["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -67,11 +67,11 @@ def example_2(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -102,10 +102,10 @@ def example_2(): def example_3(): + download_models(["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -114,11 +114,11 @@ def example_3(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -153,10 +153,10 @@ def example_3(): def example_4(): + download_models(["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -165,11 +165,11 @@ def example_4(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -205,10 +205,10 @@ def example_4(): def example_5(): + download_models(["FLUX.1-dev", "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -217,11 +217,11 @@ def example_5(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -257,10 +257,14 @@ def example_5(): def example_6(): + download_models([ + "FLUX.1-dev", + "jasperai/Flux.1-dev-Controlnet-Surface-Normals", + "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta" + ]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -269,12 +273,12 @@ def example_6(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit( @@ -314,10 +318,15 @@ def example_6(): def example_7(): + download_models([ + "FLUX.1-dev", + "InstantX/FLUX.1-dev-Controlnet-Union-alpha", + "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", + "jasperai/Flux.1-dev-Controlnet-Upscaler", + ]) model_manager = ModelManager( torch_dtype=torch.bfloat16, - device="cpu" # To reduce VRAM required, we load models to RAM. - # device="cuda" # To reduce VRAM required, we load models to RAM. + device="cpu" ) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", @@ -326,13 +335,13 @@ def example_7(): ]) model_manager.load_models( ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) model_manager.load_models( ["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"], - torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. + torch_dtype=torch.float8_e4m3fn ) pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ ControlNetConfigUnit(