mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
rearrange lora loading modules
This commit is contained in:
261
diffsynth/utils/__init__.py
Normal file
261
diffsynth/utils/__init__.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import torch, warnings, glob, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import repeat, reduce
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cuda", torch_dtype=torch.float16,
|
||||
height_division_factor=64, width_division_factor=64,
|
||||
time_division_factor=None, time_division_remainder=None,
|
||||
):
|
||||
super().__init__()
|
||||
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# The following parameters are used for shape check.
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
self.vram_management_enabled = False
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width, num_frames=None):
|
||||
# Shape check
|
||||
if height % self.height_division_factor != 0:
|
||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||
if width % self.width_division_factor != 0:
|
||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||
if num_frames is None:
|
||||
return height, width
|
||||
else:
|
||||
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||
return height, width, num_frames
|
||||
|
||||
|
||||
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a PIL.Image to torch.Tensor
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32))
|
||||
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
image = image * ((max_value - min_value) / 255) + min_value
|
||||
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a list of PIL.Image to torch.Tensor
|
||||
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
||||
video = torch.stack(video, dim=pattern.index("T") // 2)
|
||||
return video
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to PIL.Image
|
||||
if pattern != "H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
||||
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
||||
image = image.to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(image.numpy())
|
||||
return image
|
||||
|
||||
|
||||
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to list of PIL.Image
|
||||
if pattern != "T H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
||||
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||
return video
|
||||
|
||||
|
||||
def load_models_to_device(self, model_names=[]):
|
||||
if self.vram_management_enabled:
|
||||
# offload models
|
||||
for name, model in self.named_children():
|
||||
if name not in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
else:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
# onload models
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
else:
|
||||
model.to(self.device)
|
||||
|
||||
|
||||
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
||||
# Initialize Gaussian noise
|
||||
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
||||
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
return noise
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
|
||||
self.vram_management_enabled = True
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
model.train()
|
||||
model.requires_grad_(True)
|
||||
else:
|
||||
model.eval()
|
||||
model.requires_grad_(False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
path: Union[str, list[str]] = None
|
||||
model_id: str = None
|
||||
origin_file_pattern: Union[str, list[str]] = None
|
||||
download_resource: str = "ModelScope"
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
offload_dtype: Optional[torch.dtype] = None
|
||||
local_model_path: str = None
|
||||
skip_download: bool = False
|
||||
|
||||
def download_if_necessary(self, use_usp=False):
|
||||
if self.path is None:
|
||||
# Check model_id and origin_file_pattern
|
||||
if self.model_id is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
||||
|
||||
# Skip if not in rank 0
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
skip_download = self.skip_download or dist.get_rank() != 0
|
||||
else:
|
||||
skip_download = self.skip_download
|
||||
|
||||
# Check whether the origin path is a folder
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.origin_file_pattern = ""
|
||||
allow_file_pattern = None
|
||||
is_folder = True
|
||||
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
|
||||
allow_file_pattern = self.origin_file_pattern + "*"
|
||||
is_folder = True
|
||||
else:
|
||||
allow_file_pattern = self.origin_file_pattern
|
||||
is_folder = False
|
||||
|
||||
# Download
|
||||
if not skip_download:
|
||||
if self.local_model_path is None:
|
||||
self.local_model_path = "./models"
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
|
||||
# Let rank 1, 2, ... wait for rank 0
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
dist.barrier(device_ids=[dist.get_rank()])
|
||||
|
||||
# Return downloaded files
|
||||
if is_folder:
|
||||
self.path = os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)
|
||||
else:
|
||||
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||
if isinstance(self.path, list) and len(self.path) == 1:
|
||||
self.path = self.path[0]
|
||||
|
||||
|
||||
|
||||
class PipelineUnit:
|
||||
def __init__(
|
||||
self,
|
||||
seperate_cfg: bool = False,
|
||||
take_over: bool = False,
|
||||
input_params: tuple[str] = None,
|
||||
input_params_posi: dict[str, str] = None,
|
||||
input_params_nega: dict[str, str] = None,
|
||||
onload_model_names: tuple[str] = None
|
||||
):
|
||||
self.seperate_cfg = seperate_cfg
|
||||
self.take_over = take_over
|
||||
self.input_params = input_params
|
||||
self.input_params_posi = input_params_posi
|
||||
self.input_params_nega = input_params_nega
|
||||
self.onload_model_names = onload_model_names
|
||||
|
||||
|
||||
def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict:
|
||||
raise NotImplementedError("`process` is not implemented.")
|
||||
|
||||
|
||||
|
||||
class PipelineUnitRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
||||
if unit.take_over:
|
||||
# Let the pipeline unit take over this function.
|
||||
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
||||
elif unit.seperate_cfg:
|
||||
# Positive side
|
||||
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_posi.update(processor_outputs)
|
||||
# Negative side
|
||||
if inputs_shared["cfg_scale"] != 1:
|
||||
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_shared.update(processor_outputs)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
Reference in New Issue
Block a user