Diffusion Templates framework

This commit is contained in:
Artiprocher
2026-04-08 15:25:33 +08:00
parent f88b99cb4f
commit 9f8c352a15
10 changed files with 526 additions and 241 deletions

View File

@@ -28,38 +28,45 @@ class GeneralUnit_RemoveCache(PipelineUnit):
return inputs_shared, inputs_posi, inputs_nega
class GeneralUnit_SkillProcessInputs(PipelineUnit):
class GeneralUnit_TemplateProcessInputs(PipelineUnit):
# Only used for training
def __init__(self, data_processor):
super().__init__(
input_params=("skill_inputs",),
output_params=("skill_inputs",),
input_params=("template_inputs",),
output_params=("template_inputs",),
)
self.data_processor = data_processor
def process(self, pipe, skill_inputs):
if not hasattr(pipe, "skill_model"):
def process(self, pipe, template_inputs):
if not hasattr(pipe, "template_model"):
return {}
if self.data_processor is not None:
skill_inputs = self.data_processor(**skill_inputs)
skill_inputs = pipe.skill_model.process_inputs(pipe=pipe, **skill_inputs)
return {"skill_inputs": skill_inputs}
template_inputs = self.data_processor(**template_inputs)
template_inputs = pipe.template_model.process_inputs(pipe=pipe, **template_inputs)
return {"template_inputs": template_inputs}
class GeneralUnit_SkillForward(PipelineUnit):
class GeneralUnit_TemplateForward(PipelineUnit):
# Only used for training
def __init__(self):
def __init__(self, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
super().__init__(
input_params=("skill_inputs",),
output_params=("skill_cache",),
onload_model_names=("skill_model",)
input_params=("template_inputs",),
output_params=("kv_cache",),
onload_model_names=("template_model",)
)
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
def process(self, pipe, skill_inputs):
if not hasattr(pipe, "skill_model"):
def process(self, pipe, template_inputs):
if not hasattr(pipe, "template_model"):
return {}
skill_cache = pipe.skill_model.forward(**skill_inputs)
return {"skill_cache": skill_cache}
template_cache = pipe.template_model.forward(
**template_inputs,
pipe=pipe,
use_gradient_checkpointing=self.use_gradient_checkpointing,
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload,
)
return template_cache
class DiffusionTrainingModule(torch.nn.Module):
@@ -246,13 +253,13 @@ class DiffusionTrainingModule(torch.nn.Module):
return lora_target_modules
def load_training_skill_model(self, pipe, path_or_model_id):
def load_training_template_model(self, pipe, path_or_model_id, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
if path_or_model_id is None:
return pipe
model_config = self.parse_path_or_model_id(path_or_model_id)
pipe.load_training_skill_model(model_config)
pipe.units.append(GeneralUnit_SkillProcessInputs(pipe.skill_data_processor))
pipe.units.append(GeneralUnit_SkillForward())
pipe.load_training_template_model(model_config)
pipe.units.append(GeneralUnit_TemplateProcessInputs(pipe.template_data_processor))
pipe.units.append(GeneralUnit_TemplateForward(use_gradient_checkpointing, use_gradient_checkpointing_offload))
return pipe