mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
5
diffsynth/core/__init__.py
Normal file
5
diffsynth/core/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .attention import *
|
||||
from .data import *
|
||||
from .gradient import *
|
||||
from .loader import *
|
||||
from .vram import *
|
||||
1
diffsynth/core/attention/__init__.py
Normal file
1
diffsynth/core/attention/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .attention import attention_forward
|
||||
121
diffsynth/core/attention/attention.py
Normal file
121
diffsynth/core/attention/attention.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import torch, os
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
FLASH_ATTN_2_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_2_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTN_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
SAGE_ATTN_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
def initialize_attention_priority():
|
||||
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
|
||||
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
|
||||
elif FLASH_ATTN_3_AVAILABLE:
|
||||
return "flash_attention_3"
|
||||
elif FLASH_ATTN_2_AVAILABLE:
|
||||
return "flash_attention_2"
|
||||
elif SAGE_ATTN_AVAILABLE:
|
||||
return "sage_attention"
|
||||
elif XFORMERS_AVAILABLE:
|
||||
return "xformers"
|
||||
else:
|
||||
return "torch"
|
||||
|
||||
|
||||
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
|
||||
|
||||
|
||||
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
|
||||
dims = {} if dims is None else dims
|
||||
if q_pattern != required_in_pattern:
|
||||
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
|
||||
if k_pattern != required_in_pattern:
|
||||
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
||||
if v_pattern != required_in_pattern:
|
||||
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
|
||||
dims = {} if dims is None else dims
|
||||
if out_pattern != required_out_pattern:
|
||||
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
|
||||
return out
|
||||
|
||||
|
||||
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
|
||||
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
||||
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
|
||||
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||
return out
|
||||
|
||||
|
||||
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||
return out
|
||||
|
||||
|
||||
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
|
||||
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||
return out
|
||||
|
||||
|
||||
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
||||
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||
out = sageattn(q, k, v, sm_scale=scale)
|
||||
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||
return out
|
||||
|
||||
|
||||
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||
out = xops.memory_efficient_attention(q, k, v, scale=scale)
|
||||
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||
return out
|
||||
|
||||
|
||||
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
|
||||
if compatibility_mode or (attn_mask is not None):
|
||||
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
|
||||
else:
|
||||
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
|
||||
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
|
||||
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||
elif ATTENTION_IMPLEMENTATION == "sage_attention":
|
||||
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||
elif ATTENTION_IMPLEMENTATION == "xformers":
|
||||
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||
else:
|
||||
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||
1
diffsynth/core/data/__init__.py
Normal file
1
diffsynth/core/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .unified_dataset import UnifiedDataset
|
||||
218
diffsynth/core/data/operators.py
Normal file
218
diffsynth/core/data/operators.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import torch, torchvision, imageio, os
|
||||
import imageio.v3 as iio
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class DataProcessingPipeline:
|
||||
def __init__(self, operators=None):
|
||||
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
|
||||
|
||||
def __call__(self, data):
|
||||
for operator in self.operators:
|
||||
data = operator(data)
|
||||
return data
|
||||
|
||||
def __rshift__(self, pipe):
|
||||
if isinstance(pipe, DataProcessingOperator):
|
||||
pipe = DataProcessingPipeline([pipe])
|
||||
return DataProcessingPipeline(self.operators + pipe.operators)
|
||||
|
||||
|
||||
class DataProcessingOperator:
|
||||
def __call__(self, data):
|
||||
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
|
||||
|
||||
def __rshift__(self, pipe):
|
||||
if isinstance(pipe, DataProcessingOperator):
|
||||
pipe = DataProcessingPipeline([pipe])
|
||||
return DataProcessingPipeline([self]).__rshift__(pipe)
|
||||
|
||||
|
||||
class DataProcessingOperatorRaw(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return data
|
||||
|
||||
|
||||
class ToInt(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return int(data)
|
||||
|
||||
|
||||
class ToFloat(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return float(data)
|
||||
|
||||
|
||||
class ToStr(DataProcessingOperator):
|
||||
def __init__(self, none_value=""):
|
||||
self.none_value = none_value
|
||||
|
||||
def __call__(self, data):
|
||||
if data is None: data = self.none_value
|
||||
return str(data)
|
||||
|
||||
|
||||
class LoadImage(DataProcessingOperator):
|
||||
def __init__(self, convert_RGB=True):
|
||||
self.convert_RGB = convert_RGB
|
||||
|
||||
def __call__(self, data: str):
|
||||
image = Image.open(data)
|
||||
if self.convert_RGB: image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
class ImageCropAndResize(DataProcessingOperator):
|
||||
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.max_pixels = max_pixels
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
|
||||
def crop_and_resize(self, image, target_height, target_width):
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||
return image
|
||||
|
||||
def get_height_width(self, image):
|
||||
if self.height is None or self.width is None:
|
||||
width, height = image.size
|
||||
if width * height > self.max_pixels:
|
||||
scale = (width * height / self.max_pixels) ** 0.5
|
||||
height, width = int(height / scale), int(width / scale)
|
||||
height = height // self.height_division_factor * self.height_division_factor
|
||||
width = width // self.width_division_factor * self.width_division_factor
|
||||
else:
|
||||
height, width = self.height, self.width
|
||||
return height, width
|
||||
|
||||
def __call__(self, data: Image.Image):
|
||||
image = self.crop_and_resize(data, *self.get_height_width(data))
|
||||
return image
|
||||
|
||||
|
||||
class ToList(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return [data]
|
||||
|
||||
|
||||
class LoadVideo(DataProcessingOperator):
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||
self.num_frames = num_frames
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
def get_num_frames(self, reader):
|
||||
num_frames = self.num_frames
|
||||
if int(reader.count_frames()) < num_frames:
|
||||
num_frames = int(reader.count_frames())
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
def __call__(self, data: str):
|
||||
reader = imageio.get_reader(data)
|
||||
num_frames = self.get_num_frames(reader)
|
||||
frames = []
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(frame_id)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.frame_processor(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
return frames
|
||||
|
||||
|
||||
class SequencialProcess(DataProcessingOperator):
|
||||
def __init__(self, operator=lambda x: x):
|
||||
self.operator = operator
|
||||
|
||||
def __call__(self, data):
|
||||
return [self.operator(i) for i in data]
|
||||
|
||||
|
||||
class LoadGIF(DataProcessingOperator):
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||
self.num_frames = num_frames
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
def get_num_frames(self, path):
|
||||
num_frames = self.num_frames
|
||||
images = iio.imread(path, mode="RGB")
|
||||
if len(images) < num_frames:
|
||||
num_frames = len(images)
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
def __call__(self, data: str):
|
||||
num_frames = self.get_num_frames(data)
|
||||
frames = []
|
||||
images = iio.imread(data, mode="RGB")
|
||||
for img in images:
|
||||
frame = Image.fromarray(img)
|
||||
frame = self.frame_processor(frame)
|
||||
frames.append(frame)
|
||||
if len(frames) >= num_frames:
|
||||
break
|
||||
return frames
|
||||
|
||||
|
||||
class RouteByExtensionName(DataProcessingOperator):
|
||||
def __init__(self, operator_map):
|
||||
self.operator_map = operator_map
|
||||
|
||||
def __call__(self, data: str):
|
||||
file_ext_name = data.split(".")[-1].lower()
|
||||
for ext_names, operator in self.operator_map:
|
||||
if ext_names is None or file_ext_name in ext_names:
|
||||
return operator(data)
|
||||
raise ValueError(f"Unsupported file: {data}")
|
||||
|
||||
|
||||
class RouteByType(DataProcessingOperator):
|
||||
def __init__(self, operator_map):
|
||||
self.operator_map = operator_map
|
||||
|
||||
def __call__(self, data):
|
||||
for dtype, operator in self.operator_map:
|
||||
if dtype is None or isinstance(data, dtype):
|
||||
return operator(data)
|
||||
raise ValueError(f"Unsupported data: {data}")
|
||||
|
||||
|
||||
class LoadTorchPickle(DataProcessingOperator):
|
||||
def __init__(self, map_location="cpu"):
|
||||
self.map_location = map_location
|
||||
|
||||
def __call__(self, data):
|
||||
return torch.load(data, map_location=self.map_location, weights_only=False)
|
||||
|
||||
|
||||
class ToAbsolutePath(DataProcessingOperator):
|
||||
def __init__(self, base_path=""):
|
||||
self.base_path = base_path
|
||||
|
||||
def __call__(self, data):
|
||||
return os.path.join(self.base_path, data)
|
||||
|
||||
|
||||
class LoadAudio(DataProcessingOperator):
|
||||
def __init__(self, sr=16000):
|
||||
self.sr = sr
|
||||
def __call__(self, data: str):
|
||||
import librosa
|
||||
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
||||
return input_audio
|
||||
112
diffsynth/core/data/unified_dataset.py
Normal file
112
diffsynth/core/data/unified_dataset.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from .operators import *
|
||||
import torch, json, pandas
|
||||
|
||||
|
||||
class UnifiedDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path=None, metadata_path=None,
|
||||
repeat=1,
|
||||
data_file_keys=tuple(),
|
||||
main_data_operator=lambda x: x,
|
||||
special_operator_map=None,
|
||||
):
|
||||
self.base_path = base_path
|
||||
self.metadata_path = metadata_path
|
||||
self.repeat = repeat
|
||||
self.data_file_keys = data_file_keys
|
||||
self.main_data_operator = main_data_operator
|
||||
self.cached_data_operator = LoadTorchPickle()
|
||||
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
||||
self.data = []
|
||||
self.cached_data = []
|
||||
self.load_from_cache = metadata_path is None
|
||||
self.load_metadata(metadata_path)
|
||||
|
||||
@staticmethod
|
||||
def default_image_operator(
|
||||
base_path="",
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
):
|
||||
return RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def default_video_operator(
|
||||
base_path="",
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
||||
):
|
||||
return RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||
(("gif",), LoadGIF(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
])),
|
||||
])
|
||||
|
||||
def search_for_cached_data_files(self, path):
|
||||
for file_name in os.listdir(path):
|
||||
subpath = os.path.join(path, file_name)
|
||||
if os.path.isdir(subpath):
|
||||
self.search_for_cached_data_files(subpath)
|
||||
elif subpath.endswith(".pth"):
|
||||
self.cached_data.append(subpath)
|
||||
|
||||
def load_metadata(self, metadata_path):
|
||||
if metadata_path is None:
|
||||
print("No metadata_path. Searching for cached data files.")
|
||||
self.search_for_cached_data_files(self.base_path)
|
||||
print(f"{len(self.cached_data)} cached data files found.")
|
||||
elif metadata_path.endswith(".json"):
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.data = metadata
|
||||
elif metadata_path.endswith(".jsonl"):
|
||||
metadata = []
|
||||
with open(metadata_path, 'r') as f:
|
||||
for line in f:
|
||||
metadata.append(json.loads(line.strip()))
|
||||
self.data = metadata
|
||||
else:
|
||||
metadata = pandas.read_csv(metadata_path)
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
if self.load_from_cache:
|
||||
data = self.cached_data[data_id % len(self.cached_data)]
|
||||
data = self.cached_data_operator(data)
|
||||
else:
|
||||
data = self.data[data_id % len(self.data)].copy()
|
||||
for key in self.data_file_keys:
|
||||
if key in data:
|
||||
if key in self.special_operator_map:
|
||||
data[key] = self.special_operator_map[key](data[key])
|
||||
elif key in self.data_file_keys:
|
||||
data[key] = self.main_data_operator(data[key])
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
if self.load_from_cache:
|
||||
return len(self.cached_data) * self.repeat
|
||||
else:
|
||||
return len(self.data) * self.repeat
|
||||
|
||||
def check_data_equal(self, data1, data2):
|
||||
# Debug only
|
||||
if len(data1) != len(data2):
|
||||
return False
|
||||
for k in data1:
|
||||
if data1[k] != data2[k]:
|
||||
return False
|
||||
return True
|
||||
1
diffsynth/core/gradient/__init__.py
Normal file
1
diffsynth/core/gradient/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .gradient_checkpoint import gradient_checkpoint_forward
|
||||
34
diffsynth/core/gradient/gradient_checkpoint.py
Normal file
34
diffsynth/core/gradient/gradient_checkpoint.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
|
||||
def gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
model_output = model(*args, **kwargs)
|
||||
return model_output
|
||||
3
diffsynth/core/loader/__init__.py
Normal file
3
diffsynth/core/loader/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
|
||||
from .model import load_model, load_model_with_disk_offload
|
||||
from .config import ModelConfig
|
||||
117
diffsynth/core/loader/config.py
Normal file
117
diffsynth/core/loader/config.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch, glob, os
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
path: Union[str, list[str]] = None
|
||||
model_id: str = None
|
||||
origin_file_pattern: Union[str, list[str]] = None
|
||||
download_source: str = None
|
||||
local_model_path: str = None
|
||||
skip_download: bool = None
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
offload_dtype: Optional[torch.dtype] = None
|
||||
onload_device: Optional[Union[str, torch.device]] = None
|
||||
onload_dtype: Optional[torch.dtype] = None
|
||||
preparing_device: Optional[Union[str, torch.device]] = None
|
||||
preparing_dtype: Optional[torch.dtype] = None
|
||||
computation_device: Optional[Union[str, torch.device]] = None
|
||||
computation_dtype: Optional[torch.dtype] = None
|
||||
clear_parameters: bool = False
|
||||
|
||||
def check_input(self):
|
||||
if self.path is None and self.model_id is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||
|
||||
def parse_original_file_pattern(self):
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
return "*"
|
||||
elif self.origin_file_pattern.endswith("/"):
|
||||
return self.origin_file_pattern + "*"
|
||||
else:
|
||||
return self.origin_file_pattern
|
||||
|
||||
def parse_download_source(self):
|
||||
if self.download_source is None:
|
||||
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
|
||||
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
|
||||
else:
|
||||
return "modelscope"
|
||||
else:
|
||||
return self.download_source
|
||||
|
||||
def parse_skip_download(self):
|
||||
if self.skip_download is None:
|
||||
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
|
||||
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
|
||||
return True
|
||||
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return self.skip_download
|
||||
|
||||
def download(self):
|
||||
origin_file_pattern = self.parse_original_file_pattern()
|
||||
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
download_source = self.parse_download_source()
|
||||
if download_source.lower() == "modelscope":
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_file_pattern=origin_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
elif download_source.lower() == "huggingface":
|
||||
hf_snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_patterns=origin_file_pattern,
|
||||
ignore_patterns=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
else:
|
||||
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
|
||||
|
||||
def require_downloading(self):
|
||||
if self.path is not None:
|
||||
return False
|
||||
skip_download = self.parse_skip_download()
|
||||
return not skip_download
|
||||
|
||||
def reset_local_model_path(self):
|
||||
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
|
||||
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
|
||||
elif self.local_model_path is None:
|
||||
self.local_model_path = "./models"
|
||||
|
||||
def download_if_necessary(self):
|
||||
self.check_input()
|
||||
self.reset_local_model_path()
|
||||
if self.require_downloading():
|
||||
self.download()
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||
else:
|
||||
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||
if isinstance(self.path, list) and len(self.path) == 1:
|
||||
self.path = self.path[0]
|
||||
|
||||
def vram_config(self):
|
||||
return {
|
||||
"offload_device": self.offload_device,
|
||||
"offload_dtype": self.offload_dtype,
|
||||
"onload_device": self.onload_device,
|
||||
"onload_dtype": self.onload_dtype,
|
||||
"preparing_device": self.preparing_device,
|
||||
"preparing_dtype": self.preparing_dtype,
|
||||
"computation_device": self.computation_device,
|
||||
"computation_dtype": self.computation_dtype,
|
||||
}
|
||||
121
diffsynth/core/loader/file.py
Normal file
121
diffsynth/core/loader/file.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from safetensors import safe_open
|
||||
import torch, hashlib
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for file_path_ in file_path:
|
||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
|
||||
return state_dict
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
if len(state_dict) == 1:
|
||||
if "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
elif "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
elif "model_state" in state_dict:
|
||||
state_dict = state_dict["model_state"]
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
|
||||
def load_keys_dict(file_path):
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for file_path_ in file_path:
|
||||
state_dict.update(load_keys_dict(file_path_))
|
||||
return state_dict
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_keys_dict_from_safetensors(file_path)
|
||||
else:
|
||||
return load_keys_dict_from_bin(file_path)
|
||||
|
||||
|
||||
def load_keys_dict_from_safetensors(file_path):
|
||||
keys_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
keys_dict[k] = f.get_slice(k).get_shape()
|
||||
return keys_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_keys_dict(state_dict):
|
||||
keys_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
keys_dict[k] = list(v.shape)
|
||||
else:
|
||||
keys_dict[k] = convert_state_dict_to_keys_dict(v)
|
||||
return keys_dict
|
||||
|
||||
|
||||
def load_keys_dict_from_bin(file_path):
|
||||
state_dict = load_state_dict_from_bin(file_path)
|
||||
keys_dict = convert_state_dict_to_keys_dict(state_dict)
|
||||
return keys_dict
|
||||
|
||||
|
||||
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
|
||||
else:
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def hash_model_file(path, with_shape=True):
|
||||
keys_dict = load_keys_dict(path)
|
||||
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
79
diffsynth/core/loader/model.py
Normal file
79
diffsynth/core/loader/model.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from ..vram.initialization import skip_model_initialization
|
||||
from ..vram.disk_map import DiskMap
|
||||
from ..vram.layers import enable_vram_management
|
||||
from .file import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
||||
config = {} if config is None else config
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||
with skip_model_initialization():
|
||||
model = model_class(**config)
|
||||
# What is `module_map`?
|
||||
# This is a module mapping table for VRAM management.
|
||||
if module_map is not None:
|
||||
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
||||
device = [d for d in devices if d != "disk"][0]
|
||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||
dtype = [d for d in dtypes if d != "disk"][0]
|
||||
if vram_config["offload_device"] != "disk":
|
||||
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
state_dict = {i: state_dict[i] for i in state_dict}
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
||||
else:
|
||||
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
||||
else:
|
||||
# Why do we use `DiskMap`?
|
||||
# Sometimes a model file contains multiple models,
|
||||
# and DiskMap can load only the parameters of a single model,
|
||||
# avoiding the need to load all parameters in the file.
|
||||
if use_disk_map:
|
||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||
else:
|
||||
state_dict = load_state_dict(path, torch_dtype, device)
|
||||
# Why do we use `state_dict_converter`?
|
||||
# Some models are saved in complex formats,
|
||||
# and we need to convert the state dict into the appropriate format.
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
state_dict = {i: state_dict[i] for i in state_dict}
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
# Why do we call `to()`?
|
||||
# Because some models override the behavior of `to()`,
|
||||
# especially those from libraries like Transformers.
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
config = {} if config is None else config
|
||||
with skip_model_initialization():
|
||||
model = model_class(**config)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch_dtype,
|
||||
"computation_device": device,
|
||||
}
|
||||
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
||||
return model
|
||||
2
diffsynth/core/vram/__init__.py
Normal file
2
diffsynth/core/vram/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .initialization import skip_model_initialization
|
||||
from .layers import *
|
||||
93
diffsynth/core/vram/disk_map.py
Normal file
93
diffsynth/core/vram/disk_map.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from safetensors import safe_open
|
||||
import torch, os
|
||||
|
||||
|
||||
class SafetensorsCompatibleTensor:
|
||||
def __init__(self, tensor):
|
||||
self.tensor = tensor
|
||||
|
||||
def get_shape(self):
|
||||
return list(self.tensor.shape)
|
||||
|
||||
|
||||
class SafetensorsCompatibleBinaryLoader:
|
||||
def __init__(self, path, device):
|
||||
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
|
||||
self.state_dict = torch.load(path, weights_only=True, map_location=device)
|
||||
|
||||
def keys(self):
|
||||
return self.state_dict.keys()
|
||||
|
||||
def get_tensor(self, name):
|
||||
return self.state_dict[name]
|
||||
|
||||
def get_slice(self, name):
|
||||
return SafetensorsCompatibleTensor(self.state_dict[name])
|
||||
|
||||
|
||||
class DiskMap:
|
||||
|
||||
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
|
||||
self.path = path if isinstance(path, list) else [path]
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
|
||||
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
|
||||
else:
|
||||
self.buffer_size = buffer_size
|
||||
self.files = []
|
||||
self.flush_files()
|
||||
self.name_map = {}
|
||||
for file_id, file in enumerate(self.files):
|
||||
for name in file.keys():
|
||||
self.name_map[name] = file_id
|
||||
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
|
||||
|
||||
def flush_files(self):
|
||||
if len(self.files) == 0:
|
||||
for path in self.path:
|
||||
if path.endswith(".safetensors"):
|
||||
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
|
||||
else:
|
||||
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
|
||||
else:
|
||||
for i, path in enumerate(self.path):
|
||||
if path.endswith(".safetensors"):
|
||||
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
|
||||
self.num_params = 0
|
||||
|
||||
def __getitem__(self, name):
|
||||
if self.rename_dict is not None: name = self.rename_dict[name]
|
||||
file_id = self.name_map[name]
|
||||
param = self.files[file_id].get_tensor(name)
|
||||
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
|
||||
param = param.to(self.torch_dtype)
|
||||
if isinstance(param, torch.Tensor) and param.device == "cpu":
|
||||
param = param.clone()
|
||||
if isinstance(param, torch.Tensor):
|
||||
self.num_params += param.numel()
|
||||
if self.num_params > self.buffer_size:
|
||||
self.flush_files()
|
||||
return param
|
||||
|
||||
def fetch_rename_dict(self, state_dict_converter):
|
||||
if state_dict_converter is None:
|
||||
return None
|
||||
state_dict = {}
|
||||
for file in self.files:
|
||||
for name in file.keys():
|
||||
state_dict[name] = name
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
return state_dict
|
||||
|
||||
def __iter__(self):
|
||||
if self.rename_dict is not None:
|
||||
return self.rename_dict.__iter__()
|
||||
else:
|
||||
return self.name_map.__iter__()
|
||||
|
||||
def __contains__(self, x):
|
||||
if self.rename_dict is not None:
|
||||
return x in self.rename_dict
|
||||
else:
|
||||
return x in self.name_map
|
||||
21
diffsynth/core/vram/initialization.py
Normal file
21
diffsynth/core/vram/initialization.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def skip_model_initialization(device=torch.device("meta")):
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||
|
||||
old_register_parameter = torch.nn.Module.register_parameter
|
||||
torch.nn.Module.register_parameter = register_empty_parameter
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Module.register_parameter = old_register_parameter
|
||||
475
diffsynth/core/vram/layers.py
Normal file
475
diffsynth/core/vram/layers.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import torch, copy
|
||||
from typing import Union
|
||||
from .initialization import skip_model_initialization
|
||||
from .disk_map import DiskMap
|
||||
|
||||
|
||||
class AutoTorchModule(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.set_dtype_and_device(
|
||||
offload_dtype,
|
||||
offload_device,
|
||||
onload_dtype,
|
||||
onload_device,
|
||||
preparing_dtype,
|
||||
preparing_device,
|
||||
computation_dtype,
|
||||
computation_device,
|
||||
vram_limit,
|
||||
)
|
||||
self.state = 0
|
||||
self.name = ""
|
||||
|
||||
def set_dtype_and_device(
|
||||
self,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
self.offload_dtype = offload_dtype or computation_dtype
|
||||
self.offload_device = offload_device or computation_dtype
|
||||
self.onload_dtype = onload_dtype or computation_dtype
|
||||
self.onload_device = onload_device or computation_dtype
|
||||
self.preparing_dtype = preparing_dtype or computation_dtype
|
||||
self.preparing_device = preparing_device or computation_dtype
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.vram_limit = vram_limit
|
||||
|
||||
def cast_to(self, weight, dtype, device):
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def check_free_vram(self):
|
||||
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
def offload(self):
|
||||
if self.state != 0:
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state != 1:
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def param_name(self, name):
|
||||
if self.name == "":
|
||||
return name
|
||||
else:
|
||||
return self.name + "." + name
|
||||
|
||||
|
||||
class AutoWrappedModule(AutoTorchModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
name: str = "",
|
||||
disk_map: DiskMap = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
offload_dtype,
|
||||
offload_device,
|
||||
onload_dtype,
|
||||
onload_device,
|
||||
preparing_dtype,
|
||||
preparing_device,
|
||||
computation_dtype,
|
||||
computation_device,
|
||||
vram_limit,
|
||||
)
|
||||
self.module = module
|
||||
if offload_dtype == "disk":
|
||||
self.name = name
|
||||
self.disk_map = disk_map
|
||||
self.required_params = [name for name, _ in self.module.named_parameters()]
|
||||
self.disk_offload = True
|
||||
else:
|
||||
self.disk_offload = False
|
||||
|
||||
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||
if copy_module:
|
||||
module = copy.deepcopy(self.module)
|
||||
else:
|
||||
module = self.module
|
||||
state_dict = {}
|
||||
for name in self.required_params:
|
||||
param = self.disk_map[self.param_name(name)]
|
||||
param = param.to(dtype=torch_dtype, device=device)
|
||||
state_dict[name] = param
|
||||
module.load_state_dict(state_dict, assign=True)
|
||||
module.to(dtype=torch_dtype, device=device)
|
||||
return module
|
||||
|
||||
def offload_to_disk(self, model: torch.nn.Module):
|
||||
for buf in model.buffers():
|
||||
# If there are some parameters are registed in buffers (not in state dict),
|
||||
# We cannot offload the model.
|
||||
for children in model.children():
|
||||
self.offload_to_disk(children)
|
||||
break
|
||||
else:
|
||||
model.to("meta")
|
||||
|
||||
def offload(self):
|
||||
# offload / onload / preparing -> offload
|
||||
if self.state != 0:
|
||||
if self.disk_offload:
|
||||
self.offload_to_disk(self.module)
|
||||
else:
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
# offload / onload / preparing -> onload
|
||||
if self.state < 1:
|
||||
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
||||
self.load_from_disk(self.onload_dtype, self.onload_device)
|
||||
elif self.onload_device != "disk":
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def preparing(self):
|
||||
# onload / preparing -> preparing
|
||||
if self.state != 2:
|
||||
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
||||
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
||||
elif self.preparing_device != "disk":
|
||||
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||
self.state = 2
|
||||
|
||||
def cast_to(self, module, dtype, device):
|
||||
return copy.deepcopy(module).to(dtype=dtype, device=device)
|
||||
|
||||
def computation(self):
|
||||
# onload / preparing -> computation (temporary)
|
||||
if self.state == 2:
|
||||
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
||||
else:
|
||||
torch_dtype, device = self.onload_dtype, self.onload_device
|
||||
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
||||
module = self.module
|
||||
elif self.disk_offload and device == "disk":
|
||||
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
|
||||
else:
|
||||
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
|
||||
return module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
||||
self.preparing()
|
||||
module = self.computation()
|
||||
return module(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self.__dict__ or name == "module":
|
||||
return super().__getattr__(name)
|
||||
else:
|
||||
return getattr(self.module, name)
|
||||
|
||||
|
||||
class AutoWrappedNonRecurseModule(AutoWrappedModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
name: str = "",
|
||||
disk_map: DiskMap = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
module,
|
||||
offload_dtype,
|
||||
offload_device,
|
||||
onload_dtype,
|
||||
onload_device,
|
||||
preparing_dtype,
|
||||
preparing_device,
|
||||
computation_dtype,
|
||||
computation_device,
|
||||
vram_limit,
|
||||
name,
|
||||
disk_map,
|
||||
**kwargs
|
||||
)
|
||||
if self.disk_offload:
|
||||
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
|
||||
|
||||
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||
if copy_module:
|
||||
module = copy.deepcopy(self.module)
|
||||
else:
|
||||
module = self.module
|
||||
state_dict = {}
|
||||
for name in self.required_params:
|
||||
param = self.disk_map[self.param_name(name)]
|
||||
param = param.to(dtype=torch_dtype, device=device)
|
||||
state_dict[name] = param
|
||||
module.load_state_dict(state_dict, assign=True, strict=False)
|
||||
return module
|
||||
|
||||
def offload_to_disk(self, model: torch.nn.Module):
|
||||
for name in self.required_params:
|
||||
getattr(self, name).to("meta")
|
||||
|
||||
def cast_to(self, module, dtype, device):
|
||||
# Parameter casting is implemented in the model architecture.
|
||||
return module
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self.__dict__ or name == "module":
|
||||
return super().__getattr__(name)
|
||||
else:
|
||||
return getattr(self.module, name)
|
||||
|
||||
|
||||
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Linear,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
name: str = "",
|
||||
disk_map: DiskMap = None,
|
||||
**kwargs
|
||||
):
|
||||
with skip_model_initialization():
|
||||
super().__init__(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
)
|
||||
self.set_dtype_and_device(
|
||||
offload_dtype,
|
||||
offload_device,
|
||||
onload_dtype,
|
||||
onload_device,
|
||||
preparing_dtype,
|
||||
preparing_device,
|
||||
computation_dtype,
|
||||
computation_device,
|
||||
vram_limit,
|
||||
)
|
||||
self.weight = module.weight
|
||||
self.bias = module.bias
|
||||
self.state = 0
|
||||
self.name = name
|
||||
self.lora_A_weights = []
|
||||
self.lora_B_weights = []
|
||||
self.lora_merger = None
|
||||
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||
|
||||
if offload_dtype == "disk":
|
||||
self.disk_map = disk_map
|
||||
self.disk_offload = True
|
||||
else:
|
||||
self.disk_offload = False
|
||||
|
||||
def fp8_linear(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
device = input.device
|
||||
origin_dtype = input.dtype
|
||||
origin_shape = input.shape
|
||||
input = input.reshape(-1, origin_shape[-1])
|
||||
|
||||
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
||||
fp8_max = 448.0
|
||||
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
||||
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
||||
# we scale down the input by 2.0 in advance.
|
||||
# This scaling will be compensated later during the final result scaling.
|
||||
if self.computation_dtype == torch.float8_e4m3fnuz:
|
||||
fp8_max = fp8_max / 2.0
|
||||
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
||||
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
||||
input = input / (scale_a + 1e-8)
|
||||
input = input.to(self.computation_dtype)
|
||||
weight = weight.to(self.computation_dtype)
|
||||
bias = bias.to(torch.bfloat16)
|
||||
|
||||
result = torch._scaled_mm(
|
||||
input,
|
||||
weight.T,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b.T,
|
||||
bias=bias,
|
||||
out_dtype=origin_dtype,
|
||||
)
|
||||
new_shape = origin_shape[:-1] + result.shape[-1:]
|
||||
result = result.reshape(new_shape)
|
||||
return result
|
||||
|
||||
def load_from_disk(self, torch_dtype, device, assign=True):
|
||||
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
|
||||
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
|
||||
if assign:
|
||||
state_dict = {"weight": weight}
|
||||
if bias is not None: state_dict["bias"] = bias
|
||||
self.load_state_dict(state_dict, assign=True)
|
||||
return weight, bias
|
||||
|
||||
def offload(self):
|
||||
# offload / onload / preparing -> offload
|
||||
if self.state != 0:
|
||||
if self.disk_offload:
|
||||
self.to("meta")
|
||||
else:
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
# offload / onload / preparing -> onload
|
||||
if self.state < 1:
|
||||
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
||||
self.load_from_disk(self.onload_dtype, self.onload_device)
|
||||
elif self.onload_device != "disk":
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def preparing(self):
|
||||
# onload / preparing -> preparing
|
||||
if self.state != 2:
|
||||
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
||||
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
||||
elif self.preparing_device != "disk":
|
||||
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||
self.state = 2
|
||||
|
||||
def computation(self):
|
||||
# onload / preparing -> computation (temporary)
|
||||
if self.state == 2:
|
||||
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
||||
else:
|
||||
torch_dtype, device = self.onload_dtype, self.onload_device
|
||||
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
elif self.disk_offload and device == "disk":
|
||||
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
|
||||
else:
|
||||
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
return weight, bias
|
||||
|
||||
def linear_forward(self, x, weight, bias):
|
||||
if self.enable_fp8:
|
||||
out = self.fp8_linear(x, weight, bias)
|
||||
else:
|
||||
out = torch.nn.functional.linear(x, weight, bias)
|
||||
return out
|
||||
|
||||
def lora_forward(self, x, out):
|
||||
if self.lora_merger is None:
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
out = out + x @ lora_A.T @ lora_B.T
|
||||
else:
|
||||
lora_output = []
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
lora_output.append(x @ lora_A.T @ lora_B.T)
|
||||
lora_output = torch.stack(lora_output)
|
||||
out = self.lora_merger(out, lora_output)
|
||||
return out
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
||||
self.preparing()
|
||||
weight, bias = self.computation()
|
||||
out = self.linear_forward(x, weight, bias)
|
||||
if len(self.lora_A_weights) > 0:
|
||||
out = self.lora_forward(x, out)
|
||||
return out
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
|
||||
if isinstance(model, AutoWrappedNonRecurseModule):
|
||||
model = model.module
|
||||
for name, module in model.named_children():
|
||||
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
||||
for source_module, target_module in module_map.items():
|
||||
if isinstance(module, source_module):
|
||||
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
|
||||
if isinstance(module_, AutoWrappedNonRecurseModule):
|
||||
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||
setattr(model, name, module_)
|
||||
break
|
||||
else:
|
||||
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||
|
||||
|
||||
def fill_vram_config(model, vram_config):
|
||||
vram_config_ = vram_config.copy()
|
||||
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
|
||||
vram_config_["onload_device"] = vram_config["computation_device"]
|
||||
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
|
||||
vram_config_["preparing_device"] = vram_config["computation_device"]
|
||||
for k in vram_config:
|
||||
if vram_config[k] != vram_config_[k]:
|
||||
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
|
||||
break
|
||||
return vram_config_
|
||||
|
||||
|
||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
|
||||
for source_module, target_module in module_map.items():
|
||||
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
|
||||
if isinstance(model, source_module):
|
||||
vram_config = fill_vram_config(model, vram_config)
|
||||
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
||||
break
|
||||
else:
|
||||
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
||||
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
|
||||
model.vram_management_enabled = True
|
||||
return model
|
||||
Reference in New Issue
Block a user