mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
flex t2i
This commit is contained in:
@@ -98,6 +98,7 @@ model_loader_configs = [
|
||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
|
||||
@@ -276,20 +276,22 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False):
|
||||
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||
|
||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||
|
||||
self.input_dim = input_dim
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
@@ -738,5 +740,7 @@ class FluxDiTStateDictConverter:
|
||||
pass
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
return state_dict_, {"disable_guidance_embedder": True}
|
||||
elif "double_blocks.8.img_attn.norm.key_norm.scale" not in state_dict_:
|
||||
return state_dict_, {"input_dim": 196, "num_blocks": 8}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
@@ -360,6 +360,27 @@ class FluxImagePipeline(BasePipeline):
|
||||
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
else:
|
||||
return {}, controlnet_image
|
||||
|
||||
|
||||
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None):
|
||||
if self.dit.input_dim == 196:
|
||||
if flex_inpaint_image is None:
|
||||
flex_inpaint_image = torch.zeros_like(latents)
|
||||
else:
|
||||
pass # TODO
|
||||
if flex_inpaint_mask is None:
|
||||
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
|
||||
else:
|
||||
pass # TODO
|
||||
if flex_control_image is None:
|
||||
flex_control_image = torch.zeros_like(latents)
|
||||
else:
|
||||
pass # TODO
|
||||
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
||||
flex_kwargs = {"flex_condition": flex_condition}
|
||||
else:
|
||||
flex_kwargs = {}
|
||||
return flex_kwargs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -398,6 +419,10 @@ class FluxImagePipeline(BasePipeline):
|
||||
# InfiniteYou
|
||||
infinityou_id_image=None,
|
||||
infinityou_guidance=1.0,
|
||||
# Flex
|
||||
flex_inpaint_image=None,
|
||||
flex_inpaint_mask=None,
|
||||
flex_control_image=None,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
@@ -436,6 +461,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
# ControlNets
|
||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||
|
||||
# Flex
|
||||
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||
@@ -449,7 +477,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
@@ -466,7 +494,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -602,6 +630,7 @@ def lets_dance_flux(
|
||||
ipadapter_kwargs_list={},
|
||||
id_emb=None,
|
||||
infinityou_guidance=None,
|
||||
flex_condition=None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -652,6 +681,9 @@ def lets_dance_flux(
|
||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
if flex_condition is not None:
|
||||
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = dit.prepare_image_ids(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user