mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
update
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,5 +1,6 @@
|
|||||||
/data
|
/data
|
||||||
/models
|
/models
|
||||||
|
/scripts
|
||||||
*.pkl
|
*.pkl
|
||||||
*.safetensors
|
*.safetensors
|
||||||
*.pth
|
*.pth
|
||||||
|
|||||||
@@ -28,27 +28,52 @@ class ModelConfig:
|
|||||||
if self.path is None and self.model_id is None:
|
if self.path is None and self.model_id is None:
|
||||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||||
|
|
||||||
def download(self):
|
def parse_original_file_pattern(self):
|
||||||
origin_file_pattern = self.origin_file_pattern + ("*" if self.origin_file_pattern.endswith("/") else "")
|
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||||
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
return "*"
|
||||||
|
elif self.origin_file_pattern.endswith("/"):
|
||||||
|
return self.origin_file_pattern + "*"
|
||||||
|
else:
|
||||||
|
return self.origin_file_pattern
|
||||||
|
|
||||||
|
def parse_download_resource(self):
|
||||||
if self.download_resource is None:
|
if self.download_resource is None:
|
||||||
if os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE') is not None:
|
if os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE') is not None:
|
||||||
self.download_resource = os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE')
|
return os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE')
|
||||||
else:
|
else:
|
||||||
self.download_resource = "modelscope"
|
return "modelscope"
|
||||||
if self.download_resource.lower() == "modelscope":
|
else:
|
||||||
|
return self.download_resource
|
||||||
|
|
||||||
|
def parse_skip_download(self):
|
||||||
|
if self.skip_download is None:
|
||||||
|
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
|
||||||
|
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
|
||||||
|
return True
|
||||||
|
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return self.skip_download
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
origin_file_pattern = self.parse_original_file_pattern()
|
||||||
|
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||||
|
download_resource = self.parse_download_resource()
|
||||||
|
if download_resource.lower() == "modelscope":
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
self.model_id,
|
self.model_id,
|
||||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||||
allow_file_pattern=self.origin_file_pattern,
|
allow_file_pattern=origin_file_pattern,
|
||||||
ignore_file_pattern=downloaded_files,
|
ignore_file_pattern=downloaded_files,
|
||||||
local_files_only=False
|
local_files_only=False
|
||||||
)
|
)
|
||||||
elif self.download_resource.lower() == "huggingface":
|
elif download_resource.lower() == "huggingface":
|
||||||
hf_snapshot_download(
|
hf_snapshot_download(
|
||||||
self.model_id,
|
self.model_id,
|
||||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||||
allow_patterns=self.origin_file_pattern,
|
allow_patterns=origin_file_pattern,
|
||||||
ignore_patterns=downloaded_files,
|
ignore_patterns=downloaded_files,
|
||||||
local_files_only=False
|
local_files_only=False
|
||||||
)
|
)
|
||||||
@@ -58,15 +83,8 @@ class ModelConfig:
|
|||||||
def require_downloading(self):
|
def require_downloading(self):
|
||||||
if self.path is not None:
|
if self.path is not None:
|
||||||
return False
|
return False
|
||||||
if self.skip_download is None:
|
skip_download = self.parse_skip_download()
|
||||||
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
|
return not skip_download
|
||||||
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') in ["True", "true"]:
|
|
||||||
self.skip_download = True
|
|
||||||
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') in ["False", "false"]:
|
|
||||||
self.skip_download = False
|
|
||||||
else:
|
|
||||||
self.skip_download = False
|
|
||||||
return not self.skip_download
|
|
||||||
|
|
||||||
def reset_local_model_path(self):
|
def reset_local_model_path(self):
|
||||||
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
|
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
|
||||||
@@ -79,7 +97,10 @@ class ModelConfig:
|
|||||||
self.reset_local_model_path()
|
self.reset_local_model_path()
|
||||||
if self.require_downloading():
|
if self.require_downloading():
|
||||||
self.download()
|
self.download()
|
||||||
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||||
|
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||||
|
else:
|
||||||
|
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||||
if isinstance(self.path, list) and len(self.path) == 1:
|
if isinstance(self.path, list) and len(self.path) == 1:
|
||||||
self.path = self.path[0]
|
self.path = self.path[0]
|
||||||
|
|
||||||
|
|||||||
@@ -60,9 +60,10 @@ class DiskMap:
|
|||||||
if self.rename_dict is not None: name = self.rename_dict[name]
|
if self.rename_dict is not None: name = self.rename_dict[name]
|
||||||
file_id = self.name_map[name]
|
file_id = self.name_map[name]
|
||||||
param = self.files[file_id].get_tensor(name)
|
param = self.files[file_id].get_tensor(name)
|
||||||
if self.torch_dtype is not None:
|
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
|
||||||
param = param.to(self.torch_dtype)
|
param = param.to(self.torch_dtype)
|
||||||
self.num_params += param.numel()
|
if isinstance(param, torch.Tensor):
|
||||||
|
self.num_params += param.numel()
|
||||||
if self.num_params > self.buffer_size:
|
if self.num_params > self.buffer_size:
|
||||||
self.flush_files()
|
self.flush_files()
|
||||||
return param
|
return param
|
||||||
|
|||||||
20
diffsynth/utils/lora/merge.py
Normal file
20
diffsynth/utils/lora/merge.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora_weight(tensors_A, tensors_B):
|
||||||
|
lora_A = torch.concat(tensors_A, dim=0)
|
||||||
|
lora_B = torch.concat(tensors_B, dim=1)
|
||||||
|
return lora_A, lora_B
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora(loras: List[Dict[str, torch.Tensor]]):
|
||||||
|
lora_merged = {}
|
||||||
|
keys = [i for i in loras[0].keys() if ".lora_A." in i]
|
||||||
|
for key in keys:
|
||||||
|
tensors_A = [lora[key] for lora in loras]
|
||||||
|
tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras]
|
||||||
|
lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B)
|
||||||
|
lora_merged[key] = lora_A
|
||||||
|
lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B
|
||||||
|
return lora_merged
|
||||||
@@ -5,7 +5,9 @@ from modelscope import snapshot_download
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
# This model has additional requirements.
|
||||||
|
# Please install the following packages.
|
||||||
|
# pip install facexlib insightface onnxruntime
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
"ByteDance/InfiniteYou",
|
"ByteDance/InfiniteYou",
|
||||||
allow_file_pattern="supports/insightface/models/antelopev2/*",
|
allow_file_pattern="supports/insightface/models/antelopev2/*",
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ from PIL import Image
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# This model has additional requirements.
|
||||||
|
# Please install the following packages.
|
||||||
|
# pip install facexlib insightface onnxruntime
|
||||||
vram_config = {
|
vram_config = {
|
||||||
"onload_dtype": torch.float8_e4m3fn,
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
"onload_device": "cpu",
|
"onload_device": "cpu",
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ version = "2.0.0"
|
|||||||
description = "Enjoy the magic of Diffusion models!"
|
description = "Enjoy the magic of Diffusion models!"
|
||||||
authors = [{name = "ModelScope Team"}]
|
authors = [{name = "ModelScope Team"}]
|
||||||
license = {text = "Apache-2.0"}
|
license = {text = "Apache-2.0"}
|
||||||
requires-python = ">=3.6"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch>=2.0.0",
|
"torch>=2.0.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
@@ -21,10 +21,10 @@ dependencies = [
|
|||||||
"protobuf",
|
"protobuf",
|
||||||
"modelscope",
|
"modelscope",
|
||||||
"ftfy",
|
"ftfy",
|
||||||
"pynvml",
|
|
||||||
"pandas",
|
"pandas",
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"peft",
|
"peft",
|
||||||
|
"datasets",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|||||||
Reference in New Issue
Block a user