mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
update sd training scripts
This commit is contained in:
@@ -196,19 +196,14 @@ class SDUnit_InputImageEmbedder(PipelineUnit):
|
||||
|
||||
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
|
||||
return {"latents": noise}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||
if pipe.scheduler.training:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
||||
latents = noise * pipe.scheduler.init_noise_sigma
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
else:
|
||||
# Inference mode: VAE encode input image, add noise for initial latent
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
SDXLUnit_PromptEmbedder(),
|
||||
SDXLUnit_NoiseInitializer(),
|
||||
SDXLUnit_InputImageEmbedder(),
|
||||
SDXLUnit_AddTimeIdsComputer(),
|
||||
]
|
||||
self.model_fn = model_fn_stable_diffusion_xl
|
||||
self.compilable_models = ["unet"]
|
||||
@@ -94,20 +95,11 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
num_inference_steps: int = 50,
|
||||
eta: float = 0.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: tuple = None,
|
||||
crops_coords_top_left: tuple = (0, 0),
|
||||
target_size: tuple = None,
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Scheduler
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps, eta=eta,
|
||||
)
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# 2. Three-dict input preparation
|
||||
inputs_posi = {
|
||||
@@ -121,9 +113,7 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"guidance_rescale": guidance_rescale,
|
||||
"original_size": original_size,
|
||||
"crops_coords_top_left": crops_coords_top_left,
|
||||
"target_size": target_size,
|
||||
"crops_coords_top_left": (0, 0),
|
||||
}
|
||||
|
||||
# 3. Unit chain execution
|
||||
@@ -132,18 +122,7 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||
)
|
||||
|
||||
# 4. Compute add_time_ids (micro-conditioning)
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size,
|
||||
dtype=self.torch_dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
neg_add_time_ids = add_time_ids.clone()
|
||||
inputs_posi["add_time_ids"] = add_time_ids
|
||||
inputs_nega["add_time_ids"] = neg_add_time_ids
|
||||
|
||||
# 5. Denoise loop
|
||||
# 4. Denoise loop
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
@@ -183,21 +162,6 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
|
||||
return image
|
||||
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
# SDXL UNet doesn't have a config attribute, so we access add_embedding directly
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
# addition_time_embed_dim is the dimension of each time ID projection (256 for SDXL base)
|
||||
addition_time_embed_dim = self.unet.add_time_proj.num_channels
|
||||
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
|
||||
f"but a vector of {passed_add_embed_dim} was created."
|
||||
)
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device)
|
||||
return add_time_ids
|
||||
|
||||
|
||||
class SDXLUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
@@ -294,22 +258,51 @@ class SDXLUnit_InputImageEmbedder(PipelineUnit):
|
||||
|
||||
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
|
||||
return {"latents": noise}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||
if pipe.scheduler.training:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
||||
latents = noise * pipe.scheduler.init_noise_sigma
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
else:
|
||||
# Inference mode: VAE encode input image, add noise for initial latent
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
class SDXLUnit_AddTimeIdsComputer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width"),
|
||||
output_params=("add_time_ids",),
|
||||
)
|
||||
|
||||
def _get_add_time_ids(self, pipe, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
|
||||
addition_time_embed_dim = pipe.unet.add_time_proj.num_channels
|
||||
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
|
||||
f"but a vector of {passed_add_embed_dim} was created."
|
||||
)
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=pipe.device)
|
||||
return add_time_ids
|
||||
|
||||
def process(self, pipe: StableDiffusionXLPipeline, height, width):
|
||||
original_size = (height, width)
|
||||
target_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
|
||||
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
pipe, original_size, crops_coords_top_left, target_size,
|
||||
dtype=pipe.torch_dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
return {"add_time_ids": add_time_ids}
|
||||
|
||||
|
||||
def model_fn_stable_diffusion_xl(
|
||||
unet: SDXLUNet2DConditionModel,
|
||||
latents=None,
|
||||
|
||||
Reference in New Issue
Block a user