DiffSynth-Studio 2.0 major update

This commit is contained in:
root
2025-12-04 16:33:07 +08:00
parent afd101f345
commit 72af7122b3
758 changed files with 26462 additions and 2221398 deletions

View File

@@ -0,0 +1,5 @@
from .attention import *
from .data import *
from .gradient import *
from .loader import *
from .vram import *

View File

@@ -0,0 +1 @@
from .attention import attention_forward

View File

@@ -0,0 +1,121 @@
import torch, os
from einops import rearrange
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
try:
import xformers.ops as xops
XFORMERS_AVAILABLE = True
except ModuleNotFoundError:
XFORMERS_AVAILABLE = False
def initialize_attention_priority():
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
elif FLASH_ATTN_3_AVAILABLE:
return "flash_attention_3"
elif FLASH_ATTN_2_AVAILABLE:
return "flash_attention_2"
elif SAGE_ATTN_AVAILABLE:
return "sage_attention"
elif XFORMERS_AVAILABLE:
return "xformers"
else:
return "torch"
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
dims = {} if dims is None else dims
if q_pattern != required_in_pattern:
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
if k_pattern != required_in_pattern:
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
if v_pattern != required_in_pattern:
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
return q, k, v
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
dims = {} if dims is None else dims
if out_pattern != required_out_pattern:
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
return out
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
if isinstance(out, tuple):
out = out[0]
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = sageattn(q, k, v, sm_scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = xops.memory_efficient_attention(q, k, v, scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
if compatibility_mode or (attn_mask is not None):
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
else:
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "sage_attention":
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "xformers":
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
else:
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)

View File

@@ -0,0 +1 @@
from .unified_dataset import UnifiedDataset

View File

@@ -0,0 +1,218 @@
import torch, torchvision, imageio, os
import imageio.v3 as iio
from PIL import Image
class DataProcessingPipeline:
def __init__(self, operators=None):
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
def __call__(self, data):
for operator in self.operators:
data = operator(data)
return data
def __rshift__(self, pipe):
if isinstance(pipe, DataProcessingOperator):
pipe = DataProcessingPipeline([pipe])
return DataProcessingPipeline(self.operators + pipe.operators)
class DataProcessingOperator:
def __call__(self, data):
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
def __rshift__(self, pipe):
if isinstance(pipe, DataProcessingOperator):
pipe = DataProcessingPipeline([pipe])
return DataProcessingPipeline([self]).__rshift__(pipe)
class DataProcessingOperatorRaw(DataProcessingOperator):
def __call__(self, data):
return data
class ToInt(DataProcessingOperator):
def __call__(self, data):
return int(data)
class ToFloat(DataProcessingOperator):
def __call__(self, data):
return float(data)
class ToStr(DataProcessingOperator):
def __init__(self, none_value=""):
self.none_value = none_value
def __call__(self, data):
if data is None: data = self.none_value
return str(data)
class LoadImage(DataProcessingOperator):
def __init__(self, convert_RGB=True):
self.convert_RGB = convert_RGB
def __call__(self, data: str):
image = Image.open(data)
if self.convert_RGB: image = image.convert("RGB")
return image
class ImageCropAndResize(DataProcessingOperator):
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
self.height = height
self.width = width
self.max_pixels = max_pixels
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.height is None or self.width is None:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def __call__(self, data: Image.Image):
image = self.crop_and_resize(data, *self.get_height_width(data))
return image
class ToList(DataProcessingOperator):
def __call__(self, data):
return [data]
class LoadVideo(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def get_num_frames(self, reader):
num_frames = self.num_frames
if int(reader.count_frames()) < num_frames:
num_frames = int(reader.count_frames())
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def __call__(self, data: str):
reader = imageio.get_reader(data)
num_frames = self.get_num_frames(reader)
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(frame_id)
frame = Image.fromarray(frame)
frame = self.frame_processor(frame)
frames.append(frame)
reader.close()
return frames
class SequencialProcess(DataProcessingOperator):
def __init__(self, operator=lambda x: x):
self.operator = operator
def __call__(self, data):
return [self.operator(i) for i in data]
class LoadGIF(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def get_num_frames(self, path):
num_frames = self.num_frames
images = iio.imread(path, mode="RGB")
if len(images) < num_frames:
num_frames = len(images)
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def __call__(self, data: str):
num_frames = self.get_num_frames(data)
frames = []
images = iio.imread(data, mode="RGB")
for img in images:
frame = Image.fromarray(img)
frame = self.frame_processor(frame)
frames.append(frame)
if len(frames) >= num_frames:
break
return frames
class RouteByExtensionName(DataProcessingOperator):
def __init__(self, operator_map):
self.operator_map = operator_map
def __call__(self, data: str):
file_ext_name = data.split(".")[-1].lower()
for ext_names, operator in self.operator_map:
if ext_names is None or file_ext_name in ext_names:
return operator(data)
raise ValueError(f"Unsupported file: {data}")
class RouteByType(DataProcessingOperator):
def __init__(self, operator_map):
self.operator_map = operator_map
def __call__(self, data):
for dtype, operator in self.operator_map:
if dtype is None or isinstance(data, dtype):
return operator(data)
raise ValueError(f"Unsupported data: {data}")
class LoadTorchPickle(DataProcessingOperator):
def __init__(self, map_location="cpu"):
self.map_location = map_location
def __call__(self, data):
return torch.load(data, map_location=self.map_location, weights_only=False)
class ToAbsolutePath(DataProcessingOperator):
def __init__(self, base_path=""):
self.base_path = base_path
def __call__(self, data):
return os.path.join(self.base_path, data)
class LoadAudio(DataProcessingOperator):
def __init__(self, sr=16000):
self.sr = sr
def __call__(self, data: str):
import librosa
input_audio, sample_rate = librosa.load(data, sr=self.sr)
return input_audio

View File

@@ -0,0 +1,112 @@
from .operators import *
import torch, json, pandas
class UnifiedDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None, metadata_path=None,
repeat=1,
data_file_keys=tuple(),
main_data_operator=lambda x: x,
special_operator_map=None,
):
self.base_path = base_path
self.metadata_path = metadata_path
self.repeat = repeat
self.data_file_keys = data_file_keys
self.main_data_operator = main_data_operator
self.cached_data_operator = LoadTorchPickle()
self.special_operator_map = {} if special_operator_map is None else special_operator_map
self.data = []
self.cached_data = []
self.load_from_cache = metadata_path is None
self.load_metadata(metadata_path)
@staticmethod
def default_image_operator(
base_path="",
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
):
return RouteByType(operator_map=[
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
])
@staticmethod
def default_video_operator(
base_path="",
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
num_frames=81, time_division_factor=4, time_division_remainder=1,
):
return RouteByType(operator_map=[
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
(("gif",), LoadGIF(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
)),
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
)),
])),
])
def search_for_cached_data_files(self, path):
for file_name in os.listdir(path):
subpath = os.path.join(path, file_name)
if os.path.isdir(subpath):
self.search_for_cached_data_files(subpath)
elif subpath.endswith(".pth"):
self.cached_data.append(subpath)
def load_metadata(self, metadata_path):
if metadata_path is None:
print("No metadata_path. Searching for cached data files.")
self.search_for_cached_data_files(self.base_path)
print(f"{len(self.cached_data)} cached data files found.")
elif metadata_path.endswith(".json"):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.data = metadata
elif metadata_path.endswith(".jsonl"):
metadata = []
with open(metadata_path, 'r') as f:
for line in f:
metadata.append(json.loads(line.strip()))
self.data = metadata
else:
metadata = pandas.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def __getitem__(self, data_id):
if self.load_from_cache:
data = self.cached_data[data_id % len(self.cached_data)]
data = self.cached_data_operator(data)
else:
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
if key in self.special_operator_map:
data[key] = self.special_operator_map[key](data[key])
elif key in self.data_file_keys:
data[key] = self.main_data_operator(data[key])
return data
def __len__(self):
if self.load_from_cache:
return len(self.cached_data) * self.repeat
else:
return len(self.data) * self.repeat
def check_data_equal(self, data1, data2):
# Debug only
if len(data1) != len(data2):
return False
for k in data1:
if data1[k] != data2[k]:
return False
return True

