mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
training framework
This commit is contained in:
@@ -1,34 +1,26 @@
|
||||
import torch, warnings, glob
|
||||
import torch, warnings, glob, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import repeat, reduce
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
import types
|
||||
from ..models import ModelManager
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import WanPrompter
|
||||
import torch, os
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||
from ..models import ModelManager
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from ..prompters import WanPrompter
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
|
||||
|
||||
|
||||
@@ -50,6 +42,16 @@ class BasePipeline(torch.nn.Module):
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
self.vram_management_enabled = False
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width, num_frames=None):
|
||||
@@ -135,8 +137,20 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
warnings.warn("enable_cpu_offload is deprecated. This feature is automatically enabled if offload_device != device")
|
||||
|
||||
warnings.warn("`enable_cpu_offload` is deprecated. Please use `enable_vram_management`.")
|
||||
|
||||
|
||||
def get_free_vram(self):
|
||||
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||
allocated_memory = torch.cuda.device_memory_used(self.device)
|
||||
return (total_memory - allocated_memory) / (1024 ** 3)
|
||||
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
for name, model in self.named_children():
|
||||
if name not in model_names:
|
||||
model.eval()
|
||||
model.requires_grad_(False)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -146,17 +160,19 @@ class ModelConfig:
|
||||
origin_file_pattern: Union[str, list[str]] = None
|
||||
download_resource: str = "ModelScope"
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
quantization_dtype: Optional[torch.dtype] = None
|
||||
offload_dtype: Optional[torch.dtype] = None
|
||||
|
||||
def download_if_necessary(self, local_model_path="./models", skip_download=False):
|
||||
if self.path is None:
|
||||
if self.model_id is None or self.origin_file_pattern is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
||||
if not skip_download:
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(local_model_path, self.model_id),
|
||||
allow_file_pattern=self.origin_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
|
||||
@@ -195,10 +211,36 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
]
|
||||
self.model_fn = model_fn_wan_video
|
||||
|
||||
|
||||
def train(self):
|
||||
super().train()
|
||||
self.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
||||
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
|
||||
|
||||
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
||||
self.vram_management_enabled = True
|
||||
if num_persistent_param_in_dit is not None:
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_free_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
if self.text_encoder is not None:
|
||||
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
@@ -217,9 +259,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.dit is not None:
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
@@ -233,7 +277,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
@@ -246,6 +290,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae is not None:
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
@@ -304,6 +349,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
),
|
||||
)
|
||||
if self.vace is not None:
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.vace,
|
||||
module_map = {
|
||||
@@ -316,10 +362,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -330,8 +377,23 @@ class WanVideoPipeline(BasePipeline):
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
||||
local_model_path: str = "./models",
|
||||
skip_download: bool = False
|
||||
skip_download: bool = False,
|
||||
redirect_common_files: bool = True,
|
||||
):
|
||||
# Redirect model path
|
||||
if redirect_common_files:
|
||||
redirect_dict = {
|
||||
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
|
||||
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
|
||||
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
|
||||
}
|
||||
for model_config in model_configs:
|
||||
if model_config.origin_file_pattern is None or model_config.model_id is None:
|
||||
continue
|
||||
if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
|
||||
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
|
||||
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
|
||||
|
||||
# Download and load models
|
||||
model_manager = ModelManager()
|
||||
for model_config in model_configs:
|
||||
@@ -339,7 +401,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
model_manager.load_model(
|
||||
model_config.path,
|
||||
device=model_config.offload_device or device,
|
||||
torch_dtype=model_config.quantization_dtype or torch_dtype
|
||||
torch_dtype=model_config.offload_dtype or torch_dtype
|
||||
)
|
||||
|
||||
# Initialize pipeline
|
||||
@@ -356,63 +418,54 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe.prompter.fetch_models(pipe.text_encoder)
|
||||
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
||||
return pipe
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
|
||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = "",
|
||||
# Image-to-video
|
||||
input_image=None,
|
||||
input_image: Optional[Image.Image] = None,
|
||||
# First-last-frame-to-video
|
||||
end_image=None,
|
||||
end_image: Optional[Image.Image] = None,
|
||||
# Video-to-video
|
||||
input_video=None,
|
||||
denoising_strength=1.0,
|
||||
input_video: Optional[list[Image.Image]] = None,
|
||||
denoising_strength: Optional[float] = 1.0,
|
||||
# ControlNet
|
||||
control_video=None,
|
||||
reference_image=None,
|
||||
control_video: Optional[list[Image.Image]] = None,
|
||||
reference_image: Optional[Image.Image] = None,
|
||||
# VACE
|
||||
vace_video=None,
|
||||
vace_video_mask=None,
|
||||
vace_reference_image=None,
|
||||
vace_scale=1.0,
|
||||
vace_video: Optional[list[Image.Image]] = None,
|
||||
vace_video_mask: Optional[Image.Image] = None,
|
||||
vace_reference_image: Optional[Image.Image] = None,
|
||||
vace_scale: Optional[float] = 1.0,
|
||||
# Randomness
|
||||
seed=None,
|
||||
rand_device="cpu",
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
# Shape
|
||||
height=480,
|
||||
width=832,
|
||||
height: Optional[int] = 480,
|
||||
width: Optional[int] = 832,
|
||||
num_frames=81,
|
||||
# Classifier-free guidance
|
||||
cfg_scale=5.0,
|
||||
cfg_merge=False,
|
||||
cfg_scale: Optional[float] = 5.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Scheduler
|
||||
num_inference_steps=50,
|
||||
sigma_shift=5.0,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
sigma_shift: Optional[float] = 5.0,
|
||||
# Speed control
|
||||
motion_bucket_id=None,
|
||||
motion_bucket_id: Optional[int] = None,
|
||||
# VAE tiling
|
||||
tiled=True,
|
||||
tile_size=(30, 52),
|
||||
tile_stride=(15, 26),
|
||||
tiled: Optional[bool] = True,
|
||||
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
||||
# Sliding window
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
# Teacache
|
||||
tea_cache_l1_thresh=None,
|
||||
tea_cache_model_id="",
|
||||
tea_cache_l1_thresh: Optional[float] = None,
|
||||
tea_cache_model_id: Optional[str] = "",
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
@@ -452,12 +505,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = model_fn_wan_video(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
||||
if cfg_scale != 1.0:
|
||||
if cfg_merge:
|
||||
noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
|
||||
else:
|
||||
noise_pred_nega = model_fn_wan_video(**models, **inputs_shared, **inputs_nega, timestep=timestep)
|
||||
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
@@ -467,7 +520,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
# VACE (TODO: remove it)
|
||||
if vace_reference_image is not None:
|
||||
latents = latents[:, :, 1:]
|
||||
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
@@ -558,18 +611,21 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride"),
|
||||
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride):
|
||||
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength):
|
||||
if input_video is None:
|
||||
return {"latents": noise}
|
||||
pipe.load_models_to_device(["vae"])
|
||||
input_video = pipe.preprocess_video(input_video)
|
||||
latents = pipe.encode_video(input_video, tiled, tile_size, tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
latents = pipe.scheduler.add_noise(latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents}
|
||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
|
||||
@@ -639,7 +695,7 @@ class WanVideoUnit_FunControl(PipelineUnit):
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
control_video = pipe.preprocess_video(control_video)
|
||||
control_latents = pipe.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if clip_feature is None or y is None:
|
||||
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
@@ -678,7 +734,7 @@ class WanVideoUnit_SpeedControl(PipelineUnit):
|
||||
def process(self, pipe: WanVideoPipeline, motion_bucket_id):
|
||||
if motion_bucket_id is None:
|
||||
return {}
|
||||
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
|
||||
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"motion_bucket_id": motion_bucket_id}
|
||||
|
||||
|
||||
@@ -703,18 +759,16 @@ class WanVideoUnit_VACE(PipelineUnit):
|
||||
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
else:
|
||||
vace_video = pipe.preprocess_video(vace_video)
|
||||
vace_video = torch.stack(vace_video, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
if vace_mask is None:
|
||||
vace_mask = torch.ones_like(vace_video)
|
||||
else:
|
||||
vace_mask = pipe.preprocess_video(vace_mask)
|
||||
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
|
||||
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
|
||||
inactive = pipe.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
reactive = pipe.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
||||
|
||||
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
||||
@@ -724,8 +778,7 @@ class WanVideoUnit_VACE(PipelineUnit):
|
||||
pass
|
||||
else:
|
||||
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
||||
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_reference_latents = pipe.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
||||
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
|
||||
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
|
||||
@@ -894,6 +947,7 @@ def model_fn_wan_video(
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
cfg_merge: bool = False,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||
@@ -978,8 +1032,20 @@ def model_fn_wan_video(
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
||||
if tea_cache is not None:
|
||||
|
||||
Reference in New Issue
Block a user