mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
prompt processing
This commit is contained in:
@@ -31,8 +31,6 @@ class SDImagePipeline(torch.nn.Module):
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
# load textual inversion
|
||||
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
@@ -47,9 +45,8 @@ class SDImagePipeline(torch.nn.Module):
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -59,7 +56,7 @@ class SDImagePipeline(torch.nn.Module):
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -82,8 +82,6 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
# load textual inversion
|
||||
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
@@ -103,9 +101,8 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
self.motion_modules = model_manager.motion_modules
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -117,7 +114,7 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_motion_modules(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -39,9 +39,8 @@ class SDXLImagePipeline(torch.nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -51,7 +50,7 @@ class SDXLImagePipeline(torch.nn.Module):
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
@@ -106,7 +105,8 @@ class SDXLImagePipeline(torch.nn.Module):
|
||||
self.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device
|
||||
device=self.device,
|
||||
positive=True,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
@@ -114,7 +114,8 @@ class SDXLImagePipeline(torch.nn.Module):
|
||||
self.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device
|
||||
device=self.device,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
# Prepare scheduler
|
||||
|
||||
Reference in New Issue
Block a user