View File

@@ -0,0 +1 @@
from .gradient_checkpoint import gradient_checkpoint_forward

View File

@@ -0,0 +1,34 @@
import torch
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*args,
**kwargs,
):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
else:
model_output = model(*args, **kwargs)
return model_output

View File

@@ -0,0 +1,3 @@
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
from .model import load_model, load_model_with_disk_offload
from .config import ModelConfig

View File

@@ -0,0 +1,117 @@
import torch, glob, os
from typing import Optional, Union
from dataclasses import dataclass
from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download
from typing import Optional
@dataclass
class ModelConfig:
path: Union[str, list[str]] = None
model_id: str = None
origin_file_pattern: Union[str, list[str]] = None
download_source: str = None
local_model_path: str = None
skip_download: bool = None
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
onload_device: Optional[Union[str, torch.device]] = None
onload_dtype: Optional[torch.dtype] = None
preparing_device: Optional[Union[str, torch.device]] = None
preparing_dtype: Optional[torch.dtype] = None
computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False
def check_input(self):
if self.path is None and self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
def parse_original_file_pattern(self):
if self.origin_file_pattern is None or self.origin_file_pattern == "":
return "*"
elif self.origin_file_pattern.endswith("/"):
return self.origin_file_pattern + "*"
else:
return self.origin_file_pattern
def parse_download_source(self):
if self.download_source is None:
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
else:
return "modelscope"
else:
return self.download_source
def parse_skip_download(self):
if self.skip_download is None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
return True
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
return False
else:
return False
else:
return self.skip_download
def download(self):
origin_file_pattern = self.parse_original_file_pattern()
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
download_source = self.parse_download_source()
if download_source.lower() == "modelscope":
snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_file_pattern=origin_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
elif download_source.lower() == "huggingface":
hf_snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_patterns=origin_file_pattern,
ignore_patterns=downloaded_files,
local_files_only=False
)
else:
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
def require_downloading(self):
if self.path is not None:
return False
skip_download = self.parse_skip_download()
return not skip_download
def reset_local_model_path(self):
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
elif self.local_model_path is None:
self.local_model_path = "./models"
def download_if_necessary(self):
self.check_input()
self.reset_local_model_path()
if self.require_downloading():
self.download()
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.path = os.path.join(self.local_model_path, self.model_id)
else:
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
def vram_config(self):
return {
"offload_device": self.offload_device,
"offload_dtype": self.offload_dtype,
"onload_device": self.onload_device,
"onload_dtype": self.onload_dtype,
"preparing_device": self.preparing_device,
"preparing_dtype": self.preparing_dtype,
"computation_device": self.computation_device,
"computation_dtype": self.computation_dtype,
}

