mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
refine code
This commit is contained in:
@@ -426,7 +426,7 @@ class ModelManager:
|
||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
@@ -440,12 +440,25 @@ class ModelManager:
|
||||
return None
|
||||
if len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
||||
model = fetched_models[0]
|
||||
path = fetched_model_paths[0]
|
||||
else:
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
if index is None:
|
||||
model = fetched_models[0]
|
||||
path = fetched_model_paths[0]
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
elif isinstance(index, int):
|
||||
model = fetched_models[:index]
|
||||
path = fetched_model_paths[:index]
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.")
|
||||
else:
|
||||
model = fetched_models
|
||||
path = fetched_model_paths
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.")
|
||||
if require_model_path:
|
||||
return fetched_models[0], fetched_model_paths[0]
|
||||
return model, path
|
||||
else:
|
||||
return fetched_models[0]
|
||||
return model
|
||||
|
||||
|
||||
def to(self, device):
|
||||
|
||||
@@ -288,6 +288,9 @@ class WanModel(torch.nn.Module):
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
seperated_timestep: bool = False,
|
||||
require_vae_embedding: bool = True,
|
||||
require_clip_embedding: bool = True,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -295,6 +298,9 @@ class WanModel(torch.nn.Module):
|
||||
self.has_image_input = has_image_input
|
||||
self.patch_size = patch_size
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.require_vae_embedding = require_vae_embedding
|
||||
self.require_clip_embedding = require_clip_embedding
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
@@ -352,7 +358,6 @@ class WanModel(torch.nn.Module):
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
fused_y: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
@@ -366,8 +371,6 @@ class WanModel(torch.nn.Module):
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
if fused_y is not None:
|
||||
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
|
||||
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
@@ -690,6 +693,9 @@ class WanModelStateDictConverter:
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"seperated_timestep": True,
|
||||
"require_clip_embedding": False,
|
||||
"require_vae_embedding": False,
|
||||
"fuse_vae_embedding_in_latents": True,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
|
||||
# Wan-AI/Wan2.2-I2V-A14B
|
||||
@@ -705,6 +711,7 @@ class WanModelStateDictConverter:
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
|
||||
Reference in New Issue
Block a user