mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
diffusion skills framework
This commit is contained in:
@@ -6,6 +6,7 @@ from peft import LoraConfig, inject_adapter_in_model
|
||||
|
||||
|
||||
class GeneralUnit_RemoveCache(PipelineUnit):
|
||||
# Only used for training
|
||||
def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
|
||||
super().__init__(take_over=True)
|
||||
self.required_params = required_params
|
||||
@@ -27,6 +28,40 @@ class GeneralUnit_RemoveCache(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class GeneralUnit_SkillProcessInputs(PipelineUnit):
|
||||
# Only used for training
|
||||
def __init__(self, data_processor):
|
||||
super().__init__(
|
||||
input_params=("skill_inputs",),
|
||||
output_params=("skill_inputs",),
|
||||
)
|
||||
self.data_processor = data_processor
|
||||
|
||||
def process(self, pipe, skill_inputs):
|
||||
if not hasattr(pipe, "skill_model"):
|
||||
return {}
|
||||
if self.data_processor is not None:
|
||||
skill_inputs = self.data_processor(**skill_inputs)
|
||||
skill_inputs = pipe.skill_model.process_inputs(pipe=pipe, **skill_inputs)
|
||||
return {"skill_inputs": skill_inputs}
|
||||
|
||||
|
||||
class GeneralUnit_SkillForward(PipelineUnit):
|
||||
# Only used for training
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("skill_inputs",),
|
||||
output_params=("skill_cache",),
|
||||
onload_model_names=("skill_model",)
|
||||
)
|
||||
|
||||
def process(self, pipe, skill_inputs):
|
||||
if not hasattr(pipe, "skill_model"):
|
||||
return {}
|
||||
skill_cache = pipe.skill_model.forward(**skill_inputs)
|
||||
return {"skill_cache": skill_cache}
|
||||
|
||||
|
||||
class DiffusionTrainingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -209,6 +244,16 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
else:
|
||||
lora_target_modules = lora_target_modules.split(",")
|
||||
return lora_target_modules
|
||||
|
||||
|
||||
def load_training_skill_model(self, pipe, path_or_model_id):
|
||||
if path_or_model_id is None:
|
||||
return pipe
|
||||
model_config = self.parse_path_or_model_id(path_or_model_id)
|
||||
pipe.load_training_skill_model(model_config)
|
||||
pipe.units.append(GeneralUnit_SkillProcessInputs(pipe.skill_data_processor))
|
||||
pipe.units.append(GeneralUnit_SkillForward())
|
||||
return pipe
|
||||
|
||||
|
||||
def switch_pipe_to_training_mode(
|
||||
|
||||
Reference in New Issue
Block a user