support flux-fp8

This commit is contained in:
Artiprocher
2024-09-19 10:32:16 +08:00
parent a9fbfa108f
commit 091df1f1e7
4 changed files with 91 additions and 125 deletions

View File

@@ -0,0 +1,51 @@
import torch
from diffsynth import download_models, ModelManager, FluxImagePipeline
download_models(["FLUX.1-dev"])
model_manager = ModelManager(
torch_dtype=torch.bfloat16,
device="cpu" # To reduce VRAM required, we load models to RAM.
)
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
])
model_manager.load_models(
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format.
)
pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
torch.manual_seed(6)
image = pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5
)
image.save("image_1024.jpg")
# Enable classifier-free guidance
torch.manual_seed(6)
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
)
image.save("image_1024_cfg.jpg")
# Highres-fix
torch.manual_seed(7)
image = pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5,
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
)
image.save("image_2048_highres.jpg")