mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
flex t2i
This commit is contained in:
@@ -98,6 +98,7 @@ model_loader_configs = [
|
||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
|
||||
@@ -276,20 +276,22 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False):
|
||||
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||
|
||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||
|
||||
self.input_dim = input_dim
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
@@ -738,5 +740,7 @@ class FluxDiTStateDictConverter:
|
||||
pass
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
return state_dict_, {"disable_guidance_embedder": True}
|
||||
elif "double_blocks.8.img_attn.norm.key_norm.scale" not in state_dict_:
|
||||
return state_dict_, {"input_dim": 196, "num_blocks": 8}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
@@ -360,6 +360,27 @@ class FluxImagePipeline(BasePipeline):
|
||||
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
else:
|
||||
return {}, controlnet_image
|
||||
|
||||
|
||||
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None):
|
||||
if self.dit.input_dim == 196:
|
||||
if flex_inpaint_image is None:
|
||||
flex_inpaint_image = torch.zeros_like(latents)
|
||||
else:
|
||||
pass # TODO
|
||||
if flex_inpaint_mask is None:
|
||||
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
|
||||
else:
|
||||
pass # TODO
|
||||
if flex_control_image is None:
|
||||
flex_control_image = torch.zeros_like(latents)
|
||||
else:
|
||||
pass # TODO
|
||||
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
||||
flex_kwargs = {"flex_condition": flex_condition}
|
||||
else:
|
||||
flex_kwargs = {}
|
||||
return flex_kwargs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -398,6 +419,10 @@ class FluxImagePipeline(BasePipeline):
|
||||
# InfiniteYou
|
||||
infinityou_id_image=None,
|
||||
infinityou_guidance=1.0,
|
||||
# Flex
|
||||
flex_inpaint_image=None,
|
||||
flex_inpaint_mask=None,
|
||||
flex_control_image=None,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
@@ -436,6 +461,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
# ControlNets
|
||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||
|
||||
# Flex
|
||||
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||
@@ -449,7 +477,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
@@ -466,7 +494,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -602,6 +630,7 @@ def lets_dance_flux(
|
||||
ipadapter_kwargs_list={},
|
||||
id_emb=None,
|
||||
infinityou_guidance=None,
|
||||
flex_condition=None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -652,6 +681,9 @@ def lets_dance_flux(
|
||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
if flex_condition is not None:
|
||||
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = dit.prepare_image_ids(hidden_states)
|
||||
|
||||
22
examples/image_synthesis/flex_text_to_image.py
Normal file
22
examples/image_synthesis/flex_text_to_image.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_models
|
||||
|
||||
|
||||
download_models(["FLUX.1-dev"])
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/ostris/Flex.2-preview/Flex.2-preview.safetensors"
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
|
||||
torch.manual_seed(9)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=50, embedded_guidance=3.5
|
||||
)
|
||||
image.save("image_1024.jpg")
|
||||
Reference in New Issue
Block a user