update template framework

This commit is contained in:
Artiprocher
2026-04-15 14:07:51 +08:00
parent 9f8c352a15
commit 59b4bbb62c
7 changed files with 85 additions and 24 deletions

View File

@@ -38,7 +38,7 @@ class GeneralUnit_TemplateProcessInputs(PipelineUnit):
self.data_processor = data_processor
def process(self, pipe, template_inputs):
if not hasattr(pipe, "template_model"):
if not hasattr(pipe, "template_model") or template_inputs is None:
return {}
if self.data_processor is not None:
template_inputs = self.data_processor(**template_inputs)
@@ -58,7 +58,7 @@ class GeneralUnit_TemplateForward(PipelineUnit):
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
def process(self, pipe, template_inputs):
if not hasattr(pipe, "template_model"):
if not hasattr(pipe, "template_model") or template_inputs is None:
return {}
template_cache = pipe.template_model.forward(
**template_inputs,