diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index f7dc5ab..b731bc8 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -3,7 +3,7 @@ import torch import numpy as np from einops import repeat, reduce from typing import Union -from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type +from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type, enable_vram_management from ..core.device.npu_compatible_device import get_device_type from ..utils.lora import GeneralLoRALoader from ..models.model_loader import ModelPool @@ -354,6 +354,23 @@ class BasePipeline(torch.nn.Module): self.template_model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device) self.template_data_processor = load_template_data_processor(model_config.path)() + + def enable_lora_hot_loading(self, model: torch.nn.Module): + if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"): + return model + module_map = {torch.nn.Linear: AutoWrappedLinear} + vram_config = { + "offload_dtype": self.torch_dtype, + "offload_device": self.device, + "onload_dtype": self.torch_dtype, + "onload_device": self.device, + "preparing_dtype": self.torch_dtype, + "preparing_device": self.device, + "computation_dtype": self.torch_dtype, + "computation_device": self.device, + } + model = enable_vram_management(model, module_map, vram_config=vram_config) + return model class PipelineUnitGraph: diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index 10ad3a0..ee3e988 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -3,6 +3,11 @@ import torch def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): + if "lora" in inputs: + # Image-to-LoRA models need to load lora here. + pipe.clear_lora(verbose=0) + pipe.load_lora(pipe.dit, state_dict=inputs["lora"], hotload=True, verbose=0) + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py index 3bcbe4b..cf72460 100644 --- a/diffsynth/diffusion/parsers.py +++ b/diffsynth/diffusion/parsers.py @@ -62,6 +62,7 @@ def add_gradient_config(parser: argparse.ArgumentParser): def add_template_model_config(parser: argparse.ArgumentParser): parser.add_argument("--template_model_id_or_path", type=str, default=None, help="Model ID of path of template models.") + parser.add_argument("--enable_lora_hot_loading", default=False, action="store_true", help="Whether to enable LoRA hot-loading. Only available for image-to-lora models.") return parser def add_general_config(parser: argparse.ArgumentParser): diff --git a/diffsynth/diffusion/template.py b/diffsynth/diffusion/template.py index 6b9a53f..9277dd7 100644 --- a/diffsynth/diffusion/template.py +++ b/diffsynth/diffusion/template.py @@ -2,6 +2,7 @@ import torch, os, importlib, warnings, json, inspect from typing import Dict, List, Tuple, Union from ..core import ModelConfig, load_model from ..core.device.npu_compatible_device import get_device_type +from ..utils.lora.merge import merge_lora KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]] @@ -74,9 +75,28 @@ def load_template_data_processor(path): class TemplatePipeline(torch.nn.Module): - def __init__(self, models: List[TemplateModel]): + def __init__( + self, + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + lazy_loading: bool = False, + ): super().__init__() - self.models = torch.nn.ModuleList(models) + self.torch_dtype = torch_dtype + self.device = device + self.model_configs = model_configs + self.lazy_loading = lazy_loading + if lazy_loading: + self.models = None + else: + models = [] + for model_config in model_configs: + TemplatePipeline.check_vram_config(model_config) + model_config.download_if_necessary() + model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device) + models.append(model) + self.models = torch.nn.ModuleList(models) def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache: names = {} @@ -100,6 +120,8 @@ class TemplatePipeline(torch.nn.Module): data = [template_cache[param] for template_cache in template_cache_list if param in template_cache] if param == "kv_cache": data = self.merge_kv_cache(data) + elif param == "lora": + data = merge_lora(data) elif len(data) == 1: data = data[0] else: @@ -125,30 +147,32 @@ class TemplatePipeline(torch.nn.Module): torch_dtype: torch.dtype = torch.bfloat16, device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], + lazy_loading: bool = False, ): - models = [] - for model_config in model_configs: - TemplatePipeline.check_vram_config(model_config) - model_config.download_if_necessary() - model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device) - models.append(model) - pipe = TemplatePipeline(models) + pipe = TemplatePipeline(torch_dtype, device, model_configs, lazy_loading) return pipe - @torch.no_grad() - def process_inputs(self, inputs: List[Dict], pipe=None, **kwargs): - return [(i.get("model_id", 0), self.models[i.get("model_id", 0)].process_inputs(pipe=pipe, **i)) for i in inputs] - - def forward(self, inputs: List[Tuple[int, Dict]], pipe=None, **kwargs): - template_cache = [] - for model_id, model_inputs in inputs: - kv_cache = self.models[model_id](pipe=pipe, **model_inputs) - template_cache.append(kv_cache) - return template_cache + def fetch_model(self, model_id): + if self.lazy_loading: + model_config = self.model_configs[model_id] + model_config.download_if_necessary() + model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device) + else: + model = self.models[model_id] + return model def call_single_side(self, pipe=None, inputs: List[Dict] = None): - inputs = self.process_inputs(pipe=pipe, inputs=inputs) - template_cache = self.forward(pipe=pipe, inputs=inputs) + model = None + onload_model_id = -1 + template_cache = [] + for i in inputs: + model_id = i.get("model_id", 0) + if model_id != onload_model_id: + model = self.fetch_model(model_id) + onload_model_id = model_id + cache = model.process_inputs(pipe=pipe, **i) + cache = model.forward(pipe=pipe, **cache) + template_cache.append(cache) template_cache = self.merge_template_cache(template_cache) return template_cache diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py index e1d3852..844573f 100644 --- a/diffsynth/diffusion/training_module.py +++ b/diffsynth/diffusion/training_module.py @@ -38,7 +38,7 @@ class GeneralUnit_TemplateProcessInputs(PipelineUnit): self.data_processor = data_processor def process(self, pipe, template_inputs): - if not hasattr(pipe, "template_model"): + if not hasattr(pipe, "template_model") or template_inputs is None: return {} if self.data_processor is not None: template_inputs = self.data_processor(**template_inputs) @@ -58,7 +58,7 @@ class GeneralUnit_TemplateForward(PipelineUnit): self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload def process(self, pipe, template_inputs): - if not hasattr(pipe, "template_model"): + if not hasattr(pipe, "template_model") or template_inputs is None: return {} template_cache = pipe.template_model.forward( **template_inputs, diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 4bec241..a3c7694 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -100,6 +100,9 @@ class Flux2ImagePipeline(BasePipeline): # LoRA lora = None, negative_lora = None, + # Text Embedding + extra_text_embedding = None, + negative_extra_text_embedding = None, # Inpaint inpaint_mask: Image.Image = None, inpaint_blur_size: int = None, @@ -113,10 +116,12 @@ class Flux2ImagePipeline(BasePipeline): inputs_posi = { "prompt": prompt, "kv_cache": kv_cache, + "extra_text_embedding": extra_text_embedding, } inputs_nega = { "negative_prompt": negative_prompt, "kv_cache": negative_kv_cache, + "extra_text_embedding": negative_extra_text_embedding, } inputs_shared = { "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, @@ -607,6 +612,7 @@ def model_fn_flux2( edit_latents=None, edit_image_ids=None, kv_cache=None, + extra_text_embedding=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, @@ -617,6 +623,11 @@ def model_fn_flux2( latents = torch.concat([latents, edit_latents], dim=1) image_ids = torch.concat([image_ids, edit_image_ids], dim=1) embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) + if extra_text_embedding is not None: + extra_text_ids = torch.zeros((1, extra_text_embedding.shape[1], 4), dtype=text_ids.dtype, device=text_ids.device) + extra_text_ids[:, :, -1] = torch.arange(prompt_embeds.shape[1], prompt_embeds.shape[1] + extra_text_embedding.shape[1]) + prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) model_output = dit( hidden_states=latents, timestep=timestep / 1000, diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index 5378da4..144f41e 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -19,6 +19,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule): fp8_models=None, offload_models=None, template_model_id_or_path=None, + enable_lora_hot_loading=False, device="cpu", task="sft", ): @@ -29,6 +30,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule): self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) self.pipe = self.load_training_template_model(self.pipe, template_model_id_or_path, args.use_gradient_checkpointing, args.use_gradient_checkpointing_offload) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) + if enable_lora_hot_loading: self.pipe.dit = self.pipe.enable_lora_hot_loading(self.pipe.dit) # Training mode self.switch_pipe_to_training_mode( @@ -129,6 +131,7 @@ if __name__ == "__main__": fp8_models=args.fp8_models, offload_models=args.offload_models, template_model_id_or_path=args.template_model_id_or_path, + enable_lora_hot_loading=args.enable_lora_hot_loading, task=args.task, device="cpu" if args.initialize_model_on_cpu else accelerator.device, )