mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
262 lines
12 KiB
Python
262 lines
12 KiB
Python
import torch, warnings, glob, os
|
|
import numpy as np
|
|
from PIL import Image
|
|
from einops import repeat, reduce
|
|
from typing import Optional, Union
|
|
from dataclasses import dataclass
|
|
from modelscope import snapshot_download
|
|
import numpy as np
|
|
from PIL import Image
|
|
from typing import Optional
|
|
|
|
|
|
class BasePipeline(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
device="cuda", torch_dtype=torch.float16,
|
|
height_division_factor=64, width_division_factor=64,
|
|
time_division_factor=None, time_division_remainder=None,
|
|
):
|
|
super().__init__()
|
|
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
|
self.device = device
|
|
self.torch_dtype = torch_dtype
|
|
# The following parameters are used for shape check.
|
|
self.height_division_factor = height_division_factor
|
|
self.width_division_factor = width_division_factor
|
|
self.time_division_factor = time_division_factor
|
|
self.time_division_remainder = time_division_remainder
|
|
self.vram_management_enabled = False
|
|
|
|
|
|
def to(self, *args, **kwargs):
|
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
|
if device is not None:
|
|
self.device = device
|
|
if dtype is not None:
|
|
self.torch_dtype = dtype
|
|
super().to(*args, **kwargs)
|
|
return self
|
|
|
|
|
|
def check_resize_height_width(self, height, width, num_frames=None):
|
|
# Shape check
|
|
if height % self.height_division_factor != 0:
|
|
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
|
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
|
if width % self.width_division_factor != 0:
|
|
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
|
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
|
if num_frames is None:
|
|
return height, width
|
|
else:
|
|
if num_frames % self.time_division_factor != self.time_division_remainder:
|
|
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
|
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
|
return height, width, num_frames
|
|
|
|
|
|
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
|
# Transform a PIL.Image to torch.Tensor
|
|
image = torch.Tensor(np.array(image, dtype=np.float32))
|
|
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
|
image = image * ((max_value - min_value) / 255) + min_value
|
|
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
|
return image
|
|
|
|
|
|
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
|
# Transform a list of PIL.Image to torch.Tensor
|
|
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
|
video = torch.stack(video, dim=pattern.index("T") // 2)
|
|
return video
|
|
|
|
|
|
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
|
# Transform a torch.Tensor to PIL.Image
|
|
if pattern != "H W C":
|
|
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
|
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
|
image = image.to(device="cpu", dtype=torch.uint8)
|
|
image = Image.fromarray(image.numpy())
|
|
return image
|
|
|
|
|
|
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
|
# Transform a torch.Tensor to list of PIL.Image
|
|
if pattern != "T H W C":
|
|
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
|
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
|
return video
|
|
|
|
|
|
def load_models_to_device(self, model_names=[]):
|
|
if self.vram_management_enabled:
|
|
# offload models
|
|
for name, model in self.named_children():
|
|
if name not in model_names:
|
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
|
for module in model.modules():
|
|
if hasattr(module, "offload"):
|
|
module.offload()
|
|
else:
|
|
model.cpu()
|
|
torch.cuda.empty_cache()
|
|
# onload models
|
|
for name, model in self.named_children():
|
|
if name in model_names:
|
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
|
for module in model.modules():
|
|
if hasattr(module, "onload"):
|
|
module.onload()
|
|
else:
|
|
model.to(self.device)
|
|
|
|
|
|
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
|
# Initialize Gaussian noise
|
|
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
|
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
|
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
|
return noise
|
|
|
|
|
|
def enable_cpu_offload(self):
|
|
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
|
|
self.vram_management_enabled = True
|
|
|
|
|
|
def get_vram(self):
|
|
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
|
|
|
|
|
def freeze_except(self, model_names):
|
|
for name, model in self.named_children():
|
|
if name in model_names:
|
|
model.train()
|
|
model.requires_grad_(True)
|
|
else:
|
|
model.eval()
|
|
model.requires_grad_(False)
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
path: Union[str, list[str]] = None
|
|
model_id: str = None
|
|
origin_file_pattern: Union[str, list[str]] = None
|
|
download_resource: str = "ModelScope"
|
|
offload_device: Optional[Union[str, torch.device]] = None
|
|
offload_dtype: Optional[torch.dtype] = None
|
|
local_model_path: str = None
|
|
skip_download: bool = False
|
|
|
|
def download_if_necessary(self, use_usp=False):
|
|
if self.path is None:
|
|
# Check model_id and origin_file_pattern
|
|
if self.model_id is None:
|
|
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
|
|
|
# Skip if not in rank 0
|
|
if use_usp:
|
|
import torch.distributed as dist
|
|
skip_download = self.skip_download or dist.get_rank() != 0
|
|
else:
|
|
skip_download = self.skip_download
|
|
|
|
# Check whether the origin path is a folder
|
|
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
|
self.origin_file_pattern = ""
|
|
allow_file_pattern = None
|
|
is_folder = True
|
|
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
|
|
allow_file_pattern = self.origin_file_pattern + "*"
|
|
is_folder = True
|
|
else:
|
|
allow_file_pattern = self.origin_file_pattern
|
|
is_folder = False
|
|
|
|
# Download
|
|
if not skip_download:
|
|
if self.local_model_path is None:
|
|
self.local_model_path = "./models"
|
|
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
|
snapshot_download(
|
|
self.model_id,
|
|
local_dir=os.path.join(self.local_model_path, self.model_id),
|
|
allow_file_pattern=allow_file_pattern,
|
|
ignore_file_pattern=downloaded_files,
|
|
local_files_only=False
|
|
)
|
|
|
|
# Let rank 1, 2, ... wait for rank 0
|
|
if use_usp:
|
|
import torch.distributed as dist
|
|
dist.barrier(device_ids=[dist.get_rank()])
|
|
|
|
# Return downloaded files
|
|
if is_folder:
|
|
self.path = os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)
|
|
else:
|
|
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
|
if isinstance(self.path, list) and len(self.path) == 1:
|
|
self.path = self.path[0]
|
|
|
|
|
|
|
|
class PipelineUnit:
|
|
def __init__(
|
|
self,
|
|
seperate_cfg: bool = False,
|
|
take_over: bool = False,
|
|
input_params: tuple[str] = None,
|
|
input_params_posi: dict[str, str] = None,
|
|
input_params_nega: dict[str, str] = None,
|
|
onload_model_names: tuple[str] = None
|
|
):
|
|
self.seperate_cfg = seperate_cfg
|
|
self.take_over = take_over
|
|
self.input_params = input_params
|
|
self.input_params_posi = input_params_posi
|
|
self.input_params_nega = input_params_nega
|
|
self.onload_model_names = onload_model_names
|
|
|
|
|
|
def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict:
|
|
raise NotImplementedError("`process` is not implemented.")
|
|
|
|
|
|
|
|
class PipelineUnitRunner:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
|
if unit.take_over:
|
|
# Let the pipeline unit take over this function.
|
|
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
|
elif unit.seperate_cfg:
|
|
# Positive side
|
|
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
|
if unit.input_params is not None:
|
|
for name in unit.input_params:
|
|
processor_inputs[name] = inputs_shared.get(name)
|
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
|
inputs_posi.update(processor_outputs)
|
|
# Negative side
|
|
if inputs_shared["cfg_scale"] != 1:
|
|
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
|
if unit.input_params is not None:
|
|
for name in unit.input_params:
|
|
processor_inputs[name] = inputs_shared.get(name)
|
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
|
inputs_nega.update(processor_outputs)
|
|
else:
|
|
inputs_nega.update(processor_outputs)
|
|
else:
|
|
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
|
inputs_shared.update(processor_outputs)
|
|
return inputs_shared, inputs_posi, inputs_nega
|