diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 24c2b7c..1fade41 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -285,6 +285,34 @@ flux_series = [ "model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5", "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter", }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors") + "model_hash": "21ea55f476dfc4fd135587abb59dfe5d", + "model_name": "flux_vae_encoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors") + "model_hash": "21ea55f476dfc4fd135587abb59dfe5d", + "model_name": "flux_vae_decoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors") + "model_hash": "d02f41c13549fa5093d3521f62a5570a", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "extra_kwargs": {'input_dim': 196, 'num_blocks': 8}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors") + "model_hash": "0629116fce1472503a66992f96f3eb1a", + "model_name": "flux_value_controller", + "model_class": "diffsynth.models.flux_value_control.SingleValueEncoder", + } ] MODEL_CONFIGS = qwen_image_series + wan_series + flux_series diff --git a/diffsynth/models/flux_ipadapter.py b/diffsynth/models/flux_ipadapter.py index 575c752..798c45a 100644 --- a/diffsynth/models/flux_ipadapter.py +++ b/diffsynth/models/flux_ipadapter.py @@ -1,9 +1,38 @@ -from .svd_image_encoder import SVDImageEncoder -from .sd3_dit import RMSNorm -from transformers import CLIPImageProcessor +from .general_modules import RMSNorm +from transformers import SiglipVisionModel, SiglipVisionConfig import torch +class SiglipVisionModelSO400M(SiglipVisionModel): + def __init__(self): + config = SiglipVisionConfig(**{ + "architectures": [ + "SiglipModel" + ], + "initializer_factor": 1.0, + "model_type": "siglip", + "text_config": { + "hidden_size": 1152, + "intermediate_size": 4304, + "model_type": "siglip_text_model", + "num_attention_heads": 16, + "num_hidden_layers": 27 + }, + "torch_dtype": "float32", + "transformers_version": "4.37.0.dev0", + "vision_config": { + "hidden_size": 1152, + "image_size": 384, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14 + } + }) + super().__init__(config) + + class MLPProjModel(torch.nn.Module): def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): super().__init__() diff --git a/diffsynth/models/flux_vae.py b/diffsynth/models/flux_vae.py index cbb7038..ded3047 100644 --- a/diffsynth/models/flux_vae.py +++ b/diffsynth/models/flux_vae.py @@ -106,7 +106,7 @@ class TileWorker: return model_output -class Attention(torch.nn.Module): +class ConvAttention(torch.nn.Module): def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): super().__init__() @@ -115,10 +115,10 @@ class Attention(torch.nn.Module): self.num_heads = num_heads self.head_dim = head_dim - self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) - self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) - self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) - self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q) + self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv) + self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv) + self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out) def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): if encoder_hidden_states is None: @@ -126,9 +126,14 @@ class Attention(torch.nn.Module): batch_size = encoder_hidden_states.shape[0] - q = self.to_q(hidden_states) - k = self.to_k(encoder_hidden_states) - v = self.to_v(encoder_hidden_states) + conv_input = rearrange(hidden_states, "B L C -> B C L 1") + q = self.to_q(conv_input) + q = rearrange(q[:, :, :, 0], "B C L -> B L C") + conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1") + k = self.to_k(conv_input) + v = self.to_v(conv_input) + k = rearrange(k[:, :, :, 0], "B C L -> B L C") + v = rearrange(v[:, :, :, 0], "B C L -> B L C") q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) @@ -138,7 +143,9 @@ class Attention(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) - hidden_states = self.to_out(hidden_states) + conv_input = rearrange(hidden_states, "B L C -> B C L 1") + hidden_states = self.to_out(conv_input) + hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C") return hidden_states @@ -152,7 +159,7 @@ class VAEAttentionBlock(torch.nn.Module): self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) self.transformer_blocks = torch.nn.ModuleList([ - Attention( + ConvAttention( inner_dim, num_attention_heads, attention_head_dim, @@ -236,7 +243,7 @@ class DownSampler(torch.nn.Module): return hidden_states, time_emb, text_emb, res_stack -class SD3VAEDecoder(torch.nn.Module): +class FluxVAEDecoder(torch.nn.Module): def __init__(self): super().__init__() self.scaling_factor = 0.3611 @@ -308,7 +315,7 @@ class SD3VAEDecoder(torch.nn.Module): return hidden_states -class SD3VAEEncoder(torch.nn.Module): +class FluxVAEEncoder(torch.nn.Module): def __init__(self): super().__init__() self.scaling_factor = 0.3611 diff --git a/diffsynth/models/flux_value_control.py b/diffsynth/models/flux_value_control.py index 6981344..691f9ca 100644 --- a/diffsynth/models/flux_value_control.py +++ b/diffsynth/models/flux_value_control.py @@ -1,10 +1,12 @@ import torch -from diffsynth.models.svd_unet import TemporalTimesteps +from .general_modules import TemporalTimesteps class MultiValueEncoder(torch.nn.Module): def __init__(self, encoders=()): super().__init__() + if not isinstance(encoders, list): + encoders = [encoders] self.encoders = torch.nn.ModuleList(encoders) def __call__(self, values, dtype): diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py new file mode 100644 index 0000000..3c0151d --- /dev/null +++ b/diffsynth/pipelines/flux_image.py @@ -0,0 +1,1163 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora.flux import FluxLoRALoader + +from ..models.flux_dit import FluxDiT +from ..models.flux_text_encoder_clip import FluxTextEncoderClip +from ..models.flux_text_encoder_t5 import FluxTextEncoderT5 +from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.flux_value_control import MultiValueEncoder + + +class MultiControlNet(torch.nn.Module): + def __init__(self, models: list[torch.nn.Module]): + super().__init__() + if not isinstance(models, list): + models = [models] + self.models = torch.nn.ModuleList(models) + + def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs): + model = self.models[controlnet_input.controlnet_id] + res_stack, single_res_stack = model( + controlnet_conditioning=conditioning, + processor_id=controlnet_input.processor_id, + **kwargs + ) + res_stack = [res * controlnet_input.scale for res in res_stack] + single_res_stack = [res * controlnet_input.scale for res in single_res_stack] + return res_stack, single_res_stack + + def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs): + res_stack, single_res_stack = None, None + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1) + if progress > controlnet_input.start or progress < controlnet_input.end: + continue + res_stack_, single_res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs) + if res_stack is None: + res_stack = res_stack_ + single_res_stack = single_res_stack_ + else: + res_stack = [i + j for i, j in zip(res_stack, res_stack_)] + single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)] + return res_stack, single_res_stack + + +class FluxImagePipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler() + self.tokenizer_1: CLIPTokenizer = None + self.tokenizer_2: T5TokenizerFast = None + self.text_encoder_1: FluxTextEncoderClip = None + self.text_encoder_2: FluxTextEncoderT5 = None + self.dit: FluxDiT = None + self.vae_decoder: FluxVAEDecoder = None + self.vae_encoder: FluxVAEEncoder = None + self.controlnet = None + self.ipadapter = None + self.ipadapter_image_encoder = None + self.qwenvl = None + self.step1x_connector = None + self.nexus_gen = None + self.nexus_gen_generation_adapter = None + self.nexus_gen_editing_adapter = None + self.value_controller = None + self.infinityou_processor = None + self.image_proj_model = None + self.lora_patcher = None + self.lora_encoder = None + self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher") + self.units = [ + FluxImageUnit_ShapeChecker(), + FluxImageUnit_NoiseInitializer(), + FluxImageUnit_PromptEmbedder(), + FluxImageUnit_InputImageEmbedder(), + FluxImageUnit_ImageIDs(), + FluxImageUnit_EmbeddedGuidanceEmbedder(), + FluxImageUnit_Kontext(), + FluxImageUnit_InfiniteYou(), + FluxImageUnit_ControlNet(), + FluxImageUnit_IPAdapter(), + FluxImageUnit_EntityControl(), + FluxImageUnit_NexusGen(), + FluxImageUnit_TeaCache(), + FluxImageUnit_Flex(), + FluxImageUnit_Step1x(), + FluxImageUnit_ValueControl(), + FluxImageUnit_LoRAEncode(), + ] + self.model_fn = model_fn_flux_image + self.lora_loader = FluxLoRALoader + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), + tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), + nexus_gen_processor_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = FluxImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder_1 = model_pool.fetch_model("flux_text_encoder_clip") + pipe.text_encoder_2 = model_pool.fetch_model("flux_text_encoder_t5") + pipe.dit = model_pool.fetch_model("flux_dit") + pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") + pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") + if tokenizer_1_config is not None: + tokenizer_1_config.download_if_necessary() + pipe.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_config.path) + if tokenizer_2_config is not None: + tokenizer_2_config.download_if_necessary() + pipe.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_config.path) + + value_controllers = model_pool.fetch_model("flux_value_controller") + if value_controllers is not None: pipe.value_controller = MultiValueEncoder(value_controllers) + controlnets = model_pool.fetch_model("flux_controlnet") + if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets) + pipe.ipadapter = model_pool.fetch_model("flux_ipadapter") + pipe.ipadapter_image_encoder = model_pool.fetch_model("siglip_vision_model") + pipe.qwenvl = model_pool.fetch_model("qwenvl") + pipe.step1x_connector = model_pool.fetch_model("step1x_connector") + pipe.image_proj_model = model_pool.fetch_model("infiniteyou_image_projector") + if pipe.image_proj_model is not None: + pipe.infinityou_processor = InfinitYou(device=device) + pipe.lora_patcher = model_pool.fetch_model("flux_lora_patcher") + pipe.lora_encoder = model_pool.fetch_model("flux_lora_encoder") + pipe.nexus_gen = model_pool.fetch_model("nexus_gen_llm") + pipe.nexus_gen_generation_adapter = model_pool.fetch_model("nexus_gen_generation_adapter") + pipe.nexus_gen_editing_adapter = model_pool.fetch_model("nexus_gen_editing_adapter") + if nexus_gen_processor_config is not None and pipe.nexus_gen is not None: + nexus_gen_processor_config.download_if_necessary() + pipe.nexus_gen.load_processor(nexus_gen_processor_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + embedded_guidance: float = 3.5, + t5_sequence_length: int = 512, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Scheduler + sigma_shift: float = None, + # Steps + num_inference_steps: int = 30, + # local prompts + multidiffusion_prompts=(), + multidiffusion_masks=(), + multidiffusion_scales=(), + # Kontext + kontext_images: Union[list[Image.Image], Image.Image] = None, + # ControlNet + controlnet_inputs: list[ControlNetInput] = None, + # IP-Adapter + ipadapter_images: Union[list[Image.Image], Image.Image] = None, + ipadapter_scale: float = 1.0, + # EliGen + eligen_entity_prompts: list[str] = None, + eligen_entity_masks: list[Image.Image] = None, + eligen_enable_on_negative: bool = False, + eligen_enable_inpaint: bool = False, + # InfiniteYou + infinityou_id_image: Image.Image = None, + infinityou_guidance: float = 1.0, + # Flex + flex_inpaint_image: Image.Image = None, + flex_inpaint_mask: Image.Image = None, + flex_control_image: Image.Image = None, + flex_control_strength: float = 0.5, + flex_control_stop: float = 0.5, + # Value Controller + value_controller_inputs: Union[list[float], float] = None, + # Step1x + step1x_reference_image: Image.Image = None, + # NexusGen + nexus_gen_reference_image: Image.Image = None, + # LoRA Encoder + lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None, + lora_encoder_scale: float = 1.0, + # TeaCache + tea_cache_l1_thresh: float = None, + # Tile + tiled: bool = False, + tile_size: int = 128, + tile_stride: int = 64, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "t5_sequence_length": t5_sequence_length, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps, + "multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales, + "kontext_images": kontext_images, + "controlnet_inputs": controlnet_inputs, + "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale, + "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint, + "infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance, + "flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop, + "value_controller_inputs": value_controller_inputs, + "step1x_reference_image": step1x_reference_image, + "nexus_gen_reference_image": nexus_gen_reference_image, + "lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale, + "tea_cache_l1_thresh": tea_cache_l1_thresh, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "progress_bar_cmd": progress_bar_cmd, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae_decoder']) + image = self.vae_decoder(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class FluxImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width")) + + def process(self, pipe: FluxImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + + +class FluxImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "seed", "rand_device")) + + def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device) + return {"noise": noise} + + + +class FluxImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae_encoder']) + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": None} + + + +class FluxImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + input_params=("t5_sequence_length",), + onload_model_names=("text_encoder_1", "text_encoder_2") + ) + + def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ).input_ids.to(device) + pooled_prompt_emb, _ = text_encoder(input_ids) + return pooled_prompt_emb + + def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True, + ).input_ids.to(device) + prompt_emb = text_encoder(input_ids) + return prompt_emb + + def encode_prompt( + self, + tokenizer_1, + tokenizer_2, + text_encoder_1, + text_encoder_2, + prompt, + positive=True, + device="cuda", + t5_sequence_length=512, + ): + pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) + prompt_emb = self.encode_prompt_using_t5(prompt, text_encoder_2, tokenizer_2, t5_sequence_length, device) + text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype) + return prompt_emb, pooled_prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, prompt, t5_sequence_length, positive) -> dict: + if pipe.text_encoder_1 is not None and pipe.text_encoder_2 is not None: + prompt_emb, pooled_prompt_emb, text_ids = self.encode_prompt( + tokenizer_1=pipe.tokenizer_1, tokenizer_2=pipe.tokenizer_2, + text_encoder_1=pipe.text_encoder_1, text_encoder_2=pipe.text_encoder_2, + prompt=prompt, device=pipe.device, positive=positive, t5_sequence_length=t5_sequence_length, + ) + return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} + else: + return {} + + +class FluxImageUnit_ImageIDs(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents",)) + + def process(self, pipe: FluxImagePipeline, latents): + latent_image_ids = pipe.dit.prepare_image_ids(latents) + return {"image_ids": latent_image_ids} + + + +class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit): + def __init__(self): + super().__init__(input_params=("embedded_guidance", "latents")) + + def process(self, pipe: FluxImagePipeline, embedded_guidance, latents): + guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + return {"guidance": guidance} + + + +class FluxImageUnit_Kontext(PipelineUnit): + def __init__(self): + super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride")) + + def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride): + if kontext_images is None: + return {} + if not isinstance(kontext_images, list): + kontext_images = [kontext_images] + + kontext_latents = [] + kontext_image_ids = [] + for kontext_image in kontext_images: + kontext_image = pipe.preprocess_image(kontext_image) + kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image_ids = pipe.dit.prepare_image_ids(kontext_latent) + image_ids[..., 0] = 1 + kontext_image_ids.append(image_ids) + kontext_latent = pipe.dit.patchify(kontext_latent) + kontext_latents.append(kontext_latent) + kontext_latents = torch.concat(kontext_latents, dim=1) + kontext_image_ids = torch.concat(kontext_image_ids, dim=-2) + return {"kontext_latents": kontext_latents, "kontext_image_ids": kontext_image_ids} + + + +class FluxImageUnit_ControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae_encoder",) + ) + + def apply_controlnet_mask_on_latents(self, pipe, latents, mask): + mask = (pipe.preprocess_image(mask) + 1) / 2 + mask = mask.mean(dim=1, keepdim=True) + mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) + latents = torch.concat([latents, mask], dim=1) + return latents + + def apply_controlnet_mask_on_image(self, pipe, image, mask): + mask = mask.resize(image.size) + mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu() + image = np.array(image) + image[mask > 0] = 0 + image = Image.fromarray(image) + return image + + def process(self, pipe: FluxImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): + if controlnet_inputs is None: + return {} + pipe.load_models_to_device(['vae_encoder']) + conditionings = [] + for controlnet_input in controlnet_inputs: + image = controlnet_input.image + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask) + + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) + conditionings.append(image) + return {"controlnet_conditionings": conditionings} + + + +class FluxImageUnit_IPAdapter(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("ipadapter_image_encoder", "ipadapter") + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + ipadapter_images, ipadapter_scale = inputs_shared.get("ipadapter_images", None), inputs_shared.get("ipadapter_scale", 1.0) + if ipadapter_images is None: + return inputs_shared, inputs_posi, inputs_nega + if not isinstance(ipadapter_images, list): + ipadapter_images = [ipadapter_images] + + pipe.load_models_to_device(self.onload_model_names) + images = [image.convert("RGB").resize((384, 384), resample=3) for image in ipadapter_images] + images = [pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) for image in images] + ipadapter_images = torch.cat(images, dim=0) + ipadapter_image_encoding = pipe.ipadapter_image_encoder(ipadapter_images).pooler_output + + inputs_posi.update({"ipadapter_kwargs_list": pipe.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update({"ipadapter_kwargs_list": pipe.ipadapter(torch.zeros_like(ipadapter_image_encoding))}) + return inputs_shared, inputs_posi, inputs_nega + + + +class FluxImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("text_encoder_1", "text_encoder_2") + ) + + def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ).input_ids.to(device) + pooled_prompt_emb, _ = text_encoder(input_ids) + return pooled_prompt_emb + + def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True, + ).input_ids.to(device) + prompt_emb = text_encoder(input_ids) + return prompt_emb + + def encode_prompt( + self, + tokenizer_1, + tokenizer_2, + text_encoder_1, + text_encoder_2, + prompt, + positive=True, + device="cuda", + t5_sequence_length=512, + ): + pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) + prompt_emb = self.encode_prompt_using_t5(prompt, text_encoder_2, tokenizer_2, t5_sequence_length, device) + text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype) + return prompt_emb, pooled_prompt_emb, text_ids + + def preprocess_masks(self, pipe, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height, t5_sequence_length=512): + entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + + prompt_emb, _, _ = self.encode_prompt( + tokenizer_1=pipe.tokenizer_1, tokenizer_2=pipe.tokenizer_2, + text_encoder_1=pipe.text_encoder_1, text_encoder_2=pipe.text_encoder_2, + prompt=entity_prompts, device=pipe.device, t5_sequence_length=t5_sequence_length, + ) + return prompt_emb.unsqueeze(0), entity_masks + + def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_on_negative, cfg_scale): + entity_prompt_emb_posi, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1) + entity_masks_nega = entity_masks_posi + else: + entity_prompt_emb_nega, entity_masks_nega = None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega} + return eligen_kwargs_posi, eligen_kwargs_nega + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) + if eligen_entity_prompts is None or eligen_entity_masks is None: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) + eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, + eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], + inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"]) + inputs_posi.update(eligen_kwargs_posi) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update(eligen_kwargs_nega) + return inputs_shared, inputs_posi, inputs_nega + + +class FluxImageUnit_NexusGen(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"), + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if pipe.nexus_gen is None: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + if inputs_shared.get("nexus_gen_reference_image", None) is None: + assert pipe.nexus_gen_generation_adapter is not None, "NexusGen requires a generation adapter to be set." + embed = pipe.nexus_gen(inputs_posi["prompt"])[0].unsqueeze(0) + inputs_posi["prompt_emb"] = pipe.nexus_gen_generation_adapter(embed) + inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype) + else: + assert pipe.nexus_gen_editing_adapter is not None, "NexusGen requires an editing adapter to be set." + embed, ref_embed, grids = pipe.nexus_gen(inputs_posi["prompt"], inputs_shared["nexus_gen_reference_image"]) + embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long) + ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long) + + inputs_posi["prompt_emb"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid) + inputs_posi["text_ids"] = self.get_editing_text_ids( + inputs_shared["latents"], + embeds_grid[0][1].item(), embeds_grid[0][2].item(), + ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(), + ) + return inputs_shared, inputs_posi, inputs_nega + + + def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width): + # prepare text ids for target and reference embeddings + batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width + embed_ids = torch.zeros(height // 2, width // 2, 3) + scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width + embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height + embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width + embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3) + embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype) + + batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width + ref_embed_ids = torch.zeros(height // 2, width // 2, 3) + scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width + ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0 + ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height + ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width + ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3) + ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype) + + text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1) + return text_ids + + +class FluxImageUnit_Step1x(PipelineUnit): + def __init__(self): + super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder")) + + def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict): + image = inputs_shared.get("step1x_reference_image",None) + if image is None: + return inputs_shared, inputs_posi, inputs_nega + else: + pipe.load_models_to_device(self.onload_model_names) + prompt = inputs_posi["prompt"] + nega_prompt = inputs_nega["negative_prompt"] + captions = [prompt, nega_prompt] + ref_images = [image, image] + embs, masks = pipe.qwenvl(captions, ref_images) + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae_encoder(image) + inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}) + if inputs_shared.get("cfg_scale", 1) != 1: + inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}) + return inputs_shared, inputs_posi, inputs_nega + + +class FluxImageUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh")) + + def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh): + if tea_cache_l1_thresh is None: + return {} + else: + return {"tea_cache": TeaCache(num_inference_steps=num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh)} + +class FluxImageUnit_Flex(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride): + if pipe.dit.input_dim == 196: + if flex_control_stop is None: + flex_control_stop = 1 + pipe.load_models_to_device(self.onload_model_names) + if flex_inpaint_image is None: + flex_inpaint_image = torch.zeros_like(latents) + else: + flex_inpaint_image = pipe.preprocess_image(flex_inpaint_image).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_inpaint_image = pipe.vae_encoder(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if flex_inpaint_mask is None: + flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :] + else: + flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2])) + flex_inpaint_mask = pipe.preprocess_image(flex_inpaint_mask).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2 + flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask) + if flex_control_image is None: + flex_control_image = torch.zeros_like(latents) + else: + flex_control_image = pipe.preprocess_image(flex_control_image).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_control_image = pipe.vae_encoder(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength + flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) + flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1) + flex_control_stop_timestep = pipe.scheduler.timesteps[int(flex_control_stop * (len(pipe.scheduler.timesteps) - 1))] + return {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep} + else: + return {} + + + +class FluxImageUnit_InfiniteYou(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("infinityou_id_image", "infinityou_guidance"), + onload_model_names=("infinityou_processor",) + ) + + def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance): + pipe.load_models_to_device("infinityou_processor") + if infinityou_id_image is not None: + return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device) + else: + return {} + + + +class FluxImageUnit_ValueControl(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, + input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, + input_params=("value_controller_inputs",), + onload_model_names=("value_controller",) + ) + + def add_to_text_embedding(self, prompt_emb, text_ids, value_emb): + prompt_emb = torch.concat([prompt_emb, value_emb], dim=1) + extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + return prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs): + if value_controller_inputs is None: + return {} + if not isinstance(value_controller_inputs, list): + value_controller_inputs = [value_controller_inputs] + value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device) + pipe.load_models_to_device(["value_controller"]) + value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype) + value_emb = value_emb.unsqueeze(0) + prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb) + return {"prompt_emb": prompt_emb, "text_ids": text_ids} + + + +class InfinitYou(torch.nn.Module): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__() + from facexlib.recognition import init_recognition_model + from insightface.app import FaceAnalysis + self.device = device + self.torch_dtype = torch_dtype + insightface_root_path = 'models/ByteDance/InfiniteYou/supports/insightface' + self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_640.prepare(ctx_id=0, det_size=(640, 640)) + self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_320.prepare(ctx_id=0, det_size=(320, 320)) + self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_160.prepare(ctx_id=0, det_size=(160, 160)) + self.arcface_model = init_recognition_model('arcface', device=self.device).to(torch_dtype) + + def _detect_face(self, id_image_cv2): + face_info = self.app_640.get(id_image_cv2) + if len(face_info) > 0: + return face_info + face_info = self.app_320.get(id_image_cv2) + if len(face_info) > 0: + return face_info + face_info = self.app_160.get(id_image_cv2) + return face_info + + def extract_arcface_bgr_embedding(self, in_image, landmark, device): + from insightface.utils import face_align + arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112) + arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255. + arc_face_image = 2 * arc_face_image - 1 + arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype) + face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized + return face_emb + + def prepare_infinite_you(self, model, id_image, infinityou_guidance, device): + import cv2 + if id_image is None: + return {'id_emb': None} + id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR) + face_info = self._detect_face(id_image_cv2) + if len(face_info) == 0: + raise ValueError('No face detected in the input ID image') + landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face + id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device) + id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype)) + infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype) + return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance} + + + +class FluxImageUnit_LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("lora_encoder",) + ) + + def parse_lora_encoder_inputs(self, lora_encoder_inputs): + if not isinstance(lora_encoder_inputs, list): + lora_encoder_inputs = [lora_encoder_inputs] + lora_configs = [] + for lora_encoder_input in lora_encoder_inputs: + if isinstance(lora_encoder_input, str): + lora_encoder_input = ModelConfig(path=lora_encoder_input) + lora_encoder_input.download_if_necessary() + lora_configs.append(lora_encoder_input) + return lora_configs + + def load_lora(self, lora_config, dtype, device): + loader = FluxLoRALoader(torch_dtype=dtype, device=device) + lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device) + lora = loader.convert_state_dict(lora) + return lora + + def lora_embedding(self, pipe, lora_encoder_inputs): + lora_emb = [] + for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs): + lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device) + lora_emb.append(pipe.lora_encoder(lora)) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb): + prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1) + extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + return prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("lora_encoder_inputs", None) is None: + return inputs_shared, inputs_posi, inputs_nega + + # Encode + pipe.load_models_to_device(["lora_encoder"]) + lora_encoder_inputs = inputs_shared["lora_encoder_inputs"] + lora_emb = self.lora_embedding(pipe, lora_encoder_inputs) + + # Scale + lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None) + if lora_encoder_scale is not None: + lora_emb = lora_emb * lora_encoder_scale + + # Add to prompt embedding + inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding( + inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb) + return inputs_shared, inputs_posi, inputs_nega + + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + def check(self, dit: FluxDiT, hidden_states, conditioning): + inp = hidden_states.clone() + temb_ = conditioning.clone() + modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_) + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = hidden_states.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +class FastTileWorker: + def __init__(self): + pass + + + def build_mask(self, data, is_bound): + _, _, H, W = data.shape + h = repeat(torch.arange(H), "H -> H W", H=H, W=W) + w = repeat(torch.arange(W), "W -> H W", H=H, W=W) + border_width = (H + W) // 4 + pad = torch.ones_like(h) * border_width + mask = torch.stack([ + pad if is_bound[0] else h + 1, + pad if is_bound[1] else H - h, + pad if is_bound[2] else w + 1, + pad if is_bound[3] else W - w + ]).min(dim=0).values + mask = mask.clip(1, border_width) + mask = (mask / border_width).to(dtype=data.dtype, device=data.device) + mask = rearrange(mask, "H W -> 1 H W") + return mask + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + B, C, H, W = model_input.shape + border_width = int(tile_stride*0.5) if border_width is None else border_width + weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device) + values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device) + + # Split tasks + tasks = [] + for h in range(0, H, tile_stride): + for w in range(0, W, tile_stride): + if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W): + continue + h_, w_ = h + tile_size, w + tile_size + if h_ > H: h, h_ = H - tile_size, H + if w_ > W: w, w_ = W - tile_size, W + tasks.append((h, h_, w, w_)) + + # Run + for hl, hr, wl, wr in tasks: + # Forward + hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device) + + mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W)) + values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask + weight[:, :, hl:hr, wl:wr] += mask + values /= weight + return values + + +def model_fn_flux_image( + dit: FluxDiT, + controlnet=None, + step1x_connector=None, + latents=None, + timestep=None, + prompt_emb=None, + pooled_prompt_emb=None, + guidance=None, + text_ids=None, + image_ids=None, + kontext_latents=None, + kontext_image_ids=None, + controlnet_inputs=None, + controlnet_conditionings=None, + tiled=False, + tile_size=128, + tile_stride=64, + entity_prompt_emb=None, + entity_masks=None, + ipadapter_kwargs_list={}, + id_emb=None, + infinityou_guidance=None, + flex_condition=None, + flex_uncondition=None, + flex_control_stop_timestep=None, + step1x_llm_embedding=None, + step1x_mask=None, + step1x_reference_latents=None, + tea_cache: TeaCache = None, + progress_id=0, + num_inference_steps=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs +): + if tiled: + def flux_forward_fn(hl, hr, wl, wr): + tiled_controlnet_conditionings = [f[:, :, hl: hr, wl: wr] for f in controlnet_conditionings] if controlnet_conditionings is not None else None + return model_fn_flux_image( + dit=dit, + controlnet=controlnet, + latents=latents[:, :, hl: hr, wl: wr], + timestep=timestep, + prompt_emb=prompt_emb, + pooled_prompt_emb=pooled_prompt_emb, + guidance=guidance, + text_ids=text_ids, + image_ids=None, + controlnet_inputs=controlnet_inputs, + controlnet_conditionings=tiled_controlnet_conditionings, + tiled=False, + **kwargs + ) + return FastTileWorker().tiled_forward( + flux_forward_fn, + latents, + tile_size=tile_size, + tile_stride=tile_stride, + tile_device=latents.device, + tile_dtype=latents.dtype + ) + + hidden_states = latents + + # ControlNet + if controlnet is not None and controlnet_conditionings is not None: + controlnet_extra_kwargs = { + "hidden_states": hidden_states, + "timestep": timestep, + "prompt_emb": prompt_emb, + "pooled_prompt_emb": pooled_prompt_emb, + "guidance": guidance, + "text_ids": text_ids, + "image_ids": image_ids, + "controlnet_inputs": controlnet_inputs, + "tiled": tiled, + "tile_size": tile_size, + "tile_stride": tile_stride, + "progress_id": progress_id, + "num_inference_steps": num_inference_steps, + } + if id_emb is not None: + controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype) + controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance}) + controlnet_res_stack, controlnet_single_res_stack = controlnet( + controlnet_conditionings, **controlnet_extra_kwargs + ) + + # Flex + if flex_condition is not None: + if timestep.tolist()[0] >= flex_control_stop_timestep: + hidden_states = torch.concat([hidden_states, flex_condition], dim=1) + else: + hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1) + + # Step1x + if step1x_llm_embedding is not None: + prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask) + text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device) + + if image_ids is None: + image_ids = dit.prepare_image_ids(hidden_states) + + conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb) + if dit.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) + + height, width = hidden_states.shape[-2:] + hidden_states = dit.patchify(hidden_states) + + # Kontext + if kontext_latents is not None: + image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, kontext_latents], dim=1) + + # Step1x + if step1x_reference_latents is not None: + step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents) + step1x_reference_latents = dit.patchify(step1x_reference_latents) + image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1) + + hidden_states = dit.x_embedder(hidden_states) + + # EliGen + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1]) + else: + prompt_emb = dit.context_embedder(prompt_emb) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) + else: + tea_cache_update = False + + if tea_cache_update: + hidden_states = tea_cache.update(hidden_states) + else: + # Joint Blocks + for block_id, block in enumerate(dit.blocks): + hidden_states, prompt_emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None), + ) + # ControlNet + if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None: + if kontext_latents is None: + hidden_states = hidden_states + controlnet_res_stack[block_id] + else: + hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id] + + # Single Blocks + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + num_joint_blocks = len(dit.blocks) + for block_id, block in enumerate(dit.single_blocks): + hidden_states, prompt_emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), + ) + # ControlNet + if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None: + if kontext_latents is None: + hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] + else: + hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id] + hidden_states = hidden_states[:, prompt_emb.shape[1]:] + + if tea_cache is not None: + tea_cache.store(hidden_states) + + hidden_states = dit.final_norm_out(hidden_states, conditioning) + hidden_states = dit.final_proj_out(hidden_states) + + # Step1x + if step1x_reference_latents is not None: + hidden_states = hidden_states[:, :hidden_states.shape[1] // 2] + + # Kontext + if kontext_latents is not None: + hidden_states = hidden_states[:, :-kontext_latents.shape[1]] + + hidden_states = dit.unpatchify(hidden_states, height, width) + + return hidden_states diff --git a/diffsynth/utils/lora/flux.py b/diffsynth/utils/lora/flux.py new file mode 100644 index 0000000..502c5fd --- /dev/null +++ b/diffsynth/utils/lora/flux.py @@ -0,0 +1,204 @@ +from .general import GeneralLoRALoader +import torch, math + + +class FluxLoRALoader(GeneralLoRALoader): + def __init__(self, device="cpu", torch_dtype=torch.float32): + super().__init__(device=device, torch_dtype=torch_dtype) + + self.diffusers_rename_dict = { + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.weight", + } + + self.civitai_rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.weight", + } + + def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + super().fuse_lora_to_base_model(model, state_dict_lora, alpha) + + def convert_state_dict(self, state_dict): + + def guess_block_id(name,model_resource): + if model_resource == 'civitai': + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + if model_resource == 'diffusers': + names = name.split(".") + for i in names: + if i.isdigit(): + return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") + return None, None + + def guess_resource(state_dict): + for k in state_dict: + if "lora_unet_" in k: + return 'civitai' + elif k.startswith("transformer."): + return 'diffusers' + else: + None + + model_resource = guess_resource(state_dict) + if model_resource is None: + return state_dict + + rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict + def guess_alpha(state_dict): + for name, param in state_dict.items(): + if ".alpha" in name: + for suffix in [".lora_down.weight", ".lora_A.weight"]: + name_ = name.replace(".alpha", suffix) + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + + return 1 + + alpha = guess_alpha(state_dict) + + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name,model_resource) + if alpha != 1: + param *= alpha + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + + if model_resource == 'diffusers': + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + dim = 4 + if 'lora_A' in name: + dim = 1 + mlp = torch.zeros(dim * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + if 'lora_A' in name: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + elif 'lora_B' in name: + d, r = state_dict_[name].shape + param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) + param[:d, :r] = state_dict_.pop(name) + param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) + param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) + param[3*d:, 3*r:] = mlp + else: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + concat_dim = 0 + if 'lora_A' in name: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + elif 'lora_B' in name: + origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + d, r = origin.shape + # print(d, r) + param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) + param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] + param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] + else: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/flux_vae.py b/diffsynth/utils/state_dict_converters/flux_vae.py new file mode 100644 index 0000000..70e0dba --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_vae.py @@ -0,0 +1,264 @@ +def FluxVAEEncoderStateDictConverter(state_dict): + rename_dict = { + "encoder.conv_in.bias": "conv_in.bias", + "encoder.conv_in.weight": "conv_in.weight", + "encoder.conv_out.bias": "conv_out.bias", + "encoder.conv_out.weight": "conv_out.weight", + "encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias", + "encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight", + "encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias", + "encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight", + "encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias", + "encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight", + "encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias", + "encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight", + "encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias", + "encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight", + "encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias", + "encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight", + "encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias", + "encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight", + "encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias", + "encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight", + "encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias", + "encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight", + "encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias", + "encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight", + "encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias", + "encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight", + "encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias", + "encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight", + "encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias", + "encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight", + "encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias", + "encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight", + "encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias", + "encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight", + "encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias", + "encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight", + "encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias", + "encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight", + "encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias", + "encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight", + "encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias", + "encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight", + "encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias", + "encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight", + "encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias", + "encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight", + "encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias", + "encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight", + "encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias", + "encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight", + "encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias", + "encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight", + "encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias", + "encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight", + "encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias", + "encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight", + "encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias", + "encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight", + "encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias", + "encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight", + "encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias", + "encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight", + "encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias", + "encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight", + "encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias", + "encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight", + "encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias", + "encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight", + "encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias", + "encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight", + "encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias", + "encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight", + "encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias", + "encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight", + "encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias", + "encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight", + "encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias", + "encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight", + "encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias", + "encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight", + "encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias", + "encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight", + "encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias", + "encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight", + "encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias", + "encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight", + "encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias", + "encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight", + "encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias", + "encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight", + "encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias", + "encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight", + "encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias", + "encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight", + "encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias", + "encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight", + "encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias", + "encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight", + "encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias", + "encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight", + "encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias", + "encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight", + "encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias", + "encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight", + "encoder.norm_out.bias": "conv_norm_out.bias", + "encoder.norm_out.weight": "conv_norm_out.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + state_dict_[rename_dict[name]] = param + return state_dict_ + + +def FluxVAEDecoderStateDictConverter(state_dict): + rename_dict = { + "decoder.conv_in.bias": "conv_in.bias", + "decoder.conv_in.weight": "conv_in.weight", + "decoder.conv_out.bias": "conv_out.bias", + "decoder.conv_out.weight": "conv_out.weight", + "decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias", + "decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight", + "decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias", + "decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight", + "decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias", + "decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight", + "decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias", + "decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight", + "decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias", + "decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight", + "decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias", + "decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight", + "decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias", + "decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight", + "decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias", + "decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight", + "decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias", + "decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight", + "decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias", + "decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight", + "decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias", + "decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight", + "decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias", + "decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight", + "decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias", + "decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight", + "decoder.norm_out.bias": "conv_norm_out.bias", + "decoder.norm_out.weight": "conv_norm_out.weight", + "decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias", + "decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight", + "decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias", + "decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight", + "decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias", + "decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight", + "decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias", + "decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight", + "decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias", + "decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight", + "decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias", + "decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight", + "decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias", + "decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight", + "decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias", + "decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight", + "decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias", + "decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight", + "decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias", + "decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight", + "decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias", + "decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight", + "decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias", + "decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight", + "decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias", + "decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight", + "decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias", + "decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight", + "decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias", + "decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight", + "decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias", + "decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight", + "decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias", + "decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight", + "decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias", + "decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight", + "decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias", + "decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight", + "decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias", + "decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight", + "decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias", + "decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight", + "decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias", + "decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight", + "decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias", + "decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight", + "decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias", + "decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight", + "decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias", + "decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight", + "decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias", + "decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight", + "decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias", + "decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight", + "decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias", + "decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight", + "decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias", + "decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight", + "decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias", + "decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight", + "decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias", + "decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight", + "decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias", + "decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight", + "decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias", + "decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight", + "decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias", + "decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight", + "decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias", + "decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight", + "decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias", + "decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight", + "decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias", + "decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight", + "decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias", + "decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight", + "decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias", + "decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight", + "decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias", + "decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight", + "decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias", + "decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight", + "decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias", + "decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight", + "decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias", + "decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight", + "decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias", + "decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight", + "decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias", + "decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight", + "decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias", + "decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight", + "decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias", + "decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight", + "decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias", + "decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight", + "decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias", + "decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight", + "decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias", + "decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight", + "decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias", + "decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight", + "decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias", + "decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight", + "decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias", + "decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/docs/Model_Details/Overview.md b/docs/Model_Details/Overview.md index d3ee9dd..6f40261 100644 --- a/docs/Model_Details/Overview.md +++ b/docs/Model_Details/Overview.md @@ -101,7 +101,7 @@ graph LR; ```python import torch -from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, @@ -109,7 +109,7 @@ pipe = FluxImagePipeline.from_pretrained( model_configs=[ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ], ) diff --git a/examples/flux/model_inference/FLEX.2-preview.py b/examples/flux/model_inference/FLEX.2-preview.py new file mode 100644 index 0000000..2689679 --- /dev/null +++ b/examples/flux/model_inference/FLEX.2-preview.py @@ -0,0 +1,50 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth.utils.controlnet import Annotator +import numpy as np +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +image = pipe( + prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + seed=0 +) +image.save(f"image_1.jpg") + +mask = np.zeros((1024, 1024, 3), dtype=np.uint8) +mask[200:400, 400:700] = 255 +mask = Image.fromarray(mask) +mask.save(f"image_mask.jpg") + +inpaint_image = image + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask, + seed=4 +) +image.save(f"image_2_new.jpg") + +control_image = Annotator("canny")(image) +control_image.save("image_control.jpg") + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_control_image=control_image, + seed=4 +) +image.save(f"image_3_new.jpg") diff --git a/examples/flux/model_inference/FLUX.1-Kontext-dev.py b/examples/flux/model_inference/FLUX.1-Kontext-dev.py new file mode 100644 index 0000000..e7aae1b --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-Kontext-dev.py @@ -0,0 +1,54 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +image_1 = pipe( + prompt="a beautiful Asian long-haired female college student.", + embedded_guidance=2.5, + seed=1, +) +image_1.save("image_1.jpg") + +image_2 = pipe( + prompt="transform the style to anime style.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=2, +) +image_2.save("image_2.jpg") + +image_3 = pipe( + prompt="let her smile.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=3, +) +image_3.save("image_3.jpg") + +image_4 = pipe( + prompt="let the girl play basketball.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=4, +) +image_4.save("image_4.jpg") + +image_5 = pipe( + prompt="move the girl to a park, let her sit on a chair.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=5, +) +image_5.save("image_5.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-Krea-dev.py b/examples/flux/model_inference/FLUX.1-Krea-dev.py new file mode 100644 index 0000000..978a26a --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-Krea-dev.py @@ -0,0 +1,27 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Krea-dev", origin_file_pattern="flux1-krea-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +prompt = "An beautiful woman is riding a bicycle in a park, wearing a red dress" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + +image = pipe(prompt=prompt, seed=0, embedded_guidance=4.5) +image.save("flux_krea.jpg") + +image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + seed=0, cfg_scale=2, num_inference_steps=50, + embedded_guidance=4.5 +) +image.save("flux_krea_cfg.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py b/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py new file mode 100644 index 0000000..b35cce8 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py @@ -0,0 +1,19 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors") + ], +) + +for i in [0.1, 0.3, 0.5, 0.7, 0.9]: + image = pipe(prompt="a cat on the beach", seed=2, value_controller_inputs=[i]) + image.save(f"value_control_{i}.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py new file mode 100644 index 0000000..3a0d1f3 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py @@ -0,0 +1,37 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +import numpy as np +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) + +image_1 = pipe( + prompt="a cat sitting on a chair", + height=1024, width=1024, + seed=8, rand_device="cuda", +) +image_1.save("image_1.jpg") + +mask = np.zeros((1024, 1024, 3), dtype=np.uint8) +mask[100:350, 350: -300] = 255 +mask = Image.fromarray(mask) +mask.save("mask.jpg") + +image_2 = pipe( + prompt="a cat sitting on a chair, wearing sunglasses", + controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)], + height=1024, width=1024, + seed=9, rand_device="cuda", +) +image_2.save("image_2.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py new file mode 100644 index 0000000..2fa10aa --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py @@ -0,0 +1,40 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from diffsynth.utils.controlnet import Annotator +from modelscope import snapshot_download + + + +snapshot_download("sd_lora/Annotators", allow_file_pattern="dpt_hybrid-midas-501f0c75.pt", local_dir="models/Annotators") +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) + +image_1 = pipe( + prompt="a beautiful Asian girl, full body, red dress, summer", + height=1024, width=1024, + seed=6, rand_device="cuda", +) +image_1.save("image_1.jpg") + +image_canny = Annotator("canny")(image_1) +image_depth = Annotator("depth")(image_1) + +image_2 = pipe( + prompt="a beautiful Asian girl, full body, red dress, winter", + controlnet_inputs=[ + ControlNetInput(image=image_canny, scale=0.3, processor_id="canny"), + ControlNetInput(image=image_depth, scale=0.3, processor_id="depth"), + ], + height=1024, width=1024, + seed=7, rand_device="cuda", +) +image_2.save("image_2.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py new file mode 100644 index 0000000..b4c288d --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py @@ -0,0 +1,33 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) + +image_1 = pipe( + prompt="a photo of a cat, highly detailed", + height=768, width=768, + seed=0, rand_device="cuda", +) +image_1.save("image_1.jpg") + +image_1 = image_1.resize((2048, 2048)) +image_2 = pipe( + prompt="a photo of a cat, highly detailed", + controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)], + input_image=image_1, + denoising_strength=0.99, + height=2048, width=2048, tiled=True, + seed=1, rand_device="cuda", +) +image_2.save("image_2.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-dev-EliGen.py b/examples/flux/model_inference/FLUX.1-dev-EliGen.py new file mode 100644 index 0000000..6bc4d2e --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-EliGen.py @@ -0,0 +1,133 @@ +import random +import torch +from PIL import Image, ImageDraw, ImageFont +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + +def example(pipe, seeds, example_id, global_prompt, entity_prompts): + dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png") + masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + seed=seed, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_example_{example_id}_{seed}.png") + visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png") + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, ModelConfig(model_id="DiffSynth-Studio/Eligen", origin_file_pattern="model_bf16.safetensors"), alpha=1) + +# example 1 +global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n" +entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"] +example(pipe, [0], 1, global_prompt, entity_prompts) + +# example 2 +global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render." +entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"] +example(pipe, [0], 2, global_prompt, entity_prompts) + +# example 3 +global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning," +entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"] +example(pipe, [27], 3, global_prompt, entity_prompts) + +# example 4 +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +example(pipe, [21], 4, global_prompt, entity_prompts) + +# example 5 +global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere." +entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"] +example(pipe, [0], 5, global_prompt, entity_prompts) + +# example 6 +global_prompt = "Snow White and the 6 Dwarfs." +entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"] +example(pipe, [8], 6, global_prompt, entity_prompts) + +# example 7, same prompt with different seeds +seeds = range(5, 9) +global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +example(pipe, seeds, 7, global_prompt, entity_prompts) diff --git a/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py b/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py new file mode 100644 index 0000000..1479e1d --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py @@ -0,0 +1,24 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"), + ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors"), + ], +) + +origin_prompt = "a rabbit in a garden, colorful flowers" +image = pipe(prompt=origin_prompt, height=1280, width=960, seed=42) +image.save("style image.jpg") + +image = pipe(prompt="A piggy", height=1280, width=960, seed=42, + ipadapter_images=[image], ipadapter_scale=0.7) +image.save("A piggy.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py b/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py new file mode 100644 index 0000000..2c2eb58 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py @@ -0,0 +1,59 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from modelscope import snapshot_download +from PIL import Image +import numpy as np + + +snapshot_download( + "ByteDance/InfiniteYou", + allow_file_pattern="supports/insightface/models/antelopev2/*", + local_dir="models/ByteDance/InfiniteYou", +) +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"), + ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"), + ], +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/infiniteyou/*", +) + +height, width = 1024, 1024 +controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8)) +controlnet_inputs = [ControlNetInput(image=controlnet_image, scale=1.0, processor_id="None")] + +prompt = "A man, portrait, cinematic" +id_image = "data/examples/infiniteyou/man.jpg" +id_image = Image.open(id_image).convert('RGB') +image = pipe( + prompt=prompt, seed=1, + infinityou_id_image=id_image, infinityou_guidance=1.0, + controlnet_inputs=controlnet_inputs, + num_inference_steps=50, embedded_guidance=3.5, + height=height, width=width, +) +image.save("man.jpg") + +prompt = "A woman, portrait, cinematic" +id_image = "data/examples/infiniteyou/woman.jpg" +id_image = Image.open(id_image).convert('RGB') +image = pipe( + prompt=prompt, seed=1, + infinityou_id_image=id_image, infinityou_guidance=1.0, + controlnet_inputs=controlnet_inputs, + num_inference_steps=50, embedded_guidance=3.5, + height=height, width=width, +) +image.save("woman.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py new file mode 100644 index 0000000..9e3d74b --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py @@ -0,0 +1,40 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"), + ], +) +pipe.enable_lora_magic() + +lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors") +pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA. + +# Empty prompt can automatically activate LoRA capabilities. +image = pipe(prompt="", seed=0, lora_encoder_inputs=lora) +image.save("image_1.jpg") + +image = pipe(prompt="", seed=0) +image.save("image_1_origin.jpg") + +# Prompt without trigger words can also activate LoRA capabilities. +image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora) +image.save("image_2.jpg") + +image = pipe(prompt="a car", seed=0,) +image.save("image_2_origin.jpg") + +# Adjust the activation intensity through the scale parameter. +image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0) +image.save("image_3.jpg") + +image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5) +image.save("image_3_scale.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py new file mode 100644 index 0000000..5339230 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py @@ -0,0 +1,29 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors"), + ], +) +pipe.enable_lora_magic() + +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"), + hotload=True, +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"), + hotload=True, +) +image = pipe(prompt="a cat", seed=0) +image.save("image_fused.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev.py b/examples/flux/model_inference/FLUX.1-dev.py new file mode 100644 index 0000000..35d1e96 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev.py @@ -0,0 +1,26 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + +image = pipe(prompt=prompt, seed=0) +image.save("flux.jpg") + +image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + seed=0, cfg_scale=2, num_inference_steps=50, +) +image.save("flux_cfg.jpg") diff --git a/examples/flux/model_inference/Nexus-Gen-Editing.py b/examples/flux/model_inference/Nexus-Gen-Editing.py new file mode 100644 index 0000000..6769165 --- /dev/null +++ b/examples/flux/model_inference/Nexus-Gen-Editing.py @@ -0,0 +1,37 @@ +import importlib +import torch +from PIL import Image +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], + nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"), +) + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg") +ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB") +prompt = "Add a crown." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=2.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("cat_crown.jpg") diff --git a/examples/flux/model_inference/Nexus-Gen-Generation.py b/examples/flux/model_inference/Nexus-Gen-Generation.py new file mode 100644 index 0000000..5130d67 --- /dev/null +++ b/examples/flux/model_inference/Nexus-Gen-Generation.py @@ -0,0 +1,32 @@ +import importlib +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], + nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), +) + +prompt = "一只可爱的猫咪" +image = pipe( + prompt=prompt, negative_prompt="", + seed=0, cfg_scale=3, num_inference_steps=50, + height=1024, width=1024, +) +image.save("cat.jpg") diff --git a/examples/flux/model_inference/Step1X-Edit.py b/examples/flux/model_inference/Step1X-Edit.py new file mode 100644 index 0000000..1ec517b --- /dev/null +++ b/examples/flux/model_inference/Step1X-Edit.py @@ -0,0 +1,32 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from PIL import Image +import numpy as np + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"), + ], +) + +image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255) +image = pipe( + prompt="draw red flowers in Chinese ink painting style", + step1x_reference_image=image, + width=832, height=1248, cfg_scale=6, + seed=1, rand_device='cuda' +) +image.save("image_1.jpg") + +image = pipe( + prompt="add more flowers in Chinese ink painting style", + step1x_reference_image=image, + width=832, height=1248, cfg_scale=6, + seed=2, rand_device='cuda' +) +image.save("image_2.jpg")