prompt processing

This commit is contained in:
Artiprocher
2024-01-21 22:36:03 +08:00
parent 22328f48ca
commit e076e66827
11 changed files with 134 additions and 75 deletions

View File

@@ -31,8 +31,6 @@ class SDImagePipeline(torch.nn.Module):
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
# load textual inversion
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
@@ -47,9 +45,8 @@ class SDImagePipeline(torch.nn.Module):
self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
@@ -59,7 +56,7 @@ class SDImagePipeline(torch.nn.Module):
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
return pipe

View File

@@ -82,8 +82,6 @@ class SDVideoPipeline(torch.nn.Module):
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
# load textual inversion
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
@@ -103,9 +101,8 @@ class SDVideoPipeline(torch.nn.Module):
self.motion_modules = model_manager.motion_modules
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
@@ -117,7 +114,7 @@ class SDVideoPipeline(torch.nn.Module):
)
pipe.fetch_main_models(model_manager)
pipe.fetch_motion_modules(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
return pipe

View File

@@ -39,9 +39,8 @@ class SDXLImagePipeline(torch.nn.Module):
pass
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
@@ -51,7 +50,7 @@ class SDXLImagePipeline(torch.nn.Module):
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
return pipe
@@ -106,7 +105,8 @@ class SDXLImagePipeline(torch.nn.Module):
self.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device
device=self.device,
positive=True,
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
@@ -114,7 +114,8 @@ class SDXLImagePipeline(torch.nn.Module):
self.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device
device=self.device,
positive=False,
)
# Prepare scheduler