Merge pull request #252 from modelscope/Flux_ControlNet_Quantization

add Flux_ControlNet_Quantization
This commit is contained in:
Zhongjie Duan
2024-11-01 14:51:10 +08:00
committed by GitHub
7 changed files with 569 additions and 3 deletions

View File

@@ -31,6 +31,8 @@ class MultiControlNetManager:
def to(self, device):
for model in self.models:
model.to(device)
for processor in self.processors:
processor.to(device)
def process_image(self, image, processor_id=None):
if processor_id is None:

View File

@@ -37,6 +37,11 @@ class Annotator:
self.processor_id = processor_id
self.detect_resolution = detect_resolution
def to(self,device):
if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
self.processor.model.to(device)
def __call__(self, image, mask=None):
width, height = image.size

View File

@@ -1,7 +1,7 @@
import torch
from einops import rearrange, repeat
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock
from .utils import hash_state_dict_keys
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
from .utils import hash_state_dict_keys, init_weights_on_device
@@ -106,6 +106,107 @@ class FluxControlNet(torch.nn.Module):
def state_dict_converter():
return FluxControlNetStateDictConverter()
def quantize(self):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device)
return weight, bias
class quantized_layer:
class QLinear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight,bias= cast_bias_weight(self,input)
return torch.nn.functional.linear(input,weight,bias)
class QRMSNorm(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,hidden_states,**kwargs):
weight= cast_weight(self.module,hidden_states)
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype) * weight
return hidden_states
class QEmbedding(torch.nn.Embedding):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight= cast_weight(self,input)
return torch.nn.functional.embedding(
input, weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module,quantized_layer.QRMSNorm):
continue
if isinstance(module, torch.nn.Linear):
with init_weights_on_device():
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
new_layer.weight = module.weight
if module.bias is not None:
new_layer.bias = module.bias
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
if hasattr(module,"quantized"):
continue
module.quantized= True
new_layer = quantized_layer.QRMSNorm(module)
setattr(model, name, new_layer)
elif isinstance(module,torch.nn.Embedding):
rows, cols = module.weight.shape
new_layer = quantized_layer.QEmbedding(
num_embeddings=rows,
embedding_dim=cols,
_weight=module.weight,
# _freeze=module.freeze,
padding_idx=module.padding_idx,
max_norm=module.max_norm,
norm_type=module.norm_type,
scale_grad_by_freq=module.scale_grad_by_freq,
sparse=module.sparse)
setattr(model, name, new_layer)
else:
replace_layer(module)
replace_layer(self)
class FluxControlNetStateDictConverter:

View File

@@ -475,6 +475,9 @@ class FluxDiT(torch.nn.Module):
# del module
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
if hasattr(module,"quantized"):
continue
module.quantized= True
new_layer = quantized_layer.RMSNorm(module)
setattr(model, name, new_layer)
else:

View File

@@ -83,8 +83,14 @@ class LoRAFromCivitai:
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
fp8=False
if state_dict_model[name].dtype == torch.float8_e4m3fn:
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
fp8=True
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
if fp8:
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
model.load_state_dict(state_dict_model)

View File

