Diffusion Templates framework

This commit is contained in:
Artiprocher
2026-04-08 15:25:33 +08:00
parent f88b99cb4f
commit 9f8c352a15
10 changed files with 526 additions and 241 deletions

View File

@@ -9,7 +9,7 @@ from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
from ..core.device import get_device_name, IS_NPU_AVAILABLE
from .skills import load_skill_model, load_skill_data_processor
from .template import load_template_model, load_template_data_processor
class PipelineUnit:
@@ -320,14 +320,21 @@ class BasePipeline(torch.nn.Module):
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
# Positive side forward
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
# Negative side forward
if inputs_shared.get("negative_only_lora", None) is not None:
self.load_lora(self.dit, state_dict=inputs_shared["negative_only_lora"], verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
if inputs_shared.get("negative_only_lora", None) is not None:
self.clear_lora(verbose=0)
if isinstance(noise_pred_posi, tuple):
# Separately handling different output types of latents, eg. video and audio latents.
noise_pred = tuple(
@@ -341,11 +348,11 @@ class BasePipeline(torch.nn.Module):
return noise_pred
def load_training_skill_model(self, model_config: ModelConfig = None):
def load_training_template_model(self, model_config: ModelConfig = None):
if model_config is not None:
model_config.download_if_necessary()
self.skill_model = load_skill_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
self.skill_data_processor = load_skill_data_processor(model_config.path)()
self.template_model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
self.template_data_processor = load_template_data_processor(model_config.path)()