diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 93ec0bd..98c737f 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -83,6 +83,8 @@ class WanTrainingModule(DiffusionTrainingModule): inputs_shared["input_image"] = data["video"][0] elif extra_input == "end_image": inputs_shared["end_image"] = data["video"][-1] + elif extra_input == "reference_image" or extra_input == "vace_reference_image": + inputs_shared[extra_input] = data[extra_input][0] else: inputs_shared[extra_input] = data[extra_input]