mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-15 14:48:23 +00:00
update template framework
This commit is contained in:
@@ -3,7 +3,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import repeat, reduce
|
from einops import repeat, reduce
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type, enable_vram_management
|
||||||
from ..core.device.npu_compatible_device import get_device_type
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
from ..utils.lora import GeneralLoRALoader
|
from ..utils.lora import GeneralLoRALoader
|
||||||
from ..models.model_loader import ModelPool
|
from ..models.model_loader import ModelPool
|
||||||
@@ -355,6 +355,23 @@ class BasePipeline(torch.nn.Module):
|
|||||||
self.template_data_processor = load_template_data_processor(model_config.path)()
|
self.template_data_processor = load_template_data_processor(model_config.path)()
|
||||||
|
|
||||||
|
|
||||||
|
def enable_lora_hot_loading(self, model: torch.nn.Module):
|
||||||
|
if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"):
|
||||||
|
return model
|
||||||
|
module_map = {torch.nn.Linear: AutoWrappedLinear}
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": self.torch_dtype,
|
||||||
|
"offload_device": self.device,
|
||||||
|
"onload_dtype": self.torch_dtype,
|
||||||
|
"onload_device": self.device,
|
||||||
|
"preparing_dtype": self.torch_dtype,
|
||||||
|
"preparing_device": self.device,
|
||||||
|
"computation_dtype": self.torch_dtype,
|
||||||
|
"computation_device": self.device,
|
||||||
|
}
|
||||||
|
model = enable_vram_management(model, module_map, vram_config=vram_config)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class PipelineUnitGraph:
|
class PipelineUnitGraph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -3,6 +3,11 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||||
|
if "lora" in inputs:
|
||||||
|
# Image-to-LoRA models need to load lora here.
|
||||||
|
pipe.clear_lora(verbose=0)
|
||||||
|
pipe.load_lora(pipe.dit, state_dict=inputs["lora"], hotload=True, verbose=0)
|
||||||
|
|
||||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||||
|
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ def add_gradient_config(parser: argparse.ArgumentParser):
|
|||||||
|
|
||||||
def add_template_model_config(parser: argparse.ArgumentParser):
|
def add_template_model_config(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument("--template_model_id_or_path", type=str, default=None, help="Model ID of path of template models.")
|
parser.add_argument("--template_model_id_or_path", type=str, default=None, help="Model ID of path of template models.")
|
||||||
|
parser.add_argument("--enable_lora_hot_loading", default=False, action="store_true", help="Whether to enable LoRA hot-loading. Only available for image-to-lora models.")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def add_general_config(parser: argparse.ArgumentParser):
|
def add_general_config(parser: argparse.ArgumentParser):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import torch, os, importlib, warnings, json, inspect
|
|||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
from ..core import ModelConfig, load_model
|
from ..core import ModelConfig, load_model
|
||||||
from ..core.device.npu_compatible_device import get_device_type
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..utils.lora.merge import merge_lora
|
||||||
|
|
||||||
|
|
||||||
KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]]
|
KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]]
|
||||||
@@ -74,8 +75,27 @@ def load_template_data_processor(path):
|
|||||||
|
|
||||||
|
|
||||||
class TemplatePipeline(torch.nn.Module):
|
class TemplatePipeline(torch.nn.Module):
|
||||||
def __init__(self, models: List[TemplateModel]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
lazy_loading: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
self.device = device
|
||||||
|
self.model_configs = model_configs
|
||||||
|
self.lazy_loading = lazy_loading
|
||||||
|
if lazy_loading:
|
||||||
|
self.models = None
|
||||||
|
else:
|
||||||
|
models = []
|
||||||
|
for model_config in model_configs:
|
||||||
|
TemplatePipeline.check_vram_config(model_config)
|
||||||
|
model_config.download_if_necessary()
|
||||||
|
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
|
||||||
|
models.append(model)
|
||||||
self.models = torch.nn.ModuleList(models)
|
self.models = torch.nn.ModuleList(models)
|
||||||
|
|
||||||
def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache:
|
def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache:
|
||||||
@@ -100,6 +120,8 @@ class TemplatePipeline(torch.nn.Module):
|
|||||||
data = [template_cache[param] for template_cache in template_cache_list if param in template_cache]
|
data = [template_cache[param] for template_cache in template_cache_list if param in template_cache]
|
||||||
if param == "kv_cache":
|
if param == "kv_cache":
|
||||||
data = self.merge_kv_cache(data)
|
data = self.merge_kv_cache(data)
|
||||||
|
elif param == "lora":
|
||||||
|
data = merge_lora(data)
|
||||||
elif len(data) == 1:
|
elif len(data) == 1:
|
||||||
data = data[0]
|
data = data[0]
|
||||||
else:
|
else:
|
||||||
@@ -125,30 +147,32 @@ class TemplatePipeline(torch.nn.Module):
|
|||||||
torch_dtype: torch.dtype = torch.bfloat16,
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
device: Union[str, torch.device] = get_device_type(),
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
model_configs: list[ModelConfig] = [],
|
model_configs: list[ModelConfig] = [],
|
||||||
|
lazy_loading: bool = False,
|
||||||
):
|
):
|
||||||
models = []
|
pipe = TemplatePipeline(torch_dtype, device, model_configs, lazy_loading)
|
||||||
for model_config in model_configs:
|
|
||||||
TemplatePipeline.check_vram_config(model_config)
|
|
||||||
model_config.download_if_necessary()
|
|
||||||
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
|
|
||||||
models.append(model)
|
|
||||||
pipe = TemplatePipeline(models)
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
@torch.no_grad()
|
def fetch_model(self, model_id):
|
||||||
def process_inputs(self, inputs: List[Dict], pipe=None, **kwargs):
|
if self.lazy_loading:
|
||||||
return [(i.get("model_id", 0), self.models[i.get("model_id", 0)].process_inputs(pipe=pipe, **i)) for i in inputs]
|
model_config = self.model_configs[model_id]
|
||||||
|
model_config.download_if_necessary()
|
||||||
def forward(self, inputs: List[Tuple[int, Dict]], pipe=None, **kwargs):
|
model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
template_cache = []
|
else:
|
||||||
for model_id, model_inputs in inputs:
|
model = self.models[model_id]
|
||||||
kv_cache = self.models[model_id](pipe=pipe, **model_inputs)
|
return model
|
||||||
template_cache.append(kv_cache)
|
|
||||||
return template_cache
|
|
||||||
|
|
||||||
def call_single_side(self, pipe=None, inputs: List[Dict] = None):
|
def call_single_side(self, pipe=None, inputs: List[Dict] = None):
|
||||||
inputs = self.process_inputs(pipe=pipe, inputs=inputs)
|
model = None
|
||||||
template_cache = self.forward(pipe=pipe, inputs=inputs)
|
onload_model_id = -1
|
||||||
|
template_cache = []
|
||||||
|
for i in inputs:
|
||||||
|
model_id = i.get("model_id", 0)
|
||||||
|
if model_id != onload_model_id:
|
||||||
|
model = self.fetch_model(model_id)
|
||||||
|
onload_model_id = model_id
|
||||||
|
cache = model.process_inputs(pipe=pipe, **i)
|
||||||
|
cache = model.forward(pipe=pipe, **cache)
|
||||||
|
template_cache.append(cache)
|
||||||
template_cache = self.merge_template_cache(template_cache)
|
template_cache = self.merge_template_cache(template_cache)
|
||||||
return template_cache
|
return template_cache
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class GeneralUnit_TemplateProcessInputs(PipelineUnit):
|
|||||||
self.data_processor = data_processor
|
self.data_processor = data_processor
|
||||||
|
|
||||||
def process(self, pipe, template_inputs):
|
def process(self, pipe, template_inputs):
|
||||||
if not hasattr(pipe, "template_model"):
|
if not hasattr(pipe, "template_model") or template_inputs is None:
|
||||||
return {}
|
return {}
|
||||||
if self.data_processor is not None:
|
if self.data_processor is not None:
|
||||||
template_inputs = self.data_processor(**template_inputs)
|
template_inputs = self.data_processor(**template_inputs)
|
||||||
@@ -58,7 +58,7 @@ class GeneralUnit_TemplateForward(PipelineUnit):
|
|||||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
|
||||||
def process(self, pipe, template_inputs):
|
def process(self, pipe, template_inputs):
|
||||||
if not hasattr(pipe, "template_model"):
|
if not hasattr(pipe, "template_model") or template_inputs is None:
|
||||||
return {}
|
return {}
|
||||||
template_cache = pipe.template_model.forward(
|
template_cache = pipe.template_model.forward(
|
||||||
**template_inputs,
|
**template_inputs,
|
||||||
|
|||||||
@@ -100,6 +100,9 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
# LoRA
|
# LoRA
|
||||||
lora = None,
|
lora = None,
|
||||||
negative_lora = None,
|
negative_lora = None,
|
||||||
|
# Text Embedding
|
||||||
|
extra_text_embedding = None,
|
||||||
|
negative_extra_text_embedding = None,
|
||||||
# Inpaint
|
# Inpaint
|
||||||
inpaint_mask: Image.Image = None,
|
inpaint_mask: Image.Image = None,
|
||||||
inpaint_blur_size: int = None,
|
inpaint_blur_size: int = None,
|
||||||
@@ -113,10 +116,12 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
inputs_posi = {
|
inputs_posi = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"kv_cache": kv_cache,
|
"kv_cache": kv_cache,
|
||||||
|
"extra_text_embedding": extra_text_embedding,
|
||||||
}
|
}
|
||||||
inputs_nega = {
|
inputs_nega = {
|
||||||
"negative_prompt": negative_prompt,
|
"negative_prompt": negative_prompt,
|
||||||
"kv_cache": negative_kv_cache,
|
"kv_cache": negative_kv_cache,
|
||||||
|
"extra_text_embedding": negative_extra_text_embedding,
|
||||||
}
|
}
|
||||||
inputs_shared = {
|
inputs_shared = {
|
||||||
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
|
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
|
||||||
@@ -607,6 +612,7 @@ def model_fn_flux2(
|
|||||||
edit_latents=None,
|
edit_latents=None,
|
||||||
edit_image_ids=None,
|
edit_image_ids=None,
|
||||||
kv_cache=None,
|
kv_cache=None,
|
||||||
|
extra_text_embedding=None,
|
||||||
use_gradient_checkpointing=False,
|
use_gradient_checkpointing=False,
|
||||||
use_gradient_checkpointing_offload=False,
|
use_gradient_checkpointing_offload=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -617,6 +623,11 @@ def model_fn_flux2(
|
|||||||
latents = torch.concat([latents, edit_latents], dim=1)
|
latents = torch.concat([latents, edit_latents], dim=1)
|
||||||
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
|
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
|
||||||
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
|
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
|
||||||
|
if extra_text_embedding is not None:
|
||||||
|
extra_text_ids = torch.zeros((1, extra_text_embedding.shape[1], 4), dtype=text_ids.dtype, device=text_ids.device)
|
||||||
|
extra_text_ids[:, :, -1] = torch.arange(prompt_embeds.shape[1], prompt_embeds.shape[1] + extra_text_embedding.shape[1])
|
||||||
|
prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1)
|
||||||
|
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
|
||||||
model_output = dit(
|
model_output = dit(
|
||||||
hidden_states=latents,
|
hidden_states=latents,
|
||||||
timestep=timestep / 1000,
|
timestep=timestep / 1000,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
|
|||||||
fp8_models=None,
|
fp8_models=None,
|
||||||
offload_models=None,
|
offload_models=None,
|
||||||
template_model_id_or_path=None,
|
template_model_id_or_path=None,
|
||||||
|
enable_lora_hot_loading=False,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
task="sft",
|
task="sft",
|
||||||
):
|
):
|
||||||
@@ -29,6 +30,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
|
|||||||
self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||||
self.pipe = self.load_training_template_model(self.pipe, template_model_id_or_path, args.use_gradient_checkpointing, args.use_gradient_checkpointing_offload)
|
self.pipe = self.load_training_template_model(self.pipe, template_model_id_or_path, args.use_gradient_checkpointing, args.use_gradient_checkpointing_offload)
|
||||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
if enable_lora_hot_loading: self.pipe.dit = self.pipe.enable_lora_hot_loading(self.pipe.dit)
|
||||||
|
|
||||||
# Training mode
|
# Training mode
|
||||||
self.switch_pipe_to_training_mode(
|
self.switch_pipe_to_training_mode(
|
||||||
@@ -129,6 +131,7 @@ if __name__ == "__main__":
|
|||||||
fp8_models=args.fp8_models,
|
fp8_models=args.fp8_models,
|
||||||
offload_models=args.offload_models,
|
offload_models=args.offload_models,
|
||||||
template_model_id_or_path=args.template_model_id_or_path,
|
template_model_id_or_path=args.template_model_id_or_path,
|
||||||
|
enable_lora_hot_loading=args.enable_lora_hot_loading,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user