mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 06:46:13 +00:00
sd and sdxl training
This commit is contained in:
@@ -87,9 +87,7 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_2: str = None,
|
||||
negative_prompt: str = "",
|
||||
negative_prompt_2: str = None,
|
||||
cfg_scale: float = 5.0,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
@@ -103,8 +101,6 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
target_size: tuple = None,
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
prompt_2 = prompt_2 or prompt
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
@@ -116,11 +112,9 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
# 2. Three-dict input preparation
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
"prompt_2": prompt_2,
|
||||
}
|
||||
inputs_nega = {
|
||||
"prompt": negative_prompt,
|
||||
"prompt_2": negative_prompt_2,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
@@ -221,8 +215,8 @@ class SDXLUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt", "prompt_2": "prompt_2"},
|
||||
input_params_nega={"prompt": "prompt", "prompt_2": "prompt_2"},
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "prompt"},
|
||||
output_params=("prompt_embeds", "pooled_prompt_embeds"),
|
||||
onload_model_names=("text_encoder", "text_encoder_2")
|
||||
)
|
||||
@@ -231,10 +225,9 @@ class SDXLUnit_PromptEmbedder(PipelineUnit):
|
||||
self,
|
||||
pipe: StableDiffusionXLPipeline,
|
||||
prompt: str,
|
||||
prompt_2: str,
|
||||
device: torch.device,
|
||||
) -> tuple:
|
||||
"""Encode prompt using both text encoders.
|
||||
"""Encode prompt using both text encoders (same prompt for both).
|
||||
|
||||
Returns (prompt_embeds, pooled_prompt_embeds):
|
||||
- prompt_embeds: concat(encoder1_output, encoder2_output) -> (B, 77, 2048)
|
||||
@@ -254,7 +247,7 @@ class SDXLUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
# Text Encoder 2 (CLIP-bigG, 1280-dim) — uses penultimate hidden states + pooled
|
||||
text_input_ids_2 = pipe.tokenizer_2(
|
||||
prompt_2,
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=pipe.tokenizer_2.model_max_length,
|
||||
truncation=True,
|
||||
@@ -270,9 +263,9 @@ class SDXLUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def process(self, pipe: StableDiffusionXLPipeline, prompt, prompt_2):
|
||||
def process(self, pipe: StableDiffusionXLPipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(pipe, prompt, prompt_2, pipe.device)
|
||||
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
|
||||
|
||||
|
||||
@@ -294,14 +287,27 @@ class SDXLUnit_NoiseInitializer(PipelineUnit):
|
||||
class SDXLUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("noise",),
|
||||
output_params=("latents",),
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",),
|
||||
)
|
||||
|
||||
def process(self, pipe: StableDiffusionXLPipeline, noise):
|
||||
# For Text-to-Image, latents = noise (scaled by scheduler)
|
||||
latents = noise * pipe.scheduler.init_noise_sigma
|
||||
return {"latents": latents}
|
||||
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
|
||||
if pipe.scheduler.training:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
||||
latents = noise * pipe.scheduler.init_noise_sigma
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
else:
|
||||
# Inference mode: VAE encode input image, add noise for initial latent
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
def model_fn_stable_diffusion_xl(
|
||||
|
||||
Reference in New Issue
Block a user