mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
add downloader
This commit is contained in:
14
README.md
14
README.md
@@ -50,18 +50,10 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
Create Python environment:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
conda env create -f environment.yml
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
```
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
|
|
||||||
|
|
||||||
Enter the Python environment:
|
|
||||||
|
|
||||||
```
|
|
||||||
conda activate DiffSynthStudio
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage (in Python code)
|
## Usage (in Python code)
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
import torch, os
|
import torch, os
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from .downloader import download_from_huggingface, download_from_modelscope
|
||||||
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
from .sd_text_encoder import SDTextEncoder
|
||||||
from .sd_unet import SDUNet
|
from .sd_unet import SDUNet
|
||||||
@@ -29,13 +33,89 @@ from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5Tex
|
|||||||
from .hunyuan_dit import HunyuanDiT
|
from .hunyuan_dit import HunyuanDiT
|
||||||
|
|
||||||
|
|
||||||
|
preset_models_on_huggingface = {
|
||||||
|
"HunyuanDiT": [
|
||||||
|
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||||
|
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||||
|
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||||
|
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||||
|
],
|
||||||
|
"stable-video-diffusion-img2vid-xt": [
|
||||||
|
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||||
|
],
|
||||||
|
"ExVideo-SVD-128f-v1": [
|
||||||
|
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
preset_models_on_modelscope = {
|
||||||
|
"HunyuanDiT": [
|
||||||
|
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||||
|
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||||
|
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||||
|
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||||
|
],
|
||||||
|
"stable-video-diffusion-img2vid-xt": [
|
||||||
|
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||||
|
],
|
||||||
|
"ExVideo-SVD-128f-v1": [
|
||||||
|
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
Preset_model_id: TypeAlias = Literal[
|
||||||
|
"HunyuanDiT",
|
||||||
|
"stable-video-diffusion-img2vid-xt",
|
||||||
|
"ExVideo-SVD-128f-v1"
|
||||||
|
]
|
||||||
|
Preset_model_website: TypeAlias = Literal[
|
||||||
|
"HuggingFace",
|
||||||
|
"ModelScope",
|
||||||
|
]
|
||||||
|
website_to_preset_models = {
|
||||||
|
"HuggingFace": preset_models_on_huggingface,
|
||||||
|
"ModelScope": preset_models_on_modelscope,
|
||||||
|
}
|
||||||
|
website_to_download_fn = {
|
||||||
|
"HuggingFace": download_from_huggingface,
|
||||||
|
"ModelScope": download_from_modelscope,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device="cuda",
|
||||||
|
model_id_list: List[Preset_model_id] = [],
|
||||||
|
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||||
|
file_path_list: List[str] = [],
|
||||||
|
):
|
||||||
self.torch_dtype = torch_dtype
|
self.torch_dtype = torch_dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = {}
|
self.model = {}
|
||||||
self.model_path = {}
|
self.model_path = {}
|
||||||
self.textual_inversion_dict = {}
|
self.textual_inversion_dict = {}
|
||||||
|
downloaded_files = self.download_models(model_id_list, downloading_priority)
|
||||||
|
self.load_models(downloaded_files + file_path_list)
|
||||||
|
|
||||||
|
def download_models(
|
||||||
|
self,
|
||||||
|
model_id_list: List[Preset_model_id] = [],
|
||||||
|
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||||
|
):
|
||||||
|
downloaded_files = []
|
||||||
|
for model_id in model_id_list:
|
||||||
|
for website in downloading_priority:
|
||||||
|
if model_id in website_to_preset_models[website]:
|
||||||
|
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
||||||
|
# Check if the file is downloaded.
|
||||||
|
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||||
|
if file_to_download in downloaded_files:
|
||||||
|
continue
|
||||||
|
# Download
|
||||||
|
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||||
|
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||||
|
downloaded_files.append(file_to_download)
|
||||||
|
return downloaded_files
|
||||||
|
|
||||||
def is_stable_video_diffusion(self, state_dict):
|
def is_stable_video_diffusion(self, state_dict):
|
||||||
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
|
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
|
||||||
|
|||||||
278
diffsynth/models/downloader.py
Normal file
278
diffsynth/models/downloader.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from http.cookiejar import CookieJar
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, List, Union
|
||||||
|
import copy, uuid, requests, io, platform, pickle, os, urllib
|
||||||
|
from requests.adapters import Retry
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sep(path):
|
||||||
|
if isinstance(path, bytes):
|
||||||
|
return b'/'
|
||||||
|
else:
|
||||||
|
return '/'
|
||||||
|
|
||||||
|
|
||||||
|
def expanduser(path):
|
||||||
|
"""Expand ~ and ~user constructions. If user or $HOME is unknown,
|
||||||
|
do nothing."""
|
||||||
|
path = os.fspath(path)
|
||||||
|
if isinstance(path, bytes):
|
||||||
|
tilde = b'~'
|
||||||
|
else:
|
||||||
|
tilde = '~'
|
||||||
|
if not path.startswith(tilde):
|
||||||
|
return path
|
||||||
|
sep = _get_sep(path)
|
||||||
|
i = path.find(sep, 1)
|
||||||
|
if i < 0:
|
||||||
|
i = len(path)
|
||||||
|
if i == 1:
|
||||||
|
if 'HOME' not in os.environ:
|
||||||
|
import pwd
|
||||||
|
try:
|
||||||
|
userhome = pwd.getpwuid(os.getuid()).pw_dir
|
||||||
|
except KeyError:
|
||||||
|
# bpo-10496: if the current user identifier doesn't exist in the
|
||||||
|
# password database, return the path unchanged
|
||||||
|
return path
|
||||||
|
else:
|
||||||
|
userhome = os.environ['HOME']
|
||||||
|
else:
|
||||||
|
import pwd
|
||||||
|
name = path[1:i]
|
||||||
|
if isinstance(name, bytes):
|
||||||
|
name = str(name, 'ASCII')
|
||||||
|
try:
|
||||||
|
pwent = pwd.getpwnam(name)
|
||||||
|
except KeyError:
|
||||||
|
# bpo-10496: if the user name from the path doesn't exist in the
|
||||||
|
# password database, return the path unchanged
|
||||||
|
return path
|
||||||
|
userhome = pwent.pw_dir
|
||||||
|
if isinstance(path, bytes):
|
||||||
|
userhome = os.fsencode(userhome)
|
||||||
|
root = b'/'
|
||||||
|
else:
|
||||||
|
root = '/'
|
||||||
|
userhome = userhome.rstrip(root)
|
||||||
|
return (userhome + path[i:]) or root
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ModelScopeConfig:
|
||||||
|
DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
|
||||||
|
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||||
|
COOKIES_FILE_NAME = 'cookies'
|
||||||
|
GIT_TOKEN_FILE_NAME = 'git_token'
|
||||||
|
USER_INFO_FILE_NAME = 'user'
|
||||||
|
USER_SESSION_ID_FILE_NAME = 'session'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_sure_credential_path_exist():
|
||||||
|
os.makedirs(ModelScopeConfig.path_credential, exist_ok=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_session_id():
|
||||||
|
session_path = os.path.join(ModelScopeConfig.path_credential,
|
||||||
|
ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
|
||||||
|
session_id = ''
|
||||||
|
if os.path.exists(session_path):
|
||||||
|
with open(session_path, 'rb') as f:
|
||||||
|
session_id = str(f.readline().strip(), encoding='utf-8')
|
||||||
|
return session_id
|
||||||
|
if session_id == '' or len(session_id) != 32:
|
||||||
|
session_id = str(uuid.uuid4().hex)
|
||||||
|
ModelScopeConfig.make_sure_credential_path_exist()
|
||||||
|
with open(session_path, 'w+') as wf:
|
||||||
|
wf.write(session_id)
|
||||||
|
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
|
||||||
|
"""Formats a user-agent string with basic info about a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_agent (`str`, `dict`, *optional*):
|
||||||
|
The user agent info in the form of a dictionary or a single string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The formatted user-agent string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# include some more telemetrics when executing in dedicated
|
||||||
|
# cloud containers
|
||||||
|
MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
|
||||||
|
MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
|
||||||
|
env = 'custom'
|
||||||
|
if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
|
||||||
|
env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
|
||||||
|
user_name = 'unknown'
|
||||||
|
if MODELSCOPE_CLOUD_USERNAME in os.environ:
|
||||||
|
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
|
||||||
|
|
||||||
|
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
|
||||||
|
"1.15.0",
|
||||||
|
platform.python_version(),
|
||||||
|
ModelScopeConfig.get_user_session_id(),
|
||||||
|
platform.platform(),
|
||||||
|
platform.processor(),
|
||||||
|
env,
|
||||||
|
user_name,
|
||||||
|
)
|
||||||
|
if isinstance(user_agent, dict):
|
||||||
|
ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items())
|
||||||
|
elif isinstance(user_agent, str):
|
||||||
|
ua += '; ' + user_agent
|
||||||
|
return ua
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_cookies():
|
||||||
|
cookies_path = os.path.join(ModelScopeConfig.path_credential,
|
||||||
|
ModelScopeConfig.COOKIES_FILE_NAME)
|
||||||
|
if os.path.exists(cookies_path):
|
||||||
|
with open(cookies_path, 'rb') as f:
|
||||||
|
cookies = pickle.load(f)
|
||||||
|
return cookies
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def modelscope_http_get_model_file(
|
||||||
|
url: str,
|
||||||
|
local_dir: str,
|
||||||
|
file_name: str,
|
||||||
|
file_size: int,
|
||||||
|
cookies: CookieJar,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
"""Download remote file, will retry 5 times before giving up on errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url(str):
|
||||||
|
actual download url of the file
|
||||||
|
local_dir(str):
|
||||||
|
local directory where the downloaded file stores
|
||||||
|
file_name(str):
|
||||||
|
name of the file stored in `local_dir`
|
||||||
|
file_size(int):
|
||||||
|
The file size.
|
||||||
|
cookies(CookieJar):
|
||||||
|
cookies used to authentication the user, which is used for downloading private repos
|
||||||
|
headers(Dict[str, str], optional):
|
||||||
|
http headers to carry necessary info when requesting the remote file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileDownloadError: File download failed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||||
|
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||||
|
temp_file_path = os.path.join(local_dir, file_name)
|
||||||
|
# retry sleep 0.5s, 1s, 2s, 4s
|
||||||
|
retry = Retry(
|
||||||
|
total=5,
|
||||||
|
backoff_factor=1,
|
||||||
|
allowed_methods=['GET'])
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
progress = tqdm(
|
||||||
|
unit='B',
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
total=file_size,
|
||||||
|
initial=0,
|
||||||
|
desc='Downloading',
|
||||||
|
)
|
||||||
|
partial_length = 0
|
||||||
|
if os.path.exists(
|
||||||
|
temp_file_path): # download partial, continue download
|
||||||
|
with open(temp_file_path, 'rb') as f:
|
||||||
|
partial_length = f.seek(0, io.SEEK_END)
|
||||||
|
progress.update(partial_length)
|
||||||
|
if partial_length > file_size:
|
||||||
|
break
|
||||||
|
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
||||||
|
file_size - 1)
|
||||||
|
with open(temp_file_path, 'ab') as f:
|
||||||
|
r = requests.get(
|
||||||
|
url,
|
||||||
|
stream=True,
|
||||||
|
headers=get_headers,
|
||||||
|
cookies=cookies,
|
||||||
|
timeout=60)
|
||||||
|
r.raise_for_status()
|
||||||
|
for chunk in r.iter_content(
|
||||||
|
chunk_size=1024 * 1024 * 1):
|
||||||
|
if chunk: # filter out keep-alive new chunks
|
||||||
|
progress.update(len(chunk))
|
||||||
|
f.write(chunk)
|
||||||
|
progress.close()
|
||||||
|
break
|
||||||
|
except (Exception) as e: # no matter what happen, we will retry.
|
||||||
|
retry = retry.increment('GET', url, error=e)
|
||||||
|
retry.sleep()
|
||||||
|
|
||||||
|
|
||||||
|
def get_endpoint():
|
||||||
|
MODELSCOPE_URL_SCHEME = 'https://'
|
||||||
|
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
|
||||||
|
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
|
||||||
|
DEFAULT_MODELSCOPE_DOMAIN)
|
||||||
|
return MODELSCOPE_URL_SCHEME + modelscope_domain
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_download_url(model_id: str, file_path: str, revision: str):
|
||||||
|
"""Format file download url according to `model_id`, `revision` and `file_path`.
|
||||||
|
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
|
||||||
|
the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id (str): The model_id.
|
||||||
|
file_path (str): File path
|
||||||
|
revision (str): File revision.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The file url.
|
||||||
|
"""
|
||||||
|
file_path = urllib.parse.quote_plus(file_path)
|
||||||
|
revision = urllib.parse.quote_plus(revision)
|
||||||
|
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
|
||||||
|
return download_url_template.format(
|
||||||
|
endpoint=get_endpoint(),
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
file_path=file_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||||
|
print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||||
|
headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None)}
|
||||||
|
cookies = ModelScopeConfig.get_cookies()
|
||||||
|
url = get_file_download_url(model_id=model_id, file_path=origin_file_path, revision="master")
|
||||||
|
modelscope_http_get_model_file(
|
||||||
|
url,
|
||||||
|
local_dir,
|
||||||
|
os.path.basename(origin_file_path),
|
||||||
|
file_size=0,
|
||||||
|
headers=headers,
|
||||||
|
cookies=cookies
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||||
|
print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||||
|
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||||
@@ -1,14 +1,20 @@
|
|||||||
from .utils import Prompter
|
from .utils import Prompter
|
||||||
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
|
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
|
||||||
import warnings
|
import warnings, os
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTPrompter(Prompter):
|
class HunyuanDiTPrompter(Prompter):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer_path="configs/hunyuan_dit/tokenizer",
|
tokenizer_path=None,
|
||||||
tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5"
|
tokenizer_t5_path=None
|
||||||
):
|
):
|
||||||
|
if tokenizer_path is None:
|
||||||
|
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
|
||||||
|
if tokenizer_t5_path is None:
|
||||||
|
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
from .utils import Prompter, tokenize_long_prompt
|
from .utils import Prompter, tokenize_long_prompt
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from ..models import SDTextEncoder
|
from ..models import SDTextEncoder
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class SDPrompter(Prompter):
|
class SDPrompter(Prompter):
|
||||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
def __init__(self, tokenizer_path=None):
|
||||||
|
if tokenizer_path is None:
|
||||||
|
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,21 @@
|
|||||||
from .utils import Prompter, tokenize_long_prompt
|
from .utils import Prompter, tokenize_long_prompt
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2
|
from ..models import SDXLTextEncoder, SDXLTextEncoder2
|
||||||
import torch
|
import torch, os
|
||||||
|
|
||||||
|
|
||||||
class SDXLPrompter(Prompter):
|
class SDXLPrompter(Prompter):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer_path="configs/stable_diffusion/tokenizer",
|
tokenizer_path=None,
|
||||||
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
|
tokenizer_2_path=None
|
||||||
):
|
):
|
||||||
|
if tokenizer_path is None:
|
||||||
|
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
|
||||||
|
if tokenizer_2_path is None:
|
||||||
|
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def tokenize_long_prompt(tokenizer, prompt):
|
|||||||
|
|
||||||
|
|
||||||
class BeautifulPrompt:
|
class BeautifulPrompt:
|
||||||
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
|
def __init__(self, tokenizer_path=None, model=None):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||||
@@ -62,7 +62,7 @@ class BeautifulPrompt:
|
|||||||
|
|
||||||
|
|
||||||
class Translator:
|
class Translator:
|
||||||
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
|
def __init__(self, tokenizer_path=None, model=None):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from diffsynth import save_video, ModelManager, SVDVideoPipeline, HunyuanDiTImag
|
|||||||
from diffsynth import ModelManager
|
from diffsynth import ModelManager
|
||||||
import torch, os
|
import torch, os
|
||||||
|
|
||||||
|
# The models will be downloaded automatically.
|
||||||
|
# You can also use the following urls to download them manually.
|
||||||
|
|
||||||
# Download models (from Huggingface)
|
# Download models (from Huggingface)
|
||||||
# Text-to-image model:
|
# Text-to-image model:
|
||||||
@@ -14,7 +16,6 @@ import torch, os
|
|||||||
# ExVideo extension blocks:
|
# ExVideo extension blocks:
|
||||||
# `models/stable_video_diffusion/model.fp16.safetensors`: [link](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1/resolve/main/model.fp16.safetensors)
|
# `models/stable_video_diffusion/model.fp16.safetensors`: [link](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1/resolve/main/model.fp16.safetensors)
|
||||||
|
|
||||||
|
|
||||||
# Download models (from Modelscope)
|
# Download models (from Modelscope)
|
||||||
# Text-to-image model:
|
# Text-to-image model:
|
||||||
# `models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin`: [link](https://www.modelscope.cn/api/v1/models/modelscope/HunyuanDiT/repo?Revision=master&FilePath=t2i%2Fclip_text_encoder%2Fpytorch_model.bin)
|
# `models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin`: [link](https://www.modelscope.cn/api/v1/models/modelscope/HunyuanDiT/repo?Revision=master&FilePath=t2i%2Fclip_text_encoder%2Fpytorch_model.bin)
|
||||||
@@ -30,13 +31,7 @@ import torch, os
|
|||||||
def generate_image():
|
def generate_image():
|
||||||
# Load models
|
# Load models
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", model_id_list=["HunyuanDiT"])
|
||||||
model_manager.load_models([
|
|
||||||
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
|
|
||||||
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
|
|
||||||
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
|
|
||||||
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
|
|
||||||
])
|
|
||||||
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
# Generate an image
|
# Generate an image
|
||||||
@@ -46,16 +41,13 @@ def generate_image():
|
|||||||
negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,",
|
negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,",
|
||||||
num_inference_steps=50, height=1024, width=1024,
|
num_inference_steps=50, height=1024, width=1024,
|
||||||
)
|
)
|
||||||
|
model_manager.to("cpu")
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def generate_video(image):
|
def generate_video(image):
|
||||||
# Load models
|
# Load models
|
||||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", model_id_list=["stable-video-diffusion-img2vid-xt", "ExVideo-SVD-128f-v1"])
|
||||||
model_manager.load_models([
|
|
||||||
"models/stable_video_diffusion/svd_xt.safetensors",
|
|
||||||
"models/stable_video_diffusion/model.fp16.safetensors"
|
|
||||||
])
|
|
||||||
pipe = SVDVideoPipeline.from_model_manager(model_manager)
|
pipe = SVDVideoPipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
# Generate a video
|
# Generate a video
|
||||||
@@ -67,16 +59,13 @@ def generate_video(image):
|
|||||||
num_inference_steps=50,
|
num_inference_steps=50,
|
||||||
min_cfg_scale=2, max_cfg_scale=2, contrast_enhance_scale=1.2
|
min_cfg_scale=2, max_cfg_scale=2, contrast_enhance_scale=1.2
|
||||||
)
|
)
|
||||||
|
model_manager.to("cpu")
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
def upscale_video(image, video):
|
def upscale_video(image, video):
|
||||||
# Load models
|
# Load models
|
||||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", model_id_list=["stable-video-diffusion-img2vid-xt", "ExVideo-SVD-128f-v1"])
|
||||||
model_manager.load_models([
|
|
||||||
"models/stable_video_diffusion/svd_xt.safetensors",
|
|
||||||
"models/stable_video_diffusion/model.fp16.safetensors",
|
|
||||||
])
|
|
||||||
pipe = SVDVideoPipeline.from_model_manager(model_manager)
|
pipe = SVDVideoPipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
# Generate a video
|
# Generate a video
|
||||||
@@ -89,19 +78,20 @@ def upscale_video(image, video):
|
|||||||
num_inference_steps=25,
|
num_inference_steps=25,
|
||||||
min_cfg_scale=2, max_cfg_scale=2, contrast_enhance_scale=1.2
|
min_cfg_scale=2, max_cfg_scale=2, contrast_enhance_scale=1.2
|
||||||
)
|
)
|
||||||
|
model_manager.to("cpu")
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
# We use Hunyuan DiT to generate the first frame.
|
# We use Hunyuan DiT to generate the first frame. 10GB VRAM is required.
|
||||||
# If you want to use your own image,
|
# If you want to use your own image,
|
||||||
# please use `image = Image.open("your_image_file.png")` to replace the following code.
|
# please use `image = Image.open("your_image_file.png")` to replace the following code.
|
||||||
image = generate_image()
|
image = generate_image()
|
||||||
image.save("image.png")
|
image.save("image.png")
|
||||||
|
|
||||||
# Now, generate a video with resolution of 512.
|
# Now, generate a video with resolution of 512. 20GB VRAM is required.
|
||||||
video = generate_video(image)
|
video = generate_video(image)
|
||||||
save_video(video, "video_512.mp4", fps=30)
|
save_video(video, "video_512.mp4", fps=30)
|
||||||
|
|
||||||
# Upscale the video.
|
# Upscale the video. 52GB VRAM is required.
|
||||||
video = upscale_video(image, video)
|
video = upscale_video(image, video)
|
||||||
save_video(video, "video_1024.mp4", fps=30)
|
save_video(video, "video_1024.mp4", fps=30)
|
||||||
|
|||||||
12
requirements.txt
Normal file
12
requirements.txt
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
torch>=2.0.0
|
||||||
|
cupy-cuda12x
|
||||||
|
pip
|
||||||
|
transformers
|
||||||
|
controlnet-aux==0.0.7
|
||||||
|
streamlit
|
||||||
|
streamlit-drawable-canvas
|
||||||
|
imageio
|
||||||
|
imageio[ffmpeg]
|
||||||
|
safetensors
|
||||||
|
einops
|
||||||
|
sentencepiece
|
||||||
20
setup.py
Normal file
20
setup.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import os
|
||||||
|
import pkg_resources
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="diffsynth",
|
||||||
|
py_modules=["diffsynth"],
|
||||||
|
version="1.0.0",
|
||||||
|
description="",
|
||||||
|
author="Artiprocher",
|
||||||
|
packages=find_packages(exclude=["diffsynth"]),
|
||||||
|
install_requires=[
|
||||||
|
str(r)
|
||||||
|
for r in pkg_resources.parse_requirements(
|
||||||
|
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||||
|
)
|
||||||
|
],
|
||||||
|
include_package_data=True
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user