update doc

This commit is contained in:
Artiprocher
2025-11-06 20:35:35 +08:00
parent 6a6eca7baf
commit 74f8181f93
16 changed files with 433 additions and 40 deletions

View File

@@ -3,7 +3,6 @@ import imageio.v3 as iio
from PIL import Image
class DataProcessingPipeline:
def __init__(self, operators=None):
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
@@ -19,7 +18,6 @@ class DataProcessingPipeline:
return DataProcessingPipeline(self.operators + pipe.operators)
class DataProcessingOperator:
def __call__(self, data):
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
@@ -30,25 +28,21 @@ class DataProcessingOperator:
return DataProcessingPipeline([self]).__rshift__(pipe)
class DataProcessingOperatorRaw(DataProcessingOperator):
def __call__(self, data):
return data
class ToInt(DataProcessingOperator):
def __call__(self, data):
return int(data)
class ToFloat(DataProcessingOperator):
def __call__(self, data):
return float(data)
class ToStr(DataProcessingOperator):
def __init__(self, none_value=""):
self.none_value = none_value
@@ -58,7 +52,6 @@ class ToStr(DataProcessingOperator):
return str(data)
class LoadImage(DataProcessingOperator):
def __init__(self, convert_RGB=True):
self.convert_RGB = convert_RGB
@@ -69,9 +62,8 @@ class LoadImage(DataProcessingOperator):
return image
class ImageCropAndResize(DataProcessingOperator):
def __init__(self, height, width, max_pixels, height_division_factor, width_division_factor):
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
self.height = height
self.width = width
self.max_pixels = max_pixels
@@ -101,19 +93,16 @@ class ImageCropAndResize(DataProcessingOperator):
height, width = self.height, self.width
return height, width
def __call__(self, data: Image.Image):
image = self.crop_and_resize(data, *self.get_height_width(data))
return image
class ToList(DataProcessingOperator):
def __call__(self, data):
return [data]
class LoadVideo(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
@@ -143,7 +132,6 @@ class LoadVideo(DataProcessingOperator):
return frames
class SequencialProcess(DataProcessingOperator):
def __init__(self, operator=lambda x: x):
self.operator = operator
@@ -152,7 +140,6 @@ class SequencialProcess(DataProcessingOperator):
return [self.operator(i) for i in data]
class LoadGIF(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
@@ -181,7 +168,6 @@ class LoadGIF(DataProcessingOperator):
if len(frames) >= num_frames:
break
return frames
class RouteByExtensionName(DataProcessingOperator):
@@ -196,7 +182,6 @@ class RouteByExtensionName(DataProcessingOperator):
raise ValueError(f"Unsupported file: {data}")
class RouteByType(DataProcessingOperator):
def __init__(self, operator_map):
self.operator_map = operator_map
@@ -208,7 +193,6 @@ class RouteByType(DataProcessingOperator):
raise ValueError(f"Unsupported data: {data}")
class LoadTorchPickle(DataProcessingOperator):
def __init__(self, map_location="cpu"):
self.map_location = map_location
@@ -217,7 +201,6 @@ class LoadTorchPickle(DataProcessingOperator):
return torch.load(data, map_location=self.map_location, weights_only=False)
class ToAbsolutePath(DataProcessingOperator):
def __init__(self, base_path=""):
self.base_path = base_path
@@ -225,3 +208,11 @@ class ToAbsolutePath(DataProcessingOperator):
def __call__(self, data):
return os.path.join(self.base_path, data)
class LoadAudio(DataProcessingOperator):
def __init__(self, sr=16000):
self.sr = sr
def __call__(self, data: str):
import librosa
input_audio, sample_rate = librosa.load(data, sr=self.sr)
return input_audio

View File

@@ -402,17 +402,3 @@ def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
model.vram_management_enabled = True
return model
def reset_vram_config(model: torch.nn.Module, vram_config: dict, vram_limit=None):
disk_map = None
for module in model.modules():
if isinstance(module, AutoTorchModule):
module.set_dtype_and_device(**vram_config, vram_limit=vram_limit)
if hasattr(module, "disk_map") and getattr(module, "disk_map") is not None:
disk_map = getattr(module, "disk_map")
if disk_map is not None:
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
device = [d for d in devices if d != "disk"][0]
disk_map.device = device
disk_map.flush_files()