mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
flex t2i
This commit is contained in:
@@ -98,6 +98,7 @@ model_loader_configs = [
|
|||||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
|
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||||
|
|||||||
@@ -276,21 +276,23 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FluxDiT(torch.nn.Module):
|
class FluxDiT(torch.nn.Module):
|
||||||
def __init__(self, disable_guidance_embedder=False):
|
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||||
|
|
||||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
|
||||||
|
|
||||||
def patchify(self, hidden_states):
|
def patchify(self, hidden_states):
|
||||||
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||||
@@ -738,5 +740,7 @@ class FluxDiTStateDictConverter:
|
|||||||
pass
|
pass
|
||||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||||
return state_dict_, {"disable_guidance_embedder": True}
|
return state_dict_, {"disable_guidance_embedder": True}
|
||||||
|
elif "double_blocks.8.img_attn.norm.key_norm.scale" not in state_dict_:
|
||||||
|
return state_dict_, {"input_dim": 196, "num_blocks": 8}
|
||||||
else:
|
else:
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|||||||
@@ -362,6 +362,27 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return {}, controlnet_image
|
return {}, controlnet_image
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None):
|
||||||
|
if self.dit.input_dim == 196:
|
||||||
|
if flex_inpaint_image is None:
|
||||||
|
flex_inpaint_image = torch.zeros_like(latents)
|
||||||
|
else:
|
||||||
|
pass # TODO
|
||||||
|
if flex_inpaint_mask is None:
|
||||||
|
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
|
||||||
|
else:
|
||||||
|
pass # TODO
|
||||||
|
if flex_control_image is None:
|
||||||
|
flex_control_image = torch.zeros_like(latents)
|
||||||
|
else:
|
||||||
|
pass # TODO
|
||||||
|
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
||||||
|
flex_kwargs = {"flex_condition": flex_condition}
|
||||||
|
else:
|
||||||
|
flex_kwargs = {}
|
||||||
|
return flex_kwargs
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -398,6 +419,10 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
# InfiniteYou
|
# InfiniteYou
|
||||||
infinityou_id_image=None,
|
infinityou_id_image=None,
|
||||||
infinityou_guidance=1.0,
|
infinityou_guidance=1.0,
|
||||||
|
# Flex
|
||||||
|
flex_inpaint_image=None,
|
||||||
|
flex_inpaint_mask=None,
|
||||||
|
flex_control_image=None,
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_l1_thresh=None,
|
tea_cache_l1_thresh=None,
|
||||||
# Tile
|
# Tile
|
||||||
@@ -437,6 +462,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
# ControlNets
|
# ControlNets
|
||||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||||
|
|
||||||
|
# Flex
|
||||||
|
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image)
|
||||||
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||||
|
|
||||||
@@ -449,7 +477,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs,
|
||||||
)
|
)
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||||
@@ -466,7 +494,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
noise_pred_nega = lets_dance_flux(
|
noise_pred_nega = lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
@@ -602,6 +630,7 @@ def lets_dance_flux(
|
|||||||
ipadapter_kwargs_list={},
|
ipadapter_kwargs_list={},
|
||||||
id_emb=None,
|
id_emb=None,
|
||||||
infinityou_guidance=None,
|
infinityou_guidance=None,
|
||||||
|
flex_condition=None,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -653,6 +682,9 @@ def lets_dance_flux(
|
|||||||
controlnet_frames, **controlnet_extra_kwargs
|
controlnet_frames, **controlnet_extra_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if flex_condition is not None:
|
||||||
|
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
||||||
|
|
||||||
if image_ids is None:
|
if image_ids is None:
|
||||||
image_ids = dit.prepare_image_ids(hidden_states)
|
image_ids = dit.prepare_image_ids(hidden_states)
|
||||||
|
|
||||||
|
|||||||
22
examples/image_synthesis/flex_text_to_image.py
Normal file
22
examples/image_synthesis/flex_text_to_image.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, FluxImagePipeline, download_models
|
||||||
|
|
||||||
|
|
||||||
|
download_models(["FLUX.1-dev"])
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
model_manager.load_models([
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||||
|
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||||
|
"models/ostris/Flex.2-preview/Flex.2-preview.safetensors"
|
||||||
|
])
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
|
|
||||||
|
torch.manual_seed(9)
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
num_inference_steps=50, embedded_guidance=3.5
|
||||||
|
)
|
||||||
|
image.save("image_1024.jpg")
|
||||||
Reference in New Issue
Block a user