mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-08 08:58:20 +00:00
Diffusion Templates framework
This commit is contained in:
@@ -28,38 +28,45 @@ class GeneralUnit_RemoveCache(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class GeneralUnit_SkillProcessInputs(PipelineUnit):
|
||||
class GeneralUnit_TemplateProcessInputs(PipelineUnit):
|
||||
# Only used for training
|
||||
def __init__(self, data_processor):
|
||||
super().__init__(
|
||||
input_params=("skill_inputs",),
|
||||
output_params=("skill_inputs",),
|
||||
input_params=("template_inputs",),
|
||||
output_params=("template_inputs",),
|
||||
)
|
||||
self.data_processor = data_processor
|
||||
|
||||
def process(self, pipe, skill_inputs):
|
||||
if not hasattr(pipe, "skill_model"):
|
||||
def process(self, pipe, template_inputs):
|
||||
if not hasattr(pipe, "template_model"):
|
||||
return {}
|
||||
if self.data_processor is not None:
|
||||
skill_inputs = self.data_processor(**skill_inputs)
|
||||
skill_inputs = pipe.skill_model.process_inputs(pipe=pipe, **skill_inputs)
|
||||
return {"skill_inputs": skill_inputs}
|
||||
template_inputs = self.data_processor(**template_inputs)
|
||||
template_inputs = pipe.template_model.process_inputs(pipe=pipe, **template_inputs)
|
||||
return {"template_inputs": template_inputs}
|
||||
|
||||
|
||||
class GeneralUnit_SkillForward(PipelineUnit):
|
||||
class GeneralUnit_TemplateForward(PipelineUnit):
|
||||
# Only used for training
|
||||
def __init__(self):
|
||||
def __init__(self, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
||||
super().__init__(
|
||||
input_params=("skill_inputs",),
|
||||
output_params=("skill_cache",),
|
||||
onload_model_names=("skill_model",)
|
||||
input_params=("template_inputs",),
|
||||
output_params=("kv_cache",),
|
||||
onload_model_names=("template_model",)
|
||||
)
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
|
||||
def process(self, pipe, skill_inputs):
|
||||
if not hasattr(pipe, "skill_model"):
|
||||
def process(self, pipe, template_inputs):
|
||||
if not hasattr(pipe, "template_model"):
|
||||
return {}
|
||||
skill_cache = pipe.skill_model.forward(**skill_inputs)
|
||||
return {"skill_cache": skill_cache}
|
||||
template_cache = pipe.template_model.forward(
|
||||
**template_inputs,
|
||||
pipe=pipe,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload,
|
||||
)
|
||||
return template_cache
|
||||
|
||||
|
||||
class DiffusionTrainingModule(torch.nn.Module):
|
||||
@@ -246,13 +253,13 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
return lora_target_modules
|
||||
|
||||
|
||||
def load_training_skill_model(self, pipe, path_or_model_id):
|
||||
def load_training_template_model(self, pipe, path_or_model_id, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
||||
if path_or_model_id is None:
|
||||
return pipe
|
||||
model_config = self.parse_path_or_model_id(path_or_model_id)
|
||||
pipe.load_training_skill_model(model_config)
|
||||
pipe.units.append(GeneralUnit_SkillProcessInputs(pipe.skill_data_processor))
|
||||
pipe.units.append(GeneralUnit_SkillForward())
|
||||
pipe.load_training_template_model(model_config)
|
||||
pipe.units.append(GeneralUnit_TemplateProcessInputs(pipe.template_data_processor))
|
||||
pipe.units.append(GeneralUnit_TemplateForward(use_gradient_checkpointing, use_gradient_checkpointing_offload))
|
||||
return pipe
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user