View File

@@ -0,0 +1,121 @@
from safetensors import safe_open
import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
state_dict = {}
with safe_open(file_path, framework="pt", device=str(device)) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = torch.load(file_path, map_location=device, weights_only=True)
if len(state_dict) == 1:
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
elif "module" in state_dict:
state_dict = state_dict["module"]
elif "model_state" in state_dict:
state_dict = state_dict["model_state"]
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_keys_dict(file_path):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_keys_dict(file_path_))
return state_dict
if file_path.endswith(".safetensors"):
return load_keys_dict_from_safetensors(file_path)
else:
return load_keys_dict_from_bin(file_path)
def load_keys_dict_from_safetensors(file_path):
keys_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
keys_dict[k] = f.get_slice(k).get_shape()
return keys_dict
def convert_state_dict_to_keys_dict(state_dict):
keys_dict = {}
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
keys_dict[k] = list(v.shape)
else:
keys_dict[k] = convert_state_dict_to_keys_dict(v)
return keys_dict
def load_keys_dict_from_bin(file_path):
state_dict = load_state_dict_from_bin(file_path)
keys_dict = convert_state_dict_to_keys_dict(state_dict)
return keys_dict
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, dict):
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
else:
if with_shape:
shape = "_".join(map(str, list(value)))
keys.append(key + ":" + shape)
keys.append(key)
keys.sort()
keys_str = ",".join(keys)
return keys_str
def hash_model_file(path, with_shape=True):
keys_dict = load_keys_dict(path)
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()

