mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
step1x, teacache, flex refactor
This commit is contained in:
@@ -15,6 +15,7 @@ from typing_extensions import Literal
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from ..prompters import FluxPrompter
|
||||
from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader
|
||||
@@ -36,7 +37,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit", )
|
||||
self.qwenvl = None
|
||||
self.step1x_connector: Qwen2Connector = None
|
||||
self.in_iteration_models = ("dit", "step1x_connector")
|
||||
self.units = [
|
||||
FluxImageUnit_ShapeChecker(),
|
||||
FluxImageUnit_NoiseInitializer(),
|
||||
@@ -46,6 +49,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
FluxImageUnit_EmbeddedGuidanceEmbedder(),
|
||||
FluxImageUnit_IPAdapter(),
|
||||
FluxImageUnit_EntityControl(),
|
||||
FluxImageUnit_TeaCache(),
|
||||
FluxImageUnit_Flex(),
|
||||
FluxImageUnit_Step1x(),
|
||||
]
|
||||
self.model_fn = model_fn_flux_image
|
||||
|
||||
@@ -105,6 +111,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
pipe.prompter.fetch_models(pipe.text_encoder_1, pipe.text_encoder_2)
|
||||
pipe.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||
pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||
# Step1x
|
||||
pipe.qwenvl = model_manager.fetch_model("qwenvl")
|
||||
pipe.step1x_connector = model_manager.fetch_model("step1x_connector")
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -374,6 +383,73 @@ class FluxImageUnit_EntityControl(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class FluxImageUnit_Step1x(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder"))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict):
|
||||
image = inputs_shared.get("step1x_reference_image",None)
|
||||
if image is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
else:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt = inputs_posi["prompt"]
|
||||
nega_prompt = inputs_nega["negative_prompt"]
|
||||
captions = [prompt, nega_prompt]
|
||||
ref_images = [image, image]
|
||||
embs, masks = pipe.qwenvl(captions, ref_images)
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
image = pipe.vae_encoder(image)
|
||||
inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image})
|
||||
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class FluxImageUnit_TeaCache(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh):
|
||||
if tea_cache_l1_thresh is None:
|
||||
return {}
|
||||
else:
|
||||
return {"tea_cache": TeaCache(num_inference_steps=num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh)}
|
||||
|
||||
class FluxImageUnit_Flex(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride):
|
||||
if pipe.dit.input_dim == 196:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if flex_inpaint_image is None:
|
||||
flex_inpaint_image = torch.zeros_like(latents)
|
||||
else:
|
||||
flex_inpaint_image = pipe.preprocess_image(flex_inpaint_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
flex_inpaint_image = pipe.vae_encoder(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if flex_inpaint_mask is None:
|
||||
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
|
||||
else:
|
||||
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
|
||||
flex_inpaint_mask = pipe.preprocess_image(flex_inpaint_mask).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
|
||||
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
|
||||
if flex_control_image is None:
|
||||
flex_control_image = torch.zeros_like(latents)
|
||||
else:
|
||||
flex_control_image = pipe.preprocess_image(flex_control_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
flex_control_image = pipe.vae_encoder(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
|
||||
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
||||
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
|
||||
flex_control_stop_timestep = pipe.scheduler.timesteps[int(flex_control_stop * (len(pipe.scheduler.timesteps) - 1))]
|
||||
return {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
50
examples/flux/flex_text_to_image.py
Normal file
50
examples/flux/flex_text_to_image.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
seed=0
|
||||
)
|
||||
image.save(f"image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[200:400, 400:700] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save(f"image_mask.jpg")
|
||||
|
||||
inpaint_image = image
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
||||
seed=4
|
||||
)
|
||||
image.save(f"image_2_new.jpg")
|
||||
|
||||
control_image = Annotator("canny")(image)
|
||||
control_image.save("image_control.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_control_image=control_image,
|
||||
seed=4
|
||||
)
|
||||
image.save(f"image_3_new.jpg")
|
||||
24
examples/flux/flux_teacache.py
Normal file
24
examples/flux/flux_teacache.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
|
||||
for tea_cache_l1_thresh in [None, 0.2, 0.4, 0.6, 0.8]:
|
||||
image = pipe(
|
||||
prompt=prompt, embedded_guidance=3.5, seed=0,
|
||||
num_inference_steps=50, tea_cache_l1_thresh=tea_cache_l1_thresh
|
||||
)
|
||||
image.save(f"image_{tea_cache_l1_thresh}.png")
|
||||
44
examples/flux/step1x.py
Normal file
44
examples/flux/step1x.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from modelscope import snapshot_download
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
snapshot_download("Qwen/Qwen2.5-VL-7B-Instruct", cache_dir="./models")
|
||||
snapshot_download("stepfun-ai/Step1X-Edit", cache_dir="./models")
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(path="models/Qwen/Qwen2.5-VL-7B-Instruct"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
|
||||
image = pipe(
|
||||
prompt="draw red flowers in Chinese ink painting style",
|
||||
step1x_reference_image=image,
|
||||
width=832, height=1248, cfg_scale=6,
|
||||
seed=1,
|
||||
rand_device='cuda'
|
||||
)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
|
||||
|
||||
image = pipe(
|
||||
prompt="add more flowers in Chinese ink painting style",
|
||||
step1x_reference_image=image,
|
||||
width=832, height=1248, cfg_scale=6,
|
||||
seed=2,
|
||||
rand_device='cuda'
|
||||
)
|
||||
image.save("image_2.jpg")
|
||||
|
||||
Reference in New Issue
Block a user