diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 7a0b72b..c2e050d 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -67,6 +67,7 @@ from ..models.step1x_connector import Qwen2Connector from ..models.flux_value_control import SingleValueEncoder from ..lora.flux_lora import FluxLoraPatcher +from ..models.flux_lora_encoder import FluxLoRAEncoder model_loader_configs = [ @@ -150,6 +151,7 @@ model_loader_configs = [ (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"), (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"), + (None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/flux_lora_encoder.py b/diffsynth/models/flux_lora_encoder.py new file mode 100644 index 0000000..695640a --- /dev/null +++ b/diffsynth/models/flux_lora_encoder.py @@ -0,0 +1,111 @@ +import torch +from .sd_text_encoder import CLIPEncoderLayer + + +class LoRALayerBlock(torch.nn.Module): + def __init__(self, L, dim_in, dim_out): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(1, L, dim_in)) + self.layer_norm = torch.nn.LayerNorm(dim_out) + + def forward(self, lora_A, lora_B): + x = self.x @ lora_A.T @ lora_B.T + x = self.layer_norm(x) + return x + + +class LoRAEmbedder(torch.nn.Module): + def __init__(self, lora_patterns=None, L=1, out_dim=2048): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1]) + self.model_dict = torch.nn.ModuleDict(model_dict) + + proj_dict = {} + for lora_pattern in lora_patterns: + layer_type, dim = lora_pattern["type"], lora_pattern["dim"] + if layer_type not in proj_dict: + proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim) + self.proj_dict = torch.nn.ModuleDict(proj_dict) + + self.lora_patterns = lora_patterns + + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432), + "attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432), + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + return lora_patterns + + def forward(self, lora): + lora_emb = [] + for lora_pattern in self.lora_patterns: + name, layer_type = lora_pattern["name"], lora_pattern["type"] + lora_A = lora[name + ".lora_A.default.weight"] + lora_B = lora[name + ".lora_B.default.weight"] + lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) + lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) + lora_emb.append(lora_out) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + +class FluxLoRAEncoder(torch.nn.Module): + def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1): + super().__init__() + self.num_embeds_per_lora = num_embeds_per_lora + # embedder + self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)]) + + # special embedding + self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim)) + self.num_special_embeds = num_special_embeds + + # final layer + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + self.final_linear = torch.nn.Linear(embed_dim, embed_dim) + + def forward(self, lora): + lora_embeds = self.embedder(lora) + special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device) + embeds = torch.concat([special_embeds, lora_embeds], dim=1) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds) + embeds = embeds[:, :self.num_special_embeds] + embeds = self.final_layer_norm(embeds) + embeds = self.final_linear(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return FluxLoRAEncoderStateDictConverter() + + +class FluxLoRAEncoderStateDictConverter: + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/models/utils.py b/diffsynth/models/utils.py index 0d58e4e..86104d0 100644 --- a/diffsynth/models/utils.py +++ b/diffsynth/models/utils.py @@ -71,7 +71,7 @@ def load_state_dict(file_path, torch_dtype=None, device="cpu"): def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): state_dict = {} - with safe_open(file_path, framework="pt", device=device) as f: + with safe_open(file_path, framework="pt", device=str(device)) as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) if torch_dtype is not None: diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 3dbb9b8..811b119 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -20,6 +20,7 @@ from ..models.flux_controlnet import FluxControlNet from ..models.flux_ipadapter import FluxIpAdapter from ..models.flux_value_control import MultiValueEncoder from ..models.flux_infiniteyou import InfiniteYouImageProjector +from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock from ..models.tiler import FastTileWorker from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher @@ -97,6 +98,7 @@ class FluxImagePipeline(BasePipeline): self.infinityou_processor: InfinitYou = None self.image_proj_model: InfiniteYouImageProjector = None self.lora_patcher: FluxLoraPatcher = None + self.lora_encoder: FluxLoRAEncoder = None self.unit_runner = PipelineUnitRunner() self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher") self.units = [ @@ -115,6 +117,7 @@ class FluxImagePipeline(BasePipeline): FluxImageUnit_Flex(), FluxImageUnit_Step1x(), FluxImageUnit_ValueControl(), + FluxImageUnit_LoRAEncode(), ] self.model_fn = model_fn_flux_image @@ -196,6 +199,7 @@ class FluxImagePipeline(BasePipeline): torch.nn.Conv2d: AutoWrappedModule, torch.nn.GroupNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule, + LoRALayerBlock: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, @@ -207,6 +211,33 @@ class FluxImagePipeline(BasePipeline): ), vram_limit=vram_limit, ) + + + def enable_lora_magic(self): + if self.dit is not None: + if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled): + dtype = next(iter(self.dit.parameters())).dtype + enable_vram_management( + self.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device=self.device, + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=None, + ) + if self.lora_patcher is not None: + for name, module in self.dit.named_modules(): + if isinstance(module, AutoWrappedLinear): + merger_name = name.replace(".", "___") + if merger_name in self.lora_patcher.model_dict: + module.lora_merger = self.lora_patcher.model_dict[merger_name] def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): @@ -219,7 +250,7 @@ class FluxImagePipeline(BasePipeline): vram_limit = vram_limit - vram_buffer # Default config - default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector"] + default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector", "lora_encoder"] for model_name in default_vram_management_models: self._enable_vram_management_with_default_config(getattr(self, model_name), vram_limit) @@ -366,6 +397,7 @@ class FluxImagePipeline(BasePipeline): if pipe.image_proj_model is not None: pipe.infinityou_processor = InfinitYou(device=device) pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher") + pipe.lora_encoder = model_manager.fetch_model("flux_lora_encoder") # ControlNet controlnets = [] @@ -437,6 +469,9 @@ class FluxImagePipeline(BasePipeline): value_controller_inputs: list[float] = None, # Step1x step1x_reference_image: Image.Image = None, + # LoRA Encoder + lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None, + lora_encoder_scale: float = 1.0, # TeaCache tea_cache_l1_thresh: float = None, # Tile @@ -470,6 +505,7 @@ class FluxImagePipeline(BasePipeline): "flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop, "value_controller_inputs": value_controller_inputs, "step1x_reference_image": step1x_reference_image, + "lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale, "tea_cache_l1_thresh": tea_cache_l1_thresh, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "progress_bar_cmd": progress_bar_cmd, @@ -884,6 +920,66 @@ class InfinitYou(torch.nn.Module): return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance} + +class FluxImageUnit_LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("lora_encoder",) + ) + + def parse_lora_encoder_inputs(self, lora_encoder_inputs): + if not isinstance(lora_encoder_inputs, list): + lora_encoder_inputs = [lora_encoder_inputs] + lora_configs = [] + for lora_encoder_input in lora_encoder_inputs: + if isinstance(lora_encoder_input, str): + lora_encoder_input = ModelConfig(path=lora_encoder_input) + lora_encoder_input.download_if_necessary() + lora_configs.append(lora_encoder_input) + return lora_configs + + def load_lora(self, lora_config, dtype, device): + loader = FluxLoRALoader(torch_dtype=dtype, device=device) + lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device) + lora = loader.convert_state_dict(lora) + return lora + + def lora_embedding(self, pipe, lora_encoder_inputs): + lora_emb = [] + for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs): + lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device) + lora_emb.append(pipe.lora_encoder(lora)) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb): + prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1) + extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + return prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("lora_encoder_inputs", None) is None: + return inputs_shared, inputs_posi, inputs_nega + + # Encode + pipe.load_models_to_device(["lora_encoder"]) + lora_encoder_inputs = inputs_shared["lora_encoder_inputs"] + lora_emb = self.lora_embedding(pipe, lora_encoder_inputs) + + # Scale + lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None) + if lora_encoder_scale is not None: + lora_emb = lora_emb * lora_encoder_scale + + # Add to prompt embedding + inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding( + inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb) + return inputs_shared, inputs_posi, inputs_nega + + + class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh): self.num_inference_steps = num_inference_steps diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 9f52ddc..6902011 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -165,6 +165,7 @@ class ModelConfig: download_resource: str = "ModelScope" offload_device: Optional[Union[str, torch.device]] = None offload_dtype: Optional[torch.dtype] = None + skip_download: bool = False def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False): if self.path is None: @@ -190,6 +191,7 @@ class ModelConfig: is_folder = False # Download + skip_download = skip_download or self.skip_download if not skip_download: downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id)) snapshot_download( diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py new file mode 100644 index 0000000..d133024 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py @@ -0,0 +1,40 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"), + ], +) +pipe.enable_lora_magic() + +lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors") +pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA. + +# Empty prompt can automatically activate LoRA capabilities. +image = pipe(prompt="", seed=0, lora_encoder_inputs=lora) +image.save("image_1.jpg") + +image = pipe(prompt="", seed=0) +image.save("image_1_origin.jpg") + +# Prompt without trigger words can also activate LoRA capabilities. +image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora) +image.save("image_2.jpg") + +image = pipe(prompt="a car", seed=0,) +image.save("image_2_origin.jpg") + +# Adjust the activation intensity through the scale parameter. +image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0) +image.save("image_3.jpg") + +image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5) +image.save("image_3_scale.jpg") diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py new file mode 100644 index 0000000..54322bd --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py @@ -0,0 +1,41 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ], +) +pipe.enable_vram_management() +pipe.enable_lora_magic() + +lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors") +pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA. + +# Empty prompt can automatically activate LoRA capabilities. +image = pipe(prompt="", seed=0, lora_encoder_inputs=lora) +image.save("image_1.jpg") + +image = pipe(prompt="", seed=0) +image.save("image_1_origin.jpg") + +# Prompt without trigger words can also activate LoRA capabilities. +image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora) +image.save("image_2.jpg") + +image = pipe(prompt="a car", seed=0,) +image.save("image_2_origin.jpg") + +# Adjust the activation intensity through the scale parameter. +image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0) +image.save("image_3.jpg") + +image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5) +image.save("image_3_scale.jpg")