From 96daa30bcc80d57355dac05e3d1fc34ba47f3dc3 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 20 Nov 2025 17:44:00 +0800 Subject: [PATCH] update --- .gitignore | 1 + diffsynth/core/loader/config.py | 59 +++++++++++++------ diffsynth/core/vram/disk_map.py | 5 +- diffsynth/utils/lora/merge.py | 20 +++++++ .../model_inference/FLUX.1-dev-InfiniteYou.py | 4 +- .../FLUX.1-dev-InfiniteYou.py | 3 + pyproject.toml | 4 +- 7 files changed, 72 insertions(+), 24 deletions(-) create mode 100644 diffsynth/utils/lora/merge.py diff --git a/.gitignore b/.gitignore index ca34bed..391b448 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ /data /models +/scripts *.pkl *.safetensors *.pth diff --git a/diffsynth/core/loader/config.py b/diffsynth/core/loader/config.py index b6d3427..673ff07 100644 --- a/diffsynth/core/loader/config.py +++ b/diffsynth/core/loader/config.py @@ -28,27 +28,52 @@ class ModelConfig: if self.path is None and self.model_id is None: raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""") - def download(self): - origin_file_pattern = self.origin_file_pattern + ("*" if self.origin_file_pattern.endswith("/") else "") - downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) + def parse_original_file_pattern(self): + if self.origin_file_pattern is None or self.origin_file_pattern == "": + return "*" + elif self.origin_file_pattern.endswith("/"): + return self.origin_file_pattern + "*" + else: + return self.origin_file_pattern + + def parse_download_resource(self): if self.download_resource is None: if os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE') is not None: - self.download_resource = os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE') + return os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE') else: - self.download_resource = "modelscope" - if self.download_resource.lower() == "modelscope": + return "modelscope" + else: + return self.download_resource + + def parse_skip_download(self): + if self.skip_download is None: + if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None: + if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true": + return True + elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false": + return False + else: + return False + else: + return self.skip_download + + def download(self): + origin_file_pattern = self.parse_original_file_pattern() + downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) + download_resource = self.parse_download_resource() + if download_resource.lower() == "modelscope": snapshot_download( self.model_id, local_dir=os.path.join(self.local_model_path, self.model_id), - allow_file_pattern=self.origin_file_pattern, + allow_file_pattern=origin_file_pattern, ignore_file_pattern=downloaded_files, local_files_only=False ) - elif self.download_resource.lower() == "huggingface": + elif download_resource.lower() == "huggingface": hf_snapshot_download( self.model_id, local_dir=os.path.join(self.local_model_path, self.model_id), - allow_patterns=self.origin_file_pattern, + allow_patterns=origin_file_pattern, ignore_patterns=downloaded_files, local_files_only=False ) @@ -58,15 +83,8 @@ class ModelConfig: def require_downloading(self): if self.path is not None: return False - if self.skip_download is None: - if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None: - if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') in ["True", "true"]: - self.skip_download = True - elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') in ["False", "false"]: - self.skip_download = False - else: - self.skip_download = False - return not self.skip_download + skip_download = self.parse_skip_download() + return not skip_download def reset_local_model_path(self): if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None: @@ -79,7 +97,10 @@ class ModelConfig: self.reset_local_model_path() if self.require_downloading(): self.download() - self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) + if self.origin_file_pattern is None or self.origin_file_pattern == "": + self.path = os.path.join(self.local_model_path, self.model_id) + else: + self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) if isinstance(self.path, list) and len(self.path) == 1: self.path = self.path[0] diff --git a/diffsynth/core/vram/disk_map.py b/diffsynth/core/vram/disk_map.py index 0034990..6f0b6ea 100644 --- a/diffsynth/core/vram/disk_map.py +++ b/diffsynth/core/vram/disk_map.py @@ -60,9 +60,10 @@ class DiskMap: if self.rename_dict is not None: name = self.rename_dict[name] file_id = self.name_map[name] param = self.files[file_id].get_tensor(name) - if self.torch_dtype is not None: + if self.torch_dtype is not None and isinstance(param, torch.Tensor): param = param.to(self.torch_dtype) - self.num_params += param.numel() + if isinstance(param, torch.Tensor): + self.num_params += param.numel() if self.num_params > self.buffer_size: self.flush_files() return param diff --git a/diffsynth/utils/lora/merge.py b/diffsynth/utils/lora/merge.py new file mode 100644 index 0000000..d75ff90 --- /dev/null +++ b/diffsynth/utils/lora/merge.py @@ -0,0 +1,20 @@ +import torch +from typing import Dict, List + + +def merge_lora_weight(tensors_A, tensors_B): + lora_A = torch.concat(tensors_A, dim=0) + lora_B = torch.concat(tensors_B, dim=1) + return lora_A, lora_B + + +def merge_lora(loras: List[Dict[str, torch.Tensor]]): + lora_merged = {} + keys = [i for i in loras[0].keys() if ".lora_A." in i] + for key in keys: + tensors_A = [lora[key] for lora in loras] + tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras] + lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B) + lora_merged[key] = lora_A + lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B + return lora_merged diff --git a/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py b/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py index 2c2eb58..4491ccb 100644 --- a/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py +++ b/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py @@ -5,7 +5,9 @@ from modelscope import snapshot_download from PIL import Image import numpy as np - +# This model has additional requirements. +# Please install the following packages. +# pip install facexlib insightface onnxruntime snapshot_download( "ByteDance/InfiniteYou", allow_file_pattern="supports/insightface/models/antelopev2/*", diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py index bb393ae..c3be51c 100644 --- a/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py @@ -6,6 +6,9 @@ from PIL import Image import numpy as np +# This model has additional requirements. +# Please install the following packages. +# pip install facexlib insightface onnxruntime vram_config = { "onload_dtype": torch.float8_e4m3fn, "onload_device": "cpu", diff --git a/pyproject.toml b/pyproject.toml index 483918c..cb00b4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "2.0.0" description = "Enjoy the magic of Diffusion models!" authors = [{name = "ModelScope Team"}] license = {text = "Apache-2.0"} -requires-python = ">=3.6" +requires-python = ">=3.10" dependencies = [ "torch>=2.0.0", "torchvision", @@ -21,10 +21,10 @@ dependencies = [ "protobuf", "modelscope", "ftfy", - "pynvml", "pandas", "accelerate", "peft", + "datasets", ] classifiers = [ "Programming Language :: Python :: 3",