mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
@@ -98,6 +98,7 @@ model_loader_configs = [
|
||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
|
||||
@@ -276,20 +276,22 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False):
|
||||
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||
|
||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||
|
||||
self.input_dim = input_dim
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
@@ -738,5 +740,7 @@ class FluxDiTStateDictConverter:
|
||||
pass
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
return state_dict_, {"disable_guidance_embedder": True}
|
||||
elif "blocks.8.attn.norm_k_a.weight" not in state_dict_:
|
||||
return state_dict_, {"input_dim": 196, "num_blocks": 8}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
@@ -360,6 +360,34 @@ class FluxImagePipeline(BasePipeline):
|
||||
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
else:
|
||||
return {}, controlnet_image
|
||||
|
||||
|
||||
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32):
|
||||
if self.dit.input_dim == 196:
|
||||
if flex_inpaint_image is None:
|
||||
flex_inpaint_image = torch.zeros_like(latents)
|
||||
else:
|
||||
flex_inpaint_image = self.preprocess_image(flex_inpaint_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
flex_inpaint_image = self.encode_image(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if flex_inpaint_mask is None:
|
||||
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
|
||||
else:
|
||||
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
|
||||
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
|
||||
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
|
||||
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
|
||||
if flex_control_image is None:
|
||||
flex_control_image = torch.zeros_like(latents)
|
||||
else:
|
||||
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
|
||||
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
||||
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
|
||||
flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))]
|
||||
flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
|
||||
else:
|
||||
flex_kwargs = {}
|
||||
return flex_kwargs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -398,6 +426,12 @@ class FluxImagePipeline(BasePipeline):
|
||||
# InfiniteYou
|
||||
infinityou_id_image=None,
|
||||
infinityou_guidance=1.0,
|
||||
# Flex
|
||||
flex_inpaint_image=None,
|
||||
flex_inpaint_mask=None,
|
||||
flex_control_image=None,
|
||||
flex_control_strength=0.5,
|
||||
flex_control_stop=0.5,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
@@ -436,6 +470,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
# ControlNets
|
||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||
|
||||
# Flex
|
||||
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, **tiler_kwargs)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||
@@ -449,7 +486,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
@@ -466,7 +503,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -602,6 +639,9 @@ def lets_dance_flux(
|
||||
ipadapter_kwargs_list={},
|
||||
id_emb=None,
|
||||
infinityou_guidance=None,
|
||||
flex_condition=None,
|
||||
flex_uncondition=None,
|
||||
flex_control_stop_timestep=None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -652,6 +692,13 @@ def lets_dance_flux(
|
||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
# Flex
|
||||
if flex_condition is not None:
|
||||
if timestep.tolist()[0] >= flex_control_stop_timestep:
|
||||
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
||||
else:
|
||||
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = dit.prepare_image_ids(hidden_states)
|
||||
|
||||
49
examples/image_synthesis/flex_text_to_image.py
Normal file
49
examples/image_synthesis/flex_text_to_image.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_models
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
download_models(["FLUX.1-dev"])
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/ostris/Flex.2-preview/Flex.2-preview.safetensors"
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
seed=0
|
||||
)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[200:400, 400:700] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save("image_mask.jpg")
|
||||
|
||||
inpaint_image = image
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
||||
seed=4
|
||||
)
|
||||
image.save("image_2.jpg")
|
||||
|
||||
control_image = Annotator("canny")(image)
|
||||
control_image.save("image_control.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_control_image=control_image,
|
||||
seed=4
|
||||
)
|
||||
image.save("image_3.jpg")
|
||||
Reference in New Issue
Block a user