View File

@@ -0,0 +1,79 @@
from ..vram.initialization import skip_model_initialization
from ..vram.disk_map import DiskMap
from ..vram.layers import enable_vram_management
from .file import load_state_dict
import torch
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
config = {} if config is None else config
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
with skip_model_initialization():
model = model_class(**config)
# What is `module_map`?
# This is a module mapping table for VRAM management.
if module_map is not None:
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
device = [d for d in devices if d != "disk"][0]
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
else:
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
else:
# Why do we use `DiskMap`?
# Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file.
if use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else:
state_dict = load_state_dict(path, torch_dtype, device)
# Why do we use `state_dict_converter`?
# Some models are saved in complex formats,
# and we need to convert the state dict into the appropriate format.
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
# Why do we call `to()`?
# Because some models override the behavior of `to()`,
# especially those from libraries like Transformers.
model = model.to(dtype=torch_dtype, device=device)
if hasattr(model, "eval"):
model = model.eval()
return model
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
if isinstance(path, str):
path = [path]
config = {} if config is None else config
with skip_model_initialization():
model = model_class(**config)
if hasattr(model, "eval"):
model = model.eval()
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": device,
"computation_dtype": torch_dtype,
"computation_device": device,
}
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
return model

View File

@@ -0,0 +1,2 @@
from .initialization import skip_model_initialization
from .layers import *

View File

@@ -0,0 +1,93 @@
from safetensors import safe_open
import torch, os
class SafetensorsCompatibleTensor:
def __init__(self, tensor):
self.tensor = tensor
def get_shape(self):
return list(self.tensor.shape)
class SafetensorsCompatibleBinaryLoader:
def __init__(self, path, device):
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
self.state_dict = torch.load(path, weights_only=True, map_location=device)
def keys(self):
return self.state_dict.keys()
def get_tensor(self, name):
return self.state_dict[name]
def get_slice(self, name):
return SafetensorsCompatibleTensor(self.state_dict[name])
class DiskMap:
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
self.path = path if isinstance(path, list) else [path]
self.device = device
self.torch_dtype = torch_dtype
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
else:
self.buffer_size = buffer_size
self.files = []
self.flush_files()
self.name_map = {}
for file_id, file in enumerate(self.files):
for name in file.keys():
self.name_map[name] = file_id
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
def flush_files(self):
if len(self.files) == 0:
for path in self.path:
if path.endswith(".safetensors"):
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
else:
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
else:
for i, path in enumerate(self.path):
if path.endswith(".safetensors"):
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
self.num_params = 0
def __getitem__(self, name):
if self.rename_dict is not None: name = self.rename_dict[name]
file_id = self.name_map[name]
param = self.files[file_id].get_tensor(name)
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
param = param.to(self.torch_dtype)
if isinstance(param, torch.Tensor) and param.device == "cpu":
param = param.clone()
if isinstance(param, torch.Tensor):
self.num_params += param.numel()
if self.num_params > self.buffer_size:
self.flush_files()
return param
def fetch_rename_dict(self, state_dict_converter):
if state_dict_converter is None:
return None
state_dict = {}
for file in self.files:
for name in file.keys():
state_dict[name] = name
state_dict = state_dict_converter(state_dict)
return state_dict
def __iter__(self):
if self.rename_dict is not None:
return self.rename_dict.__iter__()
else:
return self.name_map.__iter__()
def __contains__(self, x):
if self.rename_dict is not None:
return x in self.rename_dict
else:
return x in self.name_map

