refine code

This commit is contained in:
Artiprocher
2025-07-28 11:00:54 +08:00
parent bba44173d2
commit 5f68727ad3
22 changed files with 474 additions and 86 deletions

View File

@@ -426,7 +426,7 @@ class ModelManager:
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
@@ -440,12 +440,25 @@ class ModelManager:
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}.")
model = fetched_models[0]
path = fetched_model_paths[0]
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
if index is None:
model = fetched_models[0]
path = fetched_model_paths[0]
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
elif isinstance(index, int):
model = fetched_models[:index]
path = fetched_model_paths[:index]
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.")
else:
model = fetched_models
path = fetched_model_paths
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
return model, path
else:
return fetched_models[0]
return model
def to(self, device):

View File

@@ -288,6 +288,9 @@ class WanModel(torch.nn.Module):
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
seperated_timestep: bool = False,
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
):
super().__init__()
self.dim = dim
@@ -295,6 +298,9 @@ class WanModel(torch.nn.Module):
self.has_image_input = has_image_input
self.patch_size = patch_size
self.seperated_timestep = seperated_timestep
self.require_vae_embedding = require_vae_embedding
self.require_clip_embedding = require_clip_embedding
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
@@ -352,7 +358,6 @@ class WanModel(torch.nn.Module):
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
fused_y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
@@ -366,8 +371,6 @@ class WanModel(torch.nn.Module):
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
if fused_y is not None:
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
x, (f, h, w) = self.patchify(x)
@@ -690,6 +693,9 @@ class WanModelStateDictConverter:
"num_layers": 30,
"eps": 1e-6,
"seperated_timestep": True,
"require_clip_embedding": False,
"require_vae_embedding": False,
"fuse_vae_embedding_in_latents": True,
}
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
# Wan-AI/Wan2.2-I2V-A14B
@@ -705,6 +711,7 @@ class WanModelStateDictConverter:
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"require_clip_embedding": False,
}
else:
config = {}