Merge pull request #215 from modelscope/flux-fp8

Support FLUX fp8
This commit is contained in:
Zhongjie Duan
2024-09-19 10:36:16 +08:00
committed by GitHub
4 changed files with 92 additions and 10 deletions

View File

@@ -404,6 +404,77 @@ class FluxDiT(torch.nn.Module):
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
def quantize(self):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device)
return weight, bias
class quantized_layer:
class Linear(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,input,**kwargs):
weight,bias= cast_bias_weight(self.module,input)
return torch.nn.functional.linear(input,weight,bias)
class RMSNorm(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,hidden_states,**kwargs):
weight= cast_weight(self.module,hidden_states)
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype) * weight
return hidden_states
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
new_layer = quantized_layer.Linear(module)
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
new_layer = quantized_layer.RMSNorm(module)
setattr(model, name, new_layer)
else:
replace_layer(module)
replace_layer(self)
@staticmethod

View File

@@ -415,8 +415,10 @@ class ModelManager:
break
def load_model(self, file_path, model_names=None):
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
print(f"Loading models from: {file_path}")
if device is None: device = self.device
if torch_dtype is None: torch_dtype = self.torch_dtype
if os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
@@ -425,7 +427,7 @@ class ModelManager:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path, state_dict,
device=self.device, torch_dtype=self.torch_dtype,
device=device, torch_dtype=torch_dtype,
allowed_model_names=model_names, model_manager=self
)
for model_name, model in zip(model_names, models):
@@ -438,9 +440,9 @@ class ModelManager:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None):
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
for file_path in file_path_list:
self.load_model(file_path, model_names)
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
def fetch_model(self, model_name, file_path=None, require_model_path=False):

View File

@@ -4,7 +4,7 @@ Image synthesis is the base feature of DiffSynth Studio. We can generate images
### Example: FLUX
Example script: [`flux_text_to_image.py`](./flux_text_to_image.py)
Example script: [`flux_text_to_image.py`](./flux_text_to_image.py) and [`flux_text_to_image_low_vram.py`](./flux_text_to_image_low_vram.py)(low VRAM).
The original version of FLUX doesn't support classifier-free guidance; however, we believe that this guidance mechanism is an important feature for synthesizing beautiful images. You can enable it using the parameter `cfg_scale`, and the extra guidance scale introduced by FLUX is `embedded_guidance`.

View File

@@ -1,17 +1,26 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_models
from diffsynth import download_models, ModelManager, FluxImagePipeline
download_models(["FLUX.1-dev"])
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager = ModelManager(
torch_dtype=torch.bfloat16,
device="cpu" # To reduce VRAM required, we load models to RAM.
)
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager, device='cuda')
model_manager.load_models(
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format.
)
pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
@@ -39,4 +48,4 @@ image = pipe(
num_inference_steps=30, embedded_guidance=3.5,
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
)
image.save("image_2048_highres.jpg")
image.save("image_2048_highres.jpg")