View File

@@ -0,0 +1,21 @@
import torch
from contextlib import contextmanager
@contextmanager
def skip_model_initialization(device=torch.device("meta")):
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
old_register_parameter = torch.nn.Module.register_parameter
torch.nn.Module.register_parameter = register_empty_parameter
try:
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter

View File

@@ -0,0 +1,475 @@
import torch, copy
from typing import Union
from .initialization import skip_model_initialization
from .disk_map import DiskMap
class AutoTorchModule(torch.nn.Module):
def __init__(
self,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
):
super().__init__()
self.set_dtype_and_device(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.state = 0
self.name = ""
def set_dtype_and_device(
self,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
):
self.offload_dtype = offload_dtype or computation_dtype
self.offload_device = offload_device or computation_dtype
self.onload_dtype = onload_dtype or computation_dtype
self.onload_device = onload_device or computation_dtype
self.preparing_dtype = preparing_dtype or computation_dtype
self.preparing_device = preparing_device or computation_dtype
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
def cast_to(self, weight, dtype, device):
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def check_free_vram(self):
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit
def offload(self):
if self.state != 0:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state != 1:
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def param_name(self, name):
if self.name == "":
return name
else:
return self.name + "." + name
class AutoWrappedModule(AutoTorchModule):
def __init__(
self,
module: torch.nn.Module,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
super().__init__(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.module = module
if offload_dtype == "disk":
self.name = name
self.disk_map = disk_map
self.required_params = [name for name, _ in self.module.named_parameters()]
self.disk_offload = True
else:
self.disk_offload = False
def load_from_disk(self, torch_dtype, device, copy_module=False):
if copy_module:
module = copy.deepcopy(self.module)
else:
module = self.module
state_dict = {}
for name in self.required_params:
param = self.disk_map[self.param_name(name)]
param = param.to(dtype=torch_dtype, device=device)
state_dict[name] = param
module.load_state_dict(state_dict, assign=True)
module.to(dtype=torch_dtype, device=device)
return module
def offload_to_disk(self, model: torch.nn.Module):
for buf in model.buffers():
# If there are some parameters are registed in buffers (not in state dict),
# We cannot offload the model.
for children in model.children():
self.offload_to_disk(children)
break
else:
model.to("meta")
def offload(self):
# offload / onload / preparing -> offload
if self.state != 0:
if self.disk_offload:
self.offload_to_disk(self.module)
else:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
# offload / onload / preparing -> onload
if self.state < 1:
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
self.load_from_disk(self.onload_dtype, self.onload_device)
elif self.onload_device != "disk":
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def preparing(self):
# onload / preparing -> preparing
if self.state != 2:
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
self.load_from_disk(self.preparing_dtype, self.preparing_device)
elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2
def cast_to(self, module, dtype, device):
return copy.deepcopy(module).to(dtype=dtype, device=device)
def computation(self):
# onload / preparing -> computation (temporary)
if self.state == 2:
torch_dtype, device = self.preparing_dtype, self.preparing_device
else:
torch_dtype, device = self.onload_dtype, self.onload_device
if torch_dtype == self.computation_dtype and device == self.computation_device:
module = self.module
elif self.disk_offload and device == "disk":
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
else:
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
return module
def forward(self, *args, **kwargs):
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
self.preparing()
module = self.computation()
return module(*args, **kwargs)
def __getattr__(self, name):
if name in self.__dict__ or name == "module":
return super().__getattr__(name)
else:
return getattr(self.module, name)
class AutoWrappedNonRecurseModule(AutoWrappedModule):
def __init__(
self,
module: torch.nn.Module,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
super().__init__(
module,
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
name,
disk_map,
**kwargs
)
if self.disk_offload:
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
def load_from_disk(self, torch_dtype, device, copy_module=False):
if copy_module:
module = copy.deepcopy(self.module)
else:
module = self.module
state_dict = {}
for name in self.required_params:
param = self.disk_map[self.param_name(name)]
param = param.to(dtype=torch_dtype, device=device)
state_dict[name] = param
module.load_state_dict(state_dict, assign=True, strict=False)
return module
def offload_to_disk(self, model: torch.nn.Module):
for name in self.required_params:
getattr(self, name).to("meta")
def cast_to(self, module, dtype, device):
# Parameter casting is implemented in the model architecture.
return module
def __getattr__(self, name):
if name in self.__dict__ or name == "module":
return super().__getattr__(name)
else:
return getattr(self.module, name)
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def __init__(
self,
module: torch.nn.Linear,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
with skip_model_initialization():
super().__init__(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
)
self.set_dtype_and_device(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.weight = module.weight
self.bias = module.bias
self.state = 0
self.name = name
self.lora_A_weights = []
self.lora_B_weights = []
self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
if offload_dtype == "disk":
self.disk_map = disk_map
self.disk_offload = True
else:
self.disk_offload = False
def fp8_linear(
self,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor = None,
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if self.computation_dtype == torch.float8_e4m3fnuz:
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
input = input / (scale_a + 1e-8)
input = input.to(self.computation_dtype)
weight = weight.to(self.computation_dtype)
bias = bias.to(torch.bfloat16)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result
def load_from_disk(self, torch_dtype, device, assign=True):
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
if assign:
state_dict = {"weight": weight}
if bias is not None: state_dict["bias"] = bias
self.load_state_dict(state_dict, assign=True)
return weight, bias
def offload(self):
# offload / onload / preparing -> offload
if self.state != 0:
if self.disk_offload:
self.to("meta")
else:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
# offload / onload / preparing -> onload
if self.state < 1:
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
self.load_from_disk(self.onload_dtype, self.onload_device)
elif self.onload_device != "disk":
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def preparing(self):
# onload / preparing -> preparing
if self.state != 2:
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
self.load_from_disk(self.preparing_dtype, self.preparing_device)
elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2
def computation(self):
# onload / preparing -> computation (temporary)
if self.state == 2:
torch_dtype, device = self.preparing_dtype, self.preparing_device
else:
torch_dtype, device = self.onload_dtype, self.onload_device
if torch_dtype == self.computation_dtype and device == self.computation_device:
weight, bias = self.weight, self.bias
elif self.disk_offload and device == "disk":
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
else:
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
return weight, bias
def linear_forward(self, x, weight, bias):
if self.enable_fp8:
out = self.fp8_linear(x, weight, bias)
else:
out = torch.nn.functional.linear(x, weight, bias)
return out
def lora_forward(self, x, out):
if self.lora_merger is None:
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out = out + x @ lora_A.T @ lora_B.T
else:
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
lora_output.append(x @ lora_A.T @ lora_B.T)
lora_output = torch.stack(lora_output)
out = self.lora_merger(out, lora_output)
return out
def forward(self, x, *args, **kwargs):
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
self.preparing()
weight, bias = self.computation()
out = self.linear_forward(x, weight, bias)
if len(self.lora_A_weights) > 0:
out = self.lora_forward(x, out)
return out
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
if isinstance(model, AutoWrappedNonRecurseModule):
model = model.module
for name, module in model.named_children():
layer_name = name if name_prefix == "" else name_prefix + "." + name
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
if isinstance(module_, AutoWrappedNonRecurseModule):
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
setattr(model, name, module_)
break
else:
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
def fill_vram_config(model, vram_config):
vram_config_ = vram_config.copy()
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
vram_config_["onload_device"] = vram_config["computation_device"]
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
vram_config_["preparing_device"] = vram_config["computation_device"]
for k in vram_config:
if vram_config[k] != vram_config_[k]:
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
break
return vram_config_
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
for source_module, target_module in module_map.items():
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
if isinstance(model, source_module):
vram_config = fill_vram_config(model, vram_config)
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
break
else:
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
model.vram_management_enabled = True
return model