rearrange lora loading modules

This commit is contained in:
Artiprocher
2025-07-24 18:56:25 +08:00
parent 783c435d88
commit 4ad6bd4e23
5 changed files with 354 additions and 302 deletions

View File

@@ -22,8 +22,8 @@ from ..models.flux_value_control import MultiValueEncoder
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock
from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher, FluxLoRAFuser
from ..models.flux_dit import RMSNorm
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
@@ -125,18 +125,20 @@ class FluxImagePipeline(BasePipeline):
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str],
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=False,
local_model_path="./models",
skip_download=False
state_dict=None,
):
if isinstance(lora_config, str):
lora_config = ModelConfig(path=lora_config)
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary(local_model_path, skip_download=skip_download)
lora = state_dict
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
lora = loader.convert_state_dict(lora)
if hotload:
for name, module in module.named_modules():
@@ -150,19 +152,21 @@ class FluxImagePipeline(BasePipeline):
loader.load(module, lora, alpha=alpha)
def enable_lora_patcher(self):
if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled):
print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.")
return
if self.lora_patcher is None:
print("Please load lora patcher models before `enable_lora_patcher()`.")
return
for name, module in self.dit.named_modules():
if isinstance(module, AutoWrappedLinear):
merger_name = name.replace(".", "___")
if merger_name in self.lora_patcher.model_dict:
module.lora_merger = self.lora_patcher.model_dict[merger_name]
def load_loras(
self,
module: torch.nn.Module,
lora_configs: list[Union[ModelConfig, str]],
alpha=1,
hotload=False,
extra_fused_lora=False,
):
for lora_config in lora_configs:
self.load_lora(module, lora_config, hotload=hotload, alpha=alpha)
if extra_fused_lora:
lora_fuser = FluxLoRAFuser(device="cuda", torch_dtype=torch.bfloat16)
fused_lora = lora_fuser(lora_configs)
self.load_lora(module, state_dict=fused_lora, hotload=hotload, alpha=alpha)
def clear_lora(self):
for name, module in self.named_modules():
@@ -365,16 +369,11 @@ class FluxImagePipeline(BasePipeline):
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
local_model_path: str = "./models",
skip_download: bool = False,
redirect_common_files: bool = True,
use_usp=False,
):
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary(local_model_path, skip_download=skip_download)
model_config.download_if_necessary()
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,

View File

@@ -12,6 +12,7 @@ from tqdm import tqdm
from typing import Optional
from typing_extensions import Literal
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
from ..models import ModelManager, load_state_dict
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
@@ -26,196 +27,6 @@ from ..lora import GeneralLoRALoader
class BasePipeline(torch.nn.Module):
def __init__(
self,
device="cuda", torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
super().__init__()
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.vram_management_enabled = False
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None):
# Shape check
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if num_frames is None:
return height, width
else:
if num_frames % self.time_division_factor != self.time_division_remainder:
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
return height, width, num_frames
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
# Transform a PIL.Image to torch.Tensor
image = torch.Tensor(np.array(image, dtype=np.float32))
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
image = image * ((max_value - min_value) / 255) + min_value
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
return image
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a list of PIL.Image to torch.Tensor
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
video = torch.stack(video, dim=pattern.index("T") // 2)
return video
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to PIL.Image
if pattern != "H W C":
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
image = image.to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(image.numpy())
return image
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to list of PIL.Image
if pattern != "T H W C":
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
return video
def load_models_to_device(self, model_names=[]):
if self.vram_management_enabled:
# offload models
for name, model in self.named_children():
if name not in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
torch.cuda.empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
# Initialize Gaussian noise
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
return noise
def enable_cpu_offload(self):
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
self.vram_management_enabled = True
def get_vram(self):
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
def freeze_except(self, model_names):
for name, model in self.named_children():
if name in model_names:
model.train()
model.requires_grad_(True)
else:
model.eval()
model.requires_grad_(False)
@dataclass
class ModelConfig:
path: Union[str, list[str]] = None
model_id: str = None
origin_file_pattern: Union[str, list[str]] = None
download_resource: str = "ModelScope"
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
skip_download: bool = False
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
if self.path is None:
# Check model_id and origin_file_pattern
if self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
# Skip if not in rank 0
if use_usp:
import torch.distributed as dist
skip_download = dist.get_rank() != 0
# Check whether the origin path is a folder
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.origin_file_pattern = ""
allow_file_pattern = None
is_folder = True
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
allow_file_pattern = self.origin_file_pattern + "*"
is_folder = True
else:
allow_file_pattern = self.origin_file_pattern
is_folder = False
# Download
skip_download = skip_download or self.skip_download
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
snapshot_download(
self.model_id,
local_dir=os.path.join(local_model_path, self.model_id),
allow_file_pattern=allow_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
# Let rank 1, 2, ... wait for rank 0
if use_usp:
import torch.distributed as dist
dist.barrier(device_ids=[dist.get_rank()])
# Return downloaded files
if is_folder:
self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern)
else:
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
@@ -438,8 +249,6 @@ class WanVideoPipeline(BasePipeline):
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
local_model_path: str = "./models",
skip_download: bool = False,
redirect_common_files: bool = True,
use_usp=False,
):
@@ -464,7 +273,7 @@ class WanVideoPipeline(BasePipeline):
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary(local_model_path, skip_download=skip_download, use_usp=use_usp)
model_config.download_if_necessary(use_usp=use_usp)
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
@@ -480,7 +289,7 @@ class WanVideoPipeline(BasePipeline):
pipe.vace = model_manager.fetch_model("wan_video_vace")
# Initialize tokenizer
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
tokenizer_config.download_if_necessary(use_usp=use_usp)
pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
@@ -606,63 +415,6 @@ class WanVideoPipeline(BasePipeline):
class PipelineUnit:
def __init__(
self,
seperate_cfg: bool = False,
take_over: bool = False,
input_params: tuple[str] = None,
input_params_posi: dict[str, str] = None,
input_params_nega: dict[str, str] = None,
onload_model_names: tuple[str] = None
):
self.seperate_cfg = seperate_cfg
self.take_over = take_over
self.input_params = input_params
self.input_params_posi = input_params_posi
self.input_params_nega = input_params_nega
self.onload_model_names = onload_model_names
def process(self, pipe: WanVideoPipeline, inputs: dict, positive=True, **kwargs) -> dict:
raise NotImplementedError("`process` is not implemented.")
class PipelineUnitRunner:
def __init__(self):
pass
def __call__(self, unit: PipelineUnit, pipe: WanVideoPipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
if unit.take_over:
# Let the pipeline unit take over this function.
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
elif unit.seperate_cfg:
# Positive side
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_posi.update(processor_outputs)
# Negative side
if inputs_shared["cfg_scale"] != 1:
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_nega.update(processor_outputs)
else:
inputs_nega.update(processor_outputs)
else:
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_shared.update(processor_outputs)
return inputs_shared, inputs_posi, inputs_nega
class WanVideoUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "num_frames"))