@@ -187,6 +187,7 @@ class FluxImagePipeline(BasePipeline):
# Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['vae_encoder'])
controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
if len(masks) > 0 and controlnet_inpaint_mask is not None:
print("The controlnet_inpaint_mask will be overridden by masks.")
@@ -257,6 +258,7 @@ def lets_dance_flux(
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
return lets_dance_flux(
dit=dit,
controlnet=controlnet,
@@ -267,7 +269,7 @@ def lets_dance_flux(
guidance=guidance,
text_ids=text_ids,
image_ids=None,
controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames],
controlnet_frames=tiled_controlnet_frames,
tiled=False,
**kwargs
)

View File

@@ -0,0 +1,447 @@
from diffsynth import ModelManager, FluxImagePipeline, ControlNetConfigUnit, download_models, download_customized_models
import torch
from PIL import Image
import numpy as np
def example_1():
download_models(["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"])
model_manager = ModelManager(
torch_dtype=torch.bfloat16,
device="cpu"
)
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
])
model_manager.load_models(
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
model_manager.load_models(
["models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="tile",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
scale=0.7
),
],device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
for model in pipe.controlnet.models:
model.quantize()
image_1 = pipe(
prompt="a photo of a cat, highly detailed",
height=768, width=768,
seed=0
)
image_1.save("image_1.jpg")
image_2 = pipe(
prompt="a photo of a cat, highly detailed",
controlnet_image=image_1.resize((2048, 2048)),
input_image=image_1.resize((2048, 2048)), denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=1
)
image_2.save("image_2.jpg")
def example_2():
download_models(["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"])
model_manager = ModelManager(
torch_dtype=torch.bfloat16,
device="cpu"
)
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
])
model_manager.load_models(
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
model_manager.load_models(
["models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="tile",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
scale=0.7
),
],device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
for model in pipe.controlnet.models:
model.quantize()
image_1 = pipe(
prompt="a beautiful Chinese girl, delicate skin texture",
height=768, width=768,
seed=2
)
image_1.save("image_3.jpg")
image_2 = pipe(
prompt="a beautiful Chinese girl, delicate skin texture",
controlnet_image=image_1.resize((2048, 2048)),
input_image=image_1.resize((2048, 2048)), denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=3
)
image_2.save("image_4.jpg")
def example_3():
download_models(["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"])
model_manager = ModelManager(
torch_dtype=torch.bfloat16,
device="cpu"
)
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
])
model_manager.load_models(
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
model_manager.load_models(
["models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="canny",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
ControlNetConfigUnit(
processor_id="depth",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
],device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
for model in pipe.controlnet.models:
model.quantize()
image_1 = pipe(
prompt="a cat is running",
height=1024, width=1024,
seed=4
)
image_1.save("image_5.jpg")
image_2 = pipe(
prompt="sunshine, a cat is running",
controlnet_image=image_1,
height=1024, width=1024,
seed=5
)
image_2.save("image_6.jpg")
def example_4():
download_models(["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"])
model_manager = ModelManager(
torch_dtype=torch.bfloat16,
device="cpu"
)
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
])
model_manager.load_models(
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
model_manager.load_models(
["models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="canny",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
ControlNetConfigUnit(
processor_id="depth",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
],device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
for model in pipe.controlnet.models:
model.quantize()
image_1 = pipe(
prompt="a beautiful Asian girl, full body, red dress, summer",
height=1024, width=1024,
seed=6
)
image_1.save("image_7.jpg")
image_2 = pipe(
prompt="a beautiful Asian girl, full body, red dress, winter",
controlnet_image=image_1,
height=1024, width=1024,
seed=7
)
image_2.save("image_8.jpg")
def example_5():
download_models(["FLUX.1-dev", "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"])
model_manager = ModelManager(
torch_dtype=torch.bfloat16,
device="cpu"
)
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
])
model_manager.load_models(
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
model_manager.load_models(
["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="inpaint",
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
scale=0.9
),
],device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
for model in pipe.controlnet.models:
model.quantize()
image_1 = pipe(
prompt="a cat sitting on a chair",
height=1024, width=1024,
seed=8
)
image_1.save("image_9.jpg")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[100:350, 350: -300] = 255
mask = Image.fromarray(mask)
mask.save("mask_9.jpg")
image_2 = pipe(
prompt="a cat sitting on a chair, wearing sunglasses",
controlnet_image=image_1, controlnet_inpaint_mask=mask,
height=1024, width=1024,
seed=9
)
image_2.save("image_10.jpg")
def example_6():
download_models([
"FLUX.1-dev",
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"
])
model_manager = ModelManager(
torch_dtype=torch.bfloat16,
device="cpu"
)
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
])
model_manager.load_models(
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
model_manager.load_models(
["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
"models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="inpaint",
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
scale=0.9
),
ControlNetConfigUnit(
processor_id="normal",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors",
scale=0.6
),
],device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
for model in pipe.controlnet.models:
model.quantize()
image_1 = pipe(
prompt="a beautiful Asian woman looking at the sky, wearing a blue t-shirt.",
height=1024, width=1024,
seed=10
)
image_1.save("image_11.jpg")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[-400:, 10:-40] = 255
mask = Image.fromarray(mask)
mask.save("mask_11.jpg")
image_2 = pipe(
prompt="a beautiful Asian woman looking at the sky, wearing a yellow t-shirt.",
controlnet_image=image_1, controlnet_inpaint_mask=mask,
height=1024, width=1024,
seed=11
)
image_2.save("image_12.jpg")
def example_7():
download_models([
"FLUX.1-dev",
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
"jasperai/Flux.1-dev-Controlnet-Upscaler",
])
model_manager = ModelManager(
torch_dtype=torch.bfloat16,
device="cpu"
)
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
])
model_manager.load_models(
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
model_manager.load_models(
["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
"models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
"models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"],
torch_dtype=torch.float8_e4m3fn
)
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="inpaint",
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
scale=0.9
),
ControlNetConfigUnit(
processor_id="canny",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.5
),
],device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
for model in pipe.controlnet.models:
model.quantize()
image_1 = pipe(
prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.",
height=1024, width=1024,
seed=100
)
image_1.save("image_13.jpg")
mask_global = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask_global = Image.fromarray(mask_global)
mask_global.save("mask_13_global.jpg")
mask_1 = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask_1[300:-100, 30: 450] = 255
mask_1 = Image.fromarray(mask_1)
mask_1.save("mask_13_1.jpg")
mask_2 = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask_2[500:-100, -400:] = 255
mask_2[-200:-100, -500:-400] = 255
mask_2 = Image.fromarray(mask_2)
mask_2.save("mask_13_2.jpg")
image_2 = pipe(
prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.",
controlnet_image=image_1, controlnet_inpaint_mask=mask_global,
local_prompts=["an orange cat, highly detailed", "a girl wearing a red camisole"], masks=[mask_1, mask_2], mask_scales=[10.0, 10.0],
height=1024, width=1024,
seed=101
)
image_2.save("image_14.jpg")
model_manager.load_lora("models/lora/FLUX-dev-lora-AntiBlur.safetensors", lora_alpha=2)
image_3 = pipe(
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. clear background.",
negative_prompt="blur, blurry",
input_image=image_2, denoising_strength=0.7,
height=1024, width=1024,
cfg_scale=2.0, num_inference_steps=50,
seed=102
)
image_3.save("image_15.jpg")
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="tile",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
scale=0.7
),
],device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
for model in pipe.controlnet.models:
model.quantize()
image_4 = pipe(
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.",
controlnet_image=image_3.resize((2048, 2048)),
input_image=image_3.resize((2048, 2048)), denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=103
)
image_4.save("image_16.jpg")
image_5 = pipe(
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.",
controlnet_image=image_4.resize((4096, 4096)),
input_image=image_4.resize((4096, 4096)), denoising_strength=0.99,
height=4096, width=4096, tiled=True,
seed=104
)
image_5.save("image_17.jpg")
download_models(["Annotators:Depth", "Annotators:Normal"])
download_customized_models(
model_id="LiblibAI/FLUX.1-dev-LoRA-AntiBlur",
origin_file_path="FLUX-dev-lora-AntiBlur.safetensors",
local_dir="models/lora"
)
example_1()
example_2()
example_3()
example_4()
example_5()
example_6()
example_7()