mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
26 lines
1.1 KiB
Python
26 lines
1.1 KiB
Python
import torch
|
|
from diffsynth import ModelManager, FluxImagePipeline
|
|
|
|
|
|
model_manager = ModelManager(
|
|
file_path_list=[
|
|
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
|
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
|
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
|
|
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
|
],
|
|
torch_dtype=torch.float8_e4m3fn,
|
|
device="cpu"
|
|
)
|
|
pipe = FluxImagePipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
|
|
|
# Enable VRAM management
|
|
# `num_persistent_param_in_dit` indicates the number of parameters that reside persistently in VRAM within the DiT model.
|
|
# When `num_persistent_param_in_dit=None`, it means all parameters reside persistently in memory.
|
|
# When `num_persistent_param_in_dit=7*10**9`, it indicates that 7 billion parameters reside persistently in memory.
|
|
# When `num_persistent_param_in_dit=0`, it means no parameters reside persistently in memory, and they are loaded layer by layer during inference.
|
|
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
|
|
|
image = pipe(prompt="a beautiful orange cat", seed=0)
|
|
image.save("image.jpg")
|