From 9cb887015be703f9344542efe8f6c723d14a29f9 Mon Sep 17 00:00:00 2001 From: "lzw478614@alibaba-inc.com" Date: Wed, 2 Jul 2025 13:32:24 +0800 Subject: [PATCH 1/3] lora hotload and merge --- diffsynth/lora/flux_lora.py | 60 ++++++++++++++++++++++++++- diffsynth/pipelines/flux_image_new.py | 42 ++++++++++++++++++- diffsynth/vram_management/layers.py | 15 ++++++- 3 files changed, 114 insertions(+), 3 deletions(-) diff --git a/diffsynth/lora/flux_lora.py b/diffsynth/lora/flux_lora.py index 899160f..b0f17aa 100644 --- a/diffsynth/lora/flux_lora.py +++ b/diffsynth/lora/flux_lora.py @@ -10,4 +10,62 @@ class FluxLoRALoader(GeneralLoRALoader): def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): lora_prefix, model_resource = self.loader.match(model, state_dict_lora) - self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource) \ No newline at end of file + self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource) + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + +class LoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index fe651f9..9c07f89 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -21,7 +21,8 @@ from ..models.flux_ipadapter import FluxIpAdapter from ..models.flux_infiniteyou import InfiniteYouImageProjector from ..models.tiler import FastTileWorker from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit -from ..lora.flux_lora import FluxLoRALoader +from ..lora.flux_lora import FluxLoRALoader,LoraPatcher +from ..models.lora import FluxLoRAConverter from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense from ..models.flux_dit import RMSNorm @@ -121,6 +122,45 @@ class FluxImagePipeline(BasePipeline): lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) loader.load(module, lora, alpha=alpha) + def enable_lora_hotload(self, lora_paths): + # load lora state dict and align format + lora_state_dicts = [ + FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)) for path in lora_paths + ] + lora_state_dicts = [l for l in lora_state_dicts if l != {}] + + for name, module in self.dit.named_modules(): + if isinstance(module, torch.nn.Linear): + lora_a_name = f'{name}.lora_A.default.weight' + lora_b_name = f'{name}.lora_B.default.weight' + lora_A_weights = [] + lora_B_weights = [] + for lora_dict in lora_state_dicts: + if lora_a_name in lora_dict and lora_b_name in lora_dict: + lora_A_weights.append(lora_dict[lora_a_name]) + lora_B_weights.append(lora_dict[lora_b_name]) + module.lora_A_weights = lora_A_weights + module.lora_B_weights = lora_B_weights + + + def enable_lora_patcher(self, lora_patcher_path): + # load lora patcher + lora_patcher = LoraPatcher().to(dtype=self.torch_dtype, device=self.device) + lora_patcher.load_state_dict(load_state_dict(lora_patcher_path)) + + for name, module in self.dit.named_modules(): + if isinstance(module, torch.nn.Linear): + merger_name = name.replace(".", "___") + if merger_name in lora_patcher.model_dict: + module.lora_merger = lora_patcher.model_dict[merger_name] + + + def off_lora_hotload(self): + for name, module in self.dit.named_modules(): + if isinstance(module, torch.nn.Linear): + module.lora_A_weights = [] + module.lora_B_weights = [] + def training_loss(self, **inputs): timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index c0beaf8..4dfec12 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -107,6 +107,9 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): self.vram_limit = vram_limit self.state = 0 self.name = name + self.lora_A_weights = [] + self.lora_B_weights = [] + self.lora_merger = None def forward(self, x, *args, **kwargs): if self.state == 2: @@ -120,7 +123,17 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): else: weight = cast_to(self.weight, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) - return torch.nn.functional.linear(x, weight, bias) + out = torch.nn.functional.linear(x, weight, bias) + lora_output = [] + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + out_lora = x @ lora_A.T @ lora_B.T + if self.lora_merger is None: + out = out + out_lora + lora_output.append(out_lora) + if self.lora_merger is not None and len(lora_output) > 0: + lora_output = torch.stack(lora_output) + out = self.lora_merger(out, lora_output) + return out def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""): From 8a9dbbd3ba4acc71d34a5d3f0786b6c2392db5b8 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 3 Jul 2025 18:49:46 +0800 Subject: [PATCH 2/3] support lora fusion --- diffsynth/configs/model_config.py | 3 + diffsynth/lora/flux_lora.py | 81 +++++++++++++++-- diffsynth/pipelines/flux_image_new.py | 90 ++++++++++--------- diffsynth/vram_management/layers.py | 20 +++-- .../model_inference/FLUX.1-dev-LoRAFusion.py | 35 ++++++++ 5 files changed, 175 insertions(+), 54 deletions(-) create mode 100644 examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index f39b87a..0713b7c 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -64,6 +64,8 @@ from ..models.wan_video_vace import VaceWanModel from ..models.step1x_connector import Qwen2Connector +from ..lora.flux_lora import FluxLoraPatcher + model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -144,6 +146,7 @@ model_loader_configs = [ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"), (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), + (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/lora/flux_lora.py b/diffsynth/lora/flux_lora.py index b0f17aa..cc9d725 100644 --- a/diffsynth/lora/flux_lora.py +++ b/diffsynth/lora/flux_lora.py @@ -1,4 +1,4 @@ -import torch +import torch, math from diffsynth.lora import GeneralLoRALoader from diffsynth.models.lora import FluxLoRAFromCivitai @@ -6,11 +6,69 @@ from diffsynth.models.lora import FluxLoRAFromCivitai class FluxLoRALoader(GeneralLoRALoader): def __init__(self, device="cpu", torch_dtype=torch.float32): super().__init__(device=device, torch_dtype=torch_dtype) - self.loader = FluxLoRAFromCivitai() def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): - lora_prefix, model_resource = self.loader.match(model, state_dict_lora) - self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource) + super().load(model, state_dict_lora, alpha) + + def convert_state_dict(self, state_dict): + # TODO: support other lora format + rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + def guess_block_id(name): + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + return None, None + def guess_alpha(state_dict): + for name, param in state_dict.items(): + if ".alpha" in name: + name_ = name.replace(".alpha", ".lora_down.weight") + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + return 1 + alpha = guess_alpha(state_dict) + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name) + if alpha != 1: + param *= alpha + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + return state_dict_ + class LoraMerger(torch.nn.Module): def __init__(self, dim): @@ -35,7 +93,8 @@ class LoraMerger(torch.nn.Module): output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) return output -class LoraPatcher(torch.nn.Module): + +class FluxLoraPatcher(torch.nn.Module): def __init__(self, lora_patterns=None): super().__init__() if lora_patterns is None: @@ -69,3 +128,15 @@ class LoraPatcher(torch.nn.Module): def forward(self, base_output, lora_outputs, name): return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) + + @staticmethod + def state_dict_converter(): + return FluxLoraPatcherStateDictConverter() + + +class FluxLoraPatcherStateDictConverter: + def __init__(self): + pass + + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 9c07f89..c8985dc 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -21,8 +21,7 @@ from ..models.flux_ipadapter import FluxIpAdapter from ..models.flux_infiniteyou import InfiniteYouImageProjector from ..models.tiler import FastTileWorker from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit -from ..lora.flux_lora import FluxLoRALoader,LoraPatcher -from ..models.lora import FluxLoRAConverter +from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense from ..models.flux_dit import RMSNorm @@ -92,12 +91,13 @@ class FluxImagePipeline(BasePipeline): self.controlnet: MultiControlNet = None self.ipadapter: FluxIpAdapter = None self.ipadapter_image_encoder = None - self.unit_runner = PipelineUnitRunner() self.qwenvl = None self.step1x_connector: Qwen2Connector = None self.infinityou_processor: InfinitYou = None self.image_proj_model: InfiniteYouImageProjector = None - self.in_iteration_models = ("dit", "step1x_connector", "controlnet") + self.lora_patcher: FluxLoraPatcher = None + self.unit_runner = PipelineUnitRunner() + self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher") self.units = [ FluxImageUnit_ShapeChecker(), FluxImageUnit_NoiseInitializer(), @@ -117,49 +117,55 @@ class FluxImagePipeline(BasePipeline): self.model_fn = model_fn_flux_image - def load_lora(self, module, path, alpha=1): + def load_lora( + self, + module: torch.nn.Module, + lora_config: Union[ModelConfig, str], + alpha=1, + hotload=False, + local_model_path="./models", + skip_download=False + ): + if isinstance(lora_config, str): + lora_config = ModelConfig(path=lora_config) + else: + lora_config.download_if_necessary(local_model_path, skip_download=skip_download) loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device) - lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) - loader.load(module, lora, alpha=alpha) - - def enable_lora_hotload(self, lora_paths): - # load lora state dict and align format - lora_state_dicts = [ - FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)) for path in lora_paths - ] - lora_state_dicts = [l for l in lora_state_dicts if l != {}] + lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device) + lora = loader.convert_state_dict(lora) + if hotload: + for name, module in module.named_modules(): + if isinstance(module, AutoWrappedLinear): + lora_a_name = f'{name}.lora_A.default.weight' + lora_b_name = f'{name}.lora_B.default.weight' + if lora_a_name in lora and lora_b_name in lora: + module.lora_A_weights.append(lora[lora_a_name] * alpha) + module.lora_B_weights.append(lora[lora_b_name]) + else: + loader.load(module, lora, alpha=alpha) + + def enable_lora_patcher(self): + if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled): + print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.") + return + if self.lora_patcher is None: + print("Please load lora patcher models before `enable_lora_patcher()`.") + return for name, module in self.dit.named_modules(): - if isinstance(module, torch.nn.Linear): - lora_a_name = f'{name}.lora_A.default.weight' - lora_b_name = f'{name}.lora_B.default.weight' - lora_A_weights = [] - lora_B_weights = [] - for lora_dict in lora_state_dicts: - if lora_a_name in lora_dict and lora_b_name in lora_dict: - lora_A_weights.append(lora_dict[lora_a_name]) - lora_B_weights.append(lora_dict[lora_b_name]) - module.lora_A_weights = lora_A_weights - module.lora_B_weights = lora_B_weights - - - def enable_lora_patcher(self, lora_patcher_path): - # load lora patcher - lora_patcher = LoraPatcher().to(dtype=self.torch_dtype, device=self.device) - lora_patcher.load_state_dict(load_state_dict(lora_patcher_path)) - - for name, module in self.dit.named_modules(): - if isinstance(module, torch.nn.Linear): + if isinstance(module, AutoWrappedLinear): merger_name = name.replace(".", "___") - if merger_name in lora_patcher.model_dict: - module.lora_merger = lora_patcher.model_dict[merger_name] + if merger_name in self.lora_patcher.model_dict: + module.lora_merger = self.lora_patcher.model_dict[merger_name] - def off_lora_hotload(self): - for name, module in self.dit.named_modules(): - if isinstance(module, torch.nn.Linear): - module.lora_A_weights = [] - module.lora_B_weights = [] + def clear_lora(self): + for name, module in self.named_modules(): + if isinstance(module, AutoWrappedLinear): + if hasattr(module, "lora_A_weights"): + module.lora_A_weights.clear() + if hasattr(module, "lora_B_weights"): + module.lora_B_weights.clear() def training_loss(self, **inputs): @@ -325,10 +331,10 @@ class FluxImagePipeline(BasePipeline): pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") pipe.qwenvl = model_manager.fetch_model("qwenvl") pipe.step1x_connector = model_manager.fetch_model("step1x_connector") - pipe.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector") if pipe.image_proj_model is not None: pipe.infinityou_processor = InfinitYou(device=device) + pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher") # ControlNet controlnets = [] diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index 4dfec12..0ebb054 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -124,13 +124,19 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): weight = cast_to(self.weight, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) out = torch.nn.functional.linear(x, weight, bias) - lora_output = [] - for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): - out_lora = x @ lora_A.T @ lora_B.T - if self.lora_merger is None: - out = out + out_lora - lora_output.append(out_lora) - if self.lora_merger is not None and len(lora_output) > 0: + + if len(self.lora_A_weights) == 0: + # No LoRA + return out + elif self.lora_merger is None: + # Native LoRA inference + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + out = out + x @ lora_A.T @ lora_B.T + else: + # LoRA fusion + lora_output = [] + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + lora_output.append(x @ lora_A.T @ lora_B.T) lora_output = torch.stack(lora_output) out = self.lora_merger(out, lora_output) return out diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py b/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py new file mode 100644 index 0000000..d6039df --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py @@ -0,0 +1,35 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-LoRAFusion", origin_file_pattern="model.safetensors") + ], +) +pipe.enable_vram_management() +pipe.enable_lora_patcher() +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="yangyufeng/fgao", origin_file_pattern="30.safetensors"), + hotload=True +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="bobooblue/LoRA-bling-mai", origin_file_pattern="10.safetensors"), + hotload=True +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="JIETANGAB/E", origin_file_pattern="17.safetensors"), + hotload=True +) + +image = pipe(prompt="a beautiful Asian girl", seed=0) +image.save("flux.jpg") From 77676b5ceace4e20748887c8f22146f66353dcaa Mon Sep 17 00:00:00 2001 From: lzws <63908509+lzws@users.noreply.github.com> Date: Mon, 7 Jul 2025 10:54:49 +0800 Subject: [PATCH 3/3] Update FLUX.1-dev-LoRAFusion.py --- examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py b/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py index d6039df..68116d0 100644 --- a/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py +++ b/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py @@ -31,5 +31,5 @@ pipe.load_lora( hotload=True ) -image = pipe(prompt="a beautiful Asian girl", seed=0) +image = pipe(prompt="This is a digital painting in a soft, ethereal style. a beautiful Asian girl Shine like a diamond. Everywhere is shining with bling bling luster.The background is a textured blue with visible brushstrokes, giving the image an impressionistic style reminiscent of Vincent van Gogh's work", seed=0) image.save("flux.jpg")