update sd training scripts

This commit is contained in:
Artiprocher
2026-04-24 14:30:09 +08:00
parent 5cdab9ed01
commit 3799bdc23a
23 changed files with 323 additions and 612 deletions

View File

@@ -196,19 +196,14 @@ class SDUnit_InputImageEmbedder(PipelineUnit):
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
if input_image is None:
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
if pipe.scheduler.training:
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample()
latents = noise * pipe.scheduler.init_noise_sigma
return {"latents": latents, "input_latents": input_latents}
else:
# Inference mode: VAE encode input image, add noise for initial latent
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample()
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
return {"latents": latents}

View File

@@ -49,6 +49,7 @@ class StableDiffusionXLPipeline(BasePipeline):
SDXLUnit_PromptEmbedder(),
SDXLUnit_NoiseInitializer(),
SDXLUnit_InputImageEmbedder(),
SDXLUnit_AddTimeIdsComputer(),
]
self.model_fn = model_fn_stable_diffusion_xl
self.compilable_models = ["unet"]
@@ -94,20 +95,11 @@ class StableDiffusionXLPipeline(BasePipeline):
seed: int = None,
rand_device: str = "cpu",
num_inference_steps: int = 50,
eta: float = 0.0,
guidance_rescale: float = 0.0,
original_size: tuple = None,
crops_coords_top_left: tuple = (0, 0),
target_size: tuple = None,
progress_bar_cmd=tqdm,
):
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Scheduler
self.scheduler.set_timesteps(
num_inference_steps, eta=eta,
)
self.scheduler.set_timesteps(num_inference_steps)
# 2. Three-dict input preparation
inputs_posi = {
@@ -121,9 +113,7 @@ class StableDiffusionXLPipeline(BasePipeline):
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"guidance_rescale": guidance_rescale,
"original_size": original_size,
"crops_coords_top_left": crops_coords_top_left,
"target_size": target_size,
"crops_coords_top_left": (0, 0),
}
# 3. Unit chain execution
@@ -132,18 +122,7 @@ class StableDiffusionXLPipeline(BasePipeline):
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Compute add_time_ids (micro-conditioning)
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size,
dtype=self.torch_dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
neg_add_time_ids = add_time_ids.clone()
inputs_posi["add_time_ids"] = add_time_ids
inputs_nega["add_time_ids"] = neg_add_time_ids
# 5. Denoise loop
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -183,21 +162,6 @@ class StableDiffusionXLPipeline(BasePipeline):
return image
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
# SDXL UNet doesn't have a config attribute, so we access add_embedding directly
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
# addition_time_embed_dim is the dimension of each time ID projection (256 for SDXL base)
addition_time_embed_dim = self.unet.add_time_proj.num_channels
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
f"but a vector of {passed_add_embed_dim} was created."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device)
return add_time_ids
class SDXLUnit_ShapeChecker(PipelineUnit):
def __init__(self):
@@ -294,22 +258,51 @@ class SDXLUnit_InputImageEmbedder(PipelineUnit):
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
if input_image is None:
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
if pipe.scheduler.training:
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample()
latents = noise * pipe.scheduler.init_noise_sigma
return {"latents": latents, "input_latents": input_latents}
else:
# Inference mode: VAE encode input image, add noise for initial latent
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample()
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
return {"latents": latents}
class SDXLUnit_AddTimeIdsComputer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("add_time_ids",),
)
def _get_add_time_ids(self, pipe, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
addition_time_embed_dim = pipe.unet.add_time_proj.num_channels
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
f"but a vector of {passed_add_embed_dim} was created."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=pipe.device)
return add_time_ids
def process(self, pipe: StableDiffusionXLPipeline, height, width):
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
pipe, original_size, crops_coords_top_left, target_size,
dtype=pipe.torch_dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
return {"add_time_ids": add_time_ids}
def model_fn_stable_diffusion_xl(
unet: SDXLUNet2DConditionModel,
latents=None,