This commit is contained in:
Artiprocher
2025-04-24 14:51:40 +08:00
parent e54c0a8468
commit b7a1ac6671
4 changed files with 64 additions and 5 deletions

View File

@@ -98,6 +98,7 @@ model_loader_configs = [
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),

View File

@@ -276,20 +276,22 @@ class AdaLayerNormContinuous(torch.nn.Module):
class FluxDiT(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False):
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.x_embedder = torch.nn.Linear(input_dim, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.final_norm_out = AdaLayerNormContinuous(3072)
self.final_proj_out = torch.nn.Linear(3072, 64)
self.input_dim = input_dim
def patchify(self, hidden_states):
@@ -738,5 +740,7 @@ class FluxDiTStateDictConverter:
pass
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
return state_dict_, {"disable_guidance_embedder": True}
elif "double_blocks.8.img_attn.norm.key_norm.scale" not in state_dict_:
return state_dict_, {"input_dim": 196, "num_blocks": 8}
else:
return state_dict_

View File

@@ -360,6 +360,27 @@ class FluxImagePipeline(BasePipeline):
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
else:
return {}, controlnet_image
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None):
if self.dit.input_dim == 196:
if flex_inpaint_image is None:
flex_inpaint_image = torch.zeros_like(latents)
else:
pass # TODO
if flex_inpaint_mask is None:
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
else:
pass # TODO
if flex_control_image is None:
flex_control_image = torch.zeros_like(latents)
else:
pass # TODO
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
flex_kwargs = {"flex_condition": flex_condition}
else:
flex_kwargs = {}
return flex_kwargs
@torch.no_grad()
@@ -398,6 +419,10 @@ class FluxImagePipeline(BasePipeline):
# InfiniteYou
infinityou_id_image=None,
infinityou_guidance=1.0,
# Flex
flex_inpaint_image=None,
flex_inpaint_mask=None,
flex_control_image=None,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
@@ -436,6 +461,9 @@ class FluxImagePipeline(BasePipeline):
# ControlNets
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# Flex
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
@@ -449,7 +477,7 @@ class FluxImagePipeline(BasePipeline):
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -466,7 +494,7 @@ class FluxImagePipeline(BasePipeline):
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -602,6 +630,7 @@ def lets_dance_flux(
ipadapter_kwargs_list={},
id_emb=None,
infinityou_guidance=None,
flex_condition=None,
tea_cache: TeaCache = None,
**kwargs
):
@@ -652,6 +681,9 @@ def lets_dance_flux(
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
)
if flex_condition is not None:
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)

View File

@@ -0,0 +1,22 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_models
download_models(["FLUX.1-dev"])
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/ostris/Flex.2-preview/Flex.2-preview.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
torch.manual_seed(9)
image = pipe(
prompt=prompt,
num_inference_steps=50, embedded_guidance=3.5
)
image.save("image_1024.jpg")