mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-08 08:58:20 +00:00
Diffusion Templates framework
This commit is contained in:
@@ -9,7 +9,7 @@ from ..utils.lora import GeneralLoRALoader
|
||||
from ..models.model_loader import ModelPool
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
from ..core.device import get_device_name, IS_NPU_AVAILABLE
|
||||
from .skills import load_skill_model, load_skill_data_processor
|
||||
from .template import load_template_model, load_template_data_processor
|
||||
|
||||
|
||||
class PipelineUnit:
|
||||
@@ -320,14 +320,21 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
# Positive side forward
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
# Negative side forward
|
||||
if inputs_shared.get("negative_only_lora", None) is not None:
|
||||
self.load_lora(self.dit, state_dict=inputs_shared["negative_only_lora"], verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
if inputs_shared.get("negative_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
|
||||
if isinstance(noise_pred_posi, tuple):
|
||||
# Separately handling different output types of latents, eg. video and audio latents.
|
||||
noise_pred = tuple(
|
||||
@@ -341,11 +348,11 @@ class BasePipeline(torch.nn.Module):
|
||||
return noise_pred
|
||||
|
||||
|
||||
def load_training_skill_model(self, model_config: ModelConfig = None):
|
||||
def load_training_template_model(self, model_config: ModelConfig = None):
|
||||
if model_config is not None:
|
||||
model_config.download_if_necessary()
|
||||
self.skill_model = load_skill_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
self.skill_data_processor = load_skill_data_processor(model_config.path)()
|
||||
self.template_model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
self.template_data_processor = load_template_data_processor(model_config.path)()
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user