This commit is contained in:
Artiprocher
2025-11-15 16:47:13 +08:00
parent e3356556ee
commit ea0a5c5908
72 changed files with 481 additions and 472 deletions

View File

@@ -15,10 +15,10 @@ from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d
from ..models.wan_video_dit_s2v import rope_precompute
from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
@@ -526,13 +526,13 @@ class WanVideoUnit_FunReference(PipelineUnit):
super().__init__(
input_params=("reference_image", "height", "width", "reference_image"),
output_params=("reference_latents", "clip_feature"),
onload_model_names=("vae",)
onload_model_names=("vae", "image_encoder")
)
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
if reference_image is None:
return {}
pipe.load_models_to_device(["vae"])
pipe.load_models_to_device(self.onload_model_names)
reference_image = reference_image.resize((width, height))
reference_latents = pipe.preprocess_video([reference_image])
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)