mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -67,6 +67,7 @@ from ..models.step1x_connector import Qwen2Connector
|
|||||||
from ..models.flux_value_control import SingleValueEncoder
|
from ..models.flux_value_control import SingleValueEncoder
|
||||||
|
|
||||||
from ..lora.flux_lora import FluxLoraPatcher
|
from ..lora.flux_lora import FluxLoraPatcher
|
||||||
|
from ..models.flux_lora_encoder import FluxLoRAEncoder
|
||||||
|
|
||||||
|
|
||||||
model_loader_configs = [
|
model_loader_configs = [
|
||||||
@@ -150,6 +151,7 @@ model_loader_configs = [
|
|||||||
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
||||||
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
||||||
(None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
|
(None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
|
||||||
|
(None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"),
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
|
|||||||
111
diffsynth/models/flux_lora_encoder.py
Normal file
111
diffsynth/models/flux_lora_encoder.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import torch
|
||||||
|
from .sd_text_encoder import CLIPEncoderLayer
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALayerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, L, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
|
||||||
|
self.layer_norm = torch.nn.LayerNorm(dim_out)
|
||||||
|
|
||||||
|
def forward(self, lora_A, lora_B):
|
||||||
|
x = self.x @ lora_A.T @ lora_B.T
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAEmbedder(torch.nn.Module):
|
||||||
|
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
|
||||||
|
super().__init__()
|
||||||
|
if lora_patterns is None:
|
||||||
|
lora_patterns = self.default_lora_patterns()
|
||||||
|
|
||||||
|
model_dict = {}
|
||||||
|
for lora_pattern in lora_patterns:
|
||||||
|
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||||
|
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
|
||||||
|
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||||
|
|
||||||
|
proj_dict = {}
|
||||||
|
for lora_pattern in lora_patterns:
|
||||||
|
layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
|
||||||
|
if layer_type not in proj_dict:
|
||||||
|
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
|
||||||
|
self.proj_dict = torch.nn.ModuleDict(proj_dict)
|
||||||
|
|
||||||
|
self.lora_patterns = lora_patterns
|
||||||
|
|
||||||
|
|
||||||
|
def default_lora_patterns(self):
|
||||||
|
lora_patterns = []
|
||||||
|
lora_dict = {
|
||||||
|
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
|
||||||
|
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
|
||||||
|
}
|
||||||
|
for i in range(19):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix],
|
||||||
|
"type": suffix,
|
||||||
|
})
|
||||||
|
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
|
||||||
|
for i in range(38):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"single_blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix],
|
||||||
|
"type": suffix,
|
||||||
|
})
|
||||||
|
return lora_patterns
|
||||||
|
|
||||||
|
def forward(self, lora):
|
||||||
|
lora_emb = []
|
||||||
|
for lora_pattern in self.lora_patterns:
|
||||||
|
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||||
|
lora_A = lora[name + ".lora_A.default.weight"]
|
||||||
|
lora_B = lora[name + ".lora_B.default.weight"]
|
||||||
|
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||||
|
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||||
|
lora_emb.append(lora_out)
|
||||||
|
lora_emb = torch.concat(lora_emb, dim=1)
|
||||||
|
return lora_emb
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRAEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
|
||||||
|
super().__init__()
|
||||||
|
self.num_embeds_per_lora = num_embeds_per_lora
|
||||||
|
# embedder
|
||||||
|
self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
|
||||||
|
|
||||||
|
# encoders
|
||||||
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
|
||||||
|
|
||||||
|
# special embedding
|
||||||
|
self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
|
||||||
|
self.num_special_embeds = num_special_embeds
|
||||||
|
|
||||||
|
# final layer
|
||||||
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
|
||||||
|
|
||||||
|
def forward(self, lora):
|
||||||
|
lora_embeds = self.embedder(lora)
|
||||||
|
special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
|
||||||
|
embeds = torch.concat([special_embeds, lora_embeds], dim=1)
|
||||||
|
for encoder_id, encoder in enumerate(self.encoders):
|
||||||
|
embeds = encoder(embeds)
|
||||||
|
embeds = embeds[:, :self.num_special_embeds]
|
||||||
|
embeds = self.final_layer_norm(embeds)
|
||||||
|
embeds = self.final_linear(embeds)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxLoRAEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRAEncoderStateDictConverter:
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
@@ -71,7 +71,7 @@ def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
|||||||
|
|
||||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
with safe_open(file_path, framework="pt", device=device) as f:
|
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
state_dict[k] = f.get_tensor(k)
|
state_dict[k] = f.get_tensor(k)
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from ..models.flux_controlnet import FluxControlNet
|
|||||||
from ..models.flux_ipadapter import FluxIpAdapter
|
from ..models.flux_ipadapter import FluxIpAdapter
|
||||||
from ..models.flux_value_control import MultiValueEncoder
|
from ..models.flux_value_control import MultiValueEncoder
|
||||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||||
|
from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock
|
||||||
from ..models.tiler import FastTileWorker
|
from ..models.tiler import FastTileWorker
|
||||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||||
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
|
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
|
||||||
@@ -97,6 +98,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.infinityou_processor: InfinitYou = None
|
self.infinityou_processor: InfinitYou = None
|
||||||
self.image_proj_model: InfiniteYouImageProjector = None
|
self.image_proj_model: InfiniteYouImageProjector = None
|
||||||
self.lora_patcher: FluxLoraPatcher = None
|
self.lora_patcher: FluxLoraPatcher = None
|
||||||
|
self.lora_encoder: FluxLoRAEncoder = None
|
||||||
self.unit_runner = PipelineUnitRunner()
|
self.unit_runner = PipelineUnitRunner()
|
||||||
self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
|
self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
|
||||||
self.units = [
|
self.units = [
|
||||||
@@ -115,6 +117,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
FluxImageUnit_Flex(),
|
FluxImageUnit_Flex(),
|
||||||
FluxImageUnit_Step1x(),
|
FluxImageUnit_Step1x(),
|
||||||
FluxImageUnit_ValueControl(),
|
FluxImageUnit_ValueControl(),
|
||||||
|
FluxImageUnit_LoRAEncode(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_flux_image
|
self.model_fn = model_fn_flux_image
|
||||||
|
|
||||||
@@ -196,6 +199,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
torch.nn.Conv2d: AutoWrappedModule,
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
torch.nn.GroupNorm: AutoWrappedModule,
|
torch.nn.GroupNorm: AutoWrappedModule,
|
||||||
RMSNorm: AutoWrappedModule,
|
RMSNorm: AutoWrappedModule,
|
||||||
|
LoRALayerBlock: AutoWrappedModule,
|
||||||
},
|
},
|
||||||
module_config = dict(
|
module_config = dict(
|
||||||
offload_dtype=dtype,
|
offload_dtype=dtype,
|
||||||
@@ -207,6 +211,33 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
),
|
),
|
||||||
vram_limit=vram_limit,
|
vram_limit=vram_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_lora_magic(self):
|
||||||
|
if self.dit is not None:
|
||||||
|
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
|
||||||
|
dtype = next(iter(self.dit.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.dit,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device=self.device,
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device=self.device,
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
vram_limit=None,
|
||||||
|
)
|
||||||
|
if self.lora_patcher is not None:
|
||||||
|
for name, module in self.dit.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
merger_name = name.replace(".", "___")
|
||||||
|
if merger_name in self.lora_patcher.model_dict:
|
||||||
|
module.lora_merger = self.lora_patcher.model_dict[merger_name]
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
||||||
@@ -219,7 +250,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
vram_limit = vram_limit - vram_buffer
|
vram_limit = vram_limit - vram_buffer
|
||||||
|
|
||||||
# Default config
|
# Default config
|
||||||
default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector"]
|
default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector", "lora_encoder"]
|
||||||
for model_name in default_vram_management_models:
|
for model_name in default_vram_management_models:
|
||||||
self._enable_vram_management_with_default_config(getattr(self, model_name), vram_limit)
|
self._enable_vram_management_with_default_config(getattr(self, model_name), vram_limit)
|
||||||
|
|
||||||
@@ -366,6 +397,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
if pipe.image_proj_model is not None:
|
if pipe.image_proj_model is not None:
|
||||||
pipe.infinityou_processor = InfinitYou(device=device)
|
pipe.infinityou_processor = InfinitYou(device=device)
|
||||||
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
|
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
|
||||||
|
pipe.lora_encoder = model_manager.fetch_model("flux_lora_encoder")
|
||||||
|
|
||||||
# ControlNet
|
# ControlNet
|
||||||
controlnets = []
|
controlnets = []
|
||||||
@@ -437,6 +469,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
value_controller_inputs: list[float] = None,
|
value_controller_inputs: list[float] = None,
|
||||||
# Step1x
|
# Step1x
|
||||||
step1x_reference_image: Image.Image = None,
|
step1x_reference_image: Image.Image = None,
|
||||||
|
# LoRA Encoder
|
||||||
|
lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None,
|
||||||
|
lora_encoder_scale: float = 1.0,
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_l1_thresh: float = None,
|
tea_cache_l1_thresh: float = None,
|
||||||
# Tile
|
# Tile
|
||||||
@@ -470,6 +505,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
|
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
|
||||||
"value_controller_inputs": value_controller_inputs,
|
"value_controller_inputs": value_controller_inputs,
|
||||||
"step1x_reference_image": step1x_reference_image,
|
"step1x_reference_image": step1x_reference_image,
|
||||||
|
"lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale,
|
||||||
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
||||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||||
"progress_bar_cmd": progress_bar_cmd,
|
"progress_bar_cmd": progress_bar_cmd,
|
||||||
@@ -884,6 +920,66 @@ class InfinitYou(torch.nn.Module):
|
|||||||
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}
|
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxImageUnit_LoRAEncode(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
take_over=True,
|
||||||
|
onload_model_names=("lora_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_lora_encoder_inputs(self, lora_encoder_inputs):
|
||||||
|
if not isinstance(lora_encoder_inputs, list):
|
||||||
|
lora_encoder_inputs = [lora_encoder_inputs]
|
||||||
|
lora_configs = []
|
||||||
|
for lora_encoder_input in lora_encoder_inputs:
|
||||||
|
if isinstance(lora_encoder_input, str):
|
||||||
|
lora_encoder_input = ModelConfig(path=lora_encoder_input)
|
||||||
|
lora_encoder_input.download_if_necessary()
|
||||||
|
lora_configs.append(lora_encoder_input)
|
||||||
|
return lora_configs
|
||||||
|
|
||||||
|
def load_lora(self, lora_config, dtype, device):
|
||||||
|
loader = FluxLoRALoader(torch_dtype=dtype, device=device)
|
||||||
|
lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device)
|
||||||
|
lora = loader.convert_state_dict(lora)
|
||||||
|
return lora
|
||||||
|
|
||||||
|
def lora_embedding(self, pipe, lora_encoder_inputs):
|
||||||
|
lora_emb = []
|
||||||
|
for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs):
|
||||||
|
lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device)
|
||||||
|
lora_emb.append(pipe.lora_encoder(lora))
|
||||||
|
lora_emb = torch.concat(lora_emb, dim=1)
|
||||||
|
return lora_emb
|
||||||
|
|
||||||
|
def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb):
|
||||||
|
prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1)
|
||||||
|
extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype)
|
||||||
|
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
|
||||||
|
return prompt_emb, text_ids
|
||||||
|
|
||||||
|
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
if inputs_shared.get("lora_encoder_inputs", None) is None:
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
# Encode
|
||||||
|
pipe.load_models_to_device(["lora_encoder"])
|
||||||
|
lora_encoder_inputs = inputs_shared["lora_encoder_inputs"]
|
||||||
|
lora_emb = self.lora_embedding(pipe, lora_encoder_inputs)
|
||||||
|
|
||||||
|
# Scale
|
||||||
|
lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None)
|
||||||
|
if lora_encoder_scale is not None:
|
||||||
|
lora_emb = lora_emb * lora_encoder_scale
|
||||||
|
|
||||||
|
# Add to prompt embedding
|
||||||
|
inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding(
|
||||||
|
inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TeaCache:
|
class TeaCache:
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
def __init__(self, num_inference_steps, rel_l1_thresh):
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|||||||
@@ -165,6 +165,7 @@ class ModelConfig:
|
|||||||
download_resource: str = "ModelScope"
|
download_resource: str = "ModelScope"
|
||||||
offload_device: Optional[Union[str, torch.device]] = None
|
offload_device: Optional[Union[str, torch.device]] = None
|
||||||
offload_dtype: Optional[torch.dtype] = None
|
offload_dtype: Optional[torch.dtype] = None
|
||||||
|
skip_download: bool = False
|
||||||
|
|
||||||
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
|
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
|
||||||
if self.path is None:
|
if self.path is None:
|
||||||
@@ -190,6 +191,7 @@ class ModelConfig:
|
|||||||
is_folder = False
|
is_folder = False
|
||||||
|
|
||||||
# Download
|
# Download
|
||||||
|
skip_download = skip_download or self.skip_download
|
||||||
if not skip_download:
|
if not skip_download:
|
||||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
|
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
|
|||||||
40
examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py
Normal file
40
examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.enable_lora_magic()
|
||||||
|
|
||||||
|
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
||||||
|
pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
||||||
|
|
||||||
|
# Empty prompt can automatically activate LoRA capabilities.
|
||||||
|
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
||||||
|
image.save("image_1.jpg")
|
||||||
|
|
||||||
|
image = pipe(prompt="", seed=0)
|
||||||
|
image.save("image_1_origin.jpg")
|
||||||
|
|
||||||
|
# Prompt without trigger words can also activate LoRA capabilities.
|
||||||
|
image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora)
|
||||||
|
image.save("image_2.jpg")
|
||||||
|
|
||||||
|
image = pipe(prompt="a car", seed=0,)
|
||||||
|
image.save("image_2_origin.jpg")
|
||||||
|
|
||||||
|
# Adjust the activation intensity through the scale parameter.
|
||||||
|
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0)
|
||||||
|
image.save("image_3.jpg")
|
||||||
|
|
||||||
|
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5)
|
||||||
|
image.save("image_3_scale.jpg")
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
pipe.enable_lora_magic()
|
||||||
|
|
||||||
|
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
||||||
|
pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
||||||
|
|
||||||
|
# Empty prompt can automatically activate LoRA capabilities.
|
||||||
|
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
||||||
|
image.save("image_1.jpg")
|
||||||
|
|
||||||
|
image = pipe(prompt="", seed=0)
|
||||||
|
image.save("image_1_origin.jpg")
|
||||||
|
|
||||||
|
# Prompt without trigger words can also activate LoRA capabilities.
|
||||||
|
image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora)
|
||||||
|
image.save("image_2.jpg")
|
||||||
|
|
||||||
|
image = pipe(prompt="a car", seed=0,)
|
||||||
|
image.save("image_2_origin.jpg")
|
||||||
|
|
||||||
|
# Adjust the activation intensity through the scale parameter.
|
||||||
|
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0)
|
||||||
|
image.save("image_3.jpg")
|
||||||
|
|
||||||
|
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5)
|
||||||
|
image.save("image_3_scale.jpg")
|
||||||
Reference in New Issue
Block a user