mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
update sd training scripts
This commit is contained in:
@@ -902,7 +902,7 @@ mova_series = [
|
|||||||
]
|
]
|
||||||
stable_diffusion_xl_series = [
|
stable_diffusion_xl_series = [
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
|
||||||
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
|
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
|
||||||
"model_name": "stable_diffusion_xl_unet",
|
"model_name": "stable_diffusion_xl_unet",
|
||||||
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
|
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
|
||||||
@@ -916,21 +916,21 @@ stable_diffusion_xl_series = [
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
|
||||||
"model_hash": "98cc34ccc5b54ae0e56bdea8688dcd5a",
|
"model_hash": "98cc34ccc5b54ae0e56bdea8688dcd5a",
|
||||||
"model_name": "stable_diffusion_xl_text_encoder",
|
"model_name": "stable_diffusion_xl_text_encoder",
|
||||||
"model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2",
|
"model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors")
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors")
|
||||||
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
||||||
"model_name": "stable_diffusion_text_encoder",
|
"model_name": "stable_diffusion_text_encoder",
|
||||||
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
|
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
"model_hash": "13115dd45a6e1c39860f91ab073b8a78",
|
"model_hash": "13115dd45a6e1c39860f91ab073b8a78",
|
||||||
"model_name": "stable_diffusion_xl_vae",
|
"model_name": "stable_diffusion_xl_vae",
|
||||||
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
|
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
|
||||||
|
|||||||
@@ -1,269 +1,107 @@
|
|||||||
import torch, math
|
import torch, math
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
|
|
||||||
class DDIMScheduler:
|
class DDIMScheduler():
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
|
||||||
self,
|
|
||||||
num_train_timesteps: int = 1000,
|
|
||||||
beta_start: float = 0.00085,
|
|
||||||
beta_end: float = 0.012,
|
|
||||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear",
|
|
||||||
clip_sample: bool = False,
|
|
||||||
set_alpha_to_one: bool = False,
|
|
||||||
steps_offset: int = 1,
|
|
||||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
|
||||||
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
|
|
||||||
rescale_betas_zero_snr: bool = False,
|
|
||||||
):
|
|
||||||
self.num_train_timesteps = num_train_timesteps
|
self.num_train_timesteps = num_train_timesteps
|
||||||
self.beta_start = beta_start
|
if beta_schedule == "scaled_linear":
|
||||||
self.beta_end = beta_end
|
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
||||||
self.beta_schedule = beta_schedule
|
elif beta_schedule == "linear":
|
||||||
self.clip_sample = clip_sample
|
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||||
self.set_alpha_to_one = set_alpha_to_one
|
|
||||||
self.steps_offset = steps_offset
|
|
||||||
self.prediction_type = prediction_type
|
|
||||||
self.timestep_spacing = timestep_spacing
|
|
||||||
|
|
||||||
# Compute betas
|
|
||||||
if beta_schedule == "linear":
|
|
||||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
|
||||||
elif beta_schedule == "scaled_linear":
|
|
||||||
# SD 1.5 specific: sqrt-linear interpolation
|
|
||||||
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
|
||||||
elif beta_schedule == "squaredcos_cap_v2":
|
|
||||||
self.betas = self._betas_for_alpha_bar(num_train_timesteps)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported beta_schedule: {beta_schedule}")
|
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||||
|
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
|
||||||
# Rescale for zero SNR
|
if rescale_zero_terminal_snr:
|
||||||
if rescale_betas_zero_snr:
|
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
|
||||||
self.betas = self._rescale_zero_terminal_snr(self.betas)
|
self.alphas_cumprod = self.alphas_cumprod.tolist()
|
||||||
|
self.set_timesteps(10)
|
||||||
self.alphas = 1.0 - self.betas
|
self.prediction_type = prediction_type
|
||||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
|
||||||
|
|
||||||
# For the final step, there is no previous alphas_cumprod
|
|
||||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
|
||||||
|
|
||||||
# standard deviation of the initial noise distribution
|
|
||||||
self.init_noise_sigma = 1.0
|
|
||||||
|
|
||||||
# Setable values (will be populated by set_timesteps)
|
|
||||||
self.num_inference_steps = None
|
|
||||||
self.timesteps = torch.from_numpy(self._default_timesteps().astype("int64"))
|
|
||||||
self.training = False
|
self.training = False
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta: float = 0.999) -> torch.Tensor:
|
|
||||||
"""Create beta schedule via cosine alpha_bar function."""
|
|
||||||
def alpha_bar_fn(t):
|
|
||||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
|
||||||
|
|
||||||
betas = []
|
def rescale_zero_terminal_snr(self, alphas_cumprod):
|
||||||
for i in range(num_diffusion_timesteps):
|
|
||||||
t1 = i / num_diffusion_timesteps
|
|
||||||
t2 = (i + 1) / num_diffusion_timesteps
|
|
||||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
|
||||||
return torch.tensor(betas, dtype=torch.float32)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Rescale betas to have zero terminal SNR."""
|
|
||||||
alphas = 1.0 - betas
|
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
||||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
alphas_bar = alphas_bar_sqrt ** 2
|
|
||||||
alphas = torch.cat([alphas_bar[1:], alphas_bar[:1]])
|
|
||||||
return 1 - alphas
|
|
||||||
|
|
||||||
def _default_timesteps(self):
|
# Convert alphas_bar_sqrt to betas
|
||||||
"""Default timesteps before set_timesteps is called."""
|
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
|
||||||
import numpy as np
|
|
||||||
return np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64)
|
|
||||||
|
|
||||||
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
|
return alphas_bar
|
||||||
"""Compute the variance for the DDIM step."""
|
|
||||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
|
||||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
|
||||||
beta_prod_t = 1 - alpha_prod_t
|
|
||||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
|
||||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
|
||||||
return variance
|
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps: int = 100, denoising_strength: float = 1.0, training: bool = False, **kwargs):
|
|
||||||
"""
|
|
||||||
Sets the discrete timesteps used for the diffusion chain.
|
|
||||||
Follows FlowMatchScheduler interface: (num_inference_steps, denoising_strength, training, **kwargs)
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
if denoising_strength != 1.0:
|
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, training=False, **kwargs):
|
||||||
# For img2img: adjust effective steps
|
# The timesteps are aligned to 999...0, which is different from other implementations,
|
||||||
num_inference_steps = int(num_inference_steps * denoising_strength)
|
# but I think this implementation is more reasonable in theory.
|
||||||
|
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
||||||
# Compute step ratio
|
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
||||||
if self.timestep_spacing == "leading":
|
if num_inference_steps == 1:
|
||||||
# leading: arange * step_ratio, reverse, then add offset
|
self.timesteps = torch.Tensor([max_timestep])
|
||||||
step_ratio = self.num_train_timesteps // num_inference_steps
|
|
||||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
|
|
||||||
timesteps = timesteps + self.steps_offset
|
|
||||||
elif self.timestep_spacing == "trailing":
|
|
||||||
# trailing: timesteps = arange(num_steps, 0, -1) * step_ratio - 1
|
|
||||||
step_ratio = self.num_train_timesteps / num_inference_steps
|
|
||||||
timesteps = (np.arange(num_inference_steps, 0, -1) * step_ratio - 1).round()[::-1]
|
|
||||||
elif self.timestep_spacing == "linspace":
|
|
||||||
# linspace: evenly spaced from num_train_timesteps - 1 to 0
|
|
||||||
timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps).round()[::-1]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported timestep_spacing: {self.timestep_spacing}")
|
step_length = max_timestep / (num_inference_steps - 1)
|
||||||
|
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
|
||||||
|
self.training = training
|
||||||
|
|
||||||
# Clamp timesteps to valid range [0, num_train_timesteps - 1]
|
|
||||||
timesteps = np.clip(timesteps, 0, self.num_train_timesteps - 1)
|
|
||||||
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.int64)
|
|
||||||
self.num_inference_steps = num_inference_steps
|
|
||||||
|
|
||||||
if training:
|
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
||||||
self.set_training_weight()
|
|
||||||
self.training = True
|
|
||||||
else:
|
|
||||||
self.training = False
|
|
||||||
|
|
||||||
def set_training_weight(self):
|
|
||||||
"""Set timestep weights for training (similar to FlowMatchScheduler)."""
|
|
||||||
steps = 1000
|
|
||||||
x = self.timesteps
|
|
||||||
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
|
||||||
y_shifted = y - y.min()
|
|
||||||
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
|
||||||
if len(self.timesteps) != 1000:
|
|
||||||
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
|
||||||
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
|
||||||
self.linear_timesteps_weights = bsmntw_weighing
|
|
||||||
|
|
||||||
def step(self, model_output, timestep, sample, to_final: bool = False, eta: float = 0.0, **kwargs):
|
|
||||||
"""
|
|
||||||
DDIM step function.
|
|
||||||
Follows FlowMatchScheduler interface: step(model_output, timestep, sample, to_final=False)
|
|
||||||
|
|
||||||
For SD 1.5, prediction_type="epsilon" and eta=0.0 (deterministic DDIM).
|
|
||||||
"""
|
|
||||||
if isinstance(timestep, torch.Tensor):
|
|
||||||
timestep = timestep.cpu()
|
|
||||||
if timestep.dim() == 0:
|
|
||||||
timestep = timestep.item()
|
|
||||||
elif timestep.dim() == 1:
|
|
||||||
timestep = timestep[0].item()
|
|
||||||
|
|
||||||
# Ensure timestep is int
|
|
||||||
timestep = int(timestep)
|
|
||||||
|
|
||||||
# Find the index of the current timestep
|
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep).abs()).item()
|
|
||||||
|
|
||||||
if timestep_id + 1 >= len(self.timesteps):
|
|
||||||
prev_timestep = -1
|
|
||||||
else:
|
|
||||||
prev_timestep = self.timesteps[timestep_id + 1].item()
|
|
||||||
|
|
||||||
# Get alphas
|
|
||||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
|
||||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
|
||||||
|
|
||||||
alpha_prod_t = alpha_prod_t.to(device=sample.device, dtype=sample.dtype)
|
|
||||||
alpha_prod_t_prev = alpha_prod_t_prev.to(device=sample.device, dtype=sample.dtype)
|
|
||||||
|
|
||||||
beta_prod_t = 1 - alpha_prod_t
|
|
||||||
|
|
||||||
# Compute predicted original sample (x_0)
|
|
||||||
if self.prediction_type == "epsilon":
|
if self.prediction_type == "epsilon":
|
||||||
pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
||||||
elif self.prediction_type == "sample":
|
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
||||||
pred_original_sample = model_output
|
prev_sample = sample * weight_x + model_output * weight_e
|
||||||
elif self.prediction_type == "v_prediction":
|
elif self.prediction_type == "v_prediction":
|
||||||
pred_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
|
||||||
|
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
|
||||||
|
prev_sample = sample * weight_x + model_output * weight_e
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported prediction_type: {self.prediction_type}")
|
raise NotImplementedError(f"{self.prediction_type} is not implemented")
|
||||||
|
|
||||||
# Clip sample if needed
|
|
||||||
if self.clip_sample:
|
|
||||||
pred_original_sample = pred_original_sample.clamp(-1.0, 1.0)
|
|
||||||
|
|
||||||
# Compute predicted noise (re-derived from x_0)
|
|
||||||
pred_epsilon = (sample - alpha_prod_t.sqrt() * pred_original_sample) / beta_prod_t.sqrt()
|
|
||||||
|
|
||||||
# DDIM formula: prev_sample = sqrt(alpha_prev) * x0 + sqrt(1 - alpha_prev) * epsilon
|
|
||||||
prev_sample = alpha_prod_t_prev.sqrt() * pred_original_sample + (1 - alpha_prod_t_prev).sqrt() * pred_epsilon
|
|
||||||
|
|
||||||
# Add variance noise if eta > 0 (DDIM: eta=0, DDPM: eta=1)
|
|
||||||
if eta > 0:
|
|
||||||
variance = self._get_variance(timestep, prev_timestep)
|
|
||||||
variance = variance.to(device=sample.device, dtype=sample.dtype)
|
|
||||||
std_dev_t = eta * variance.sqrt()
|
|
||||||
device = sample.device
|
|
||||||
noise = torch.randn_like(sample)
|
|
||||||
prev_sample = prev_sample + std_dev_t * noise
|
|
||||||
|
|
||||||
return prev_sample
|
return prev_sample
|
||||||
|
|
||||||
def add_noise(self, original_samples, noise, timestep):
|
|
||||||
"""Add noise to original samples (forward diffusion).
|
def step(self, model_output, timestep, sample, to_final=False):
|
||||||
Follows FlowMatchScheduler interface: add_noise(original_samples, noise, timestep)
|
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||||
"""
|
|
||||||
if isinstance(timestep, torch.Tensor):
|
if isinstance(timestep, torch.Tensor):
|
||||||
timestep = timestep.cpu()
|
timestep = timestep.cpu()
|
||||||
if timestep.dim() == 0:
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
timestep = timestep.item()
|
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||||
elif timestep.dim() == 1:
|
alpha_prod_t_prev = 1.0
|
||||||
timestep = timestep[0].item()
|
else:
|
||||||
|
timestep_prev = int(self.timesteps[timestep_id + 1])
|
||||||
|
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
||||||
|
|
||||||
timestep = int(timestep)
|
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
||||||
# Defensive clamp: ensure timestep is within valid range
|
|
||||||
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
|
|
||||||
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
|
|
||||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
|
|
||||||
|
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.to(device=original_samples.device, dtype=original_samples.dtype)
|
|
||||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.to(device=original_samples.device, dtype=original_samples.dtype)
|
|
||||||
|
|
||||||
# Handle broadcasting for batch timesteps
|
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||||
while sqrt_alpha_prod.dim() < original_samples.dim():
|
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
|
||||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise(self, original_samples, noise, timestep):
|
||||||
|
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||||
|
return noisy_samples
|
||||||
|
|
||||||
sample = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
|
||||||
return sample
|
|
||||||
|
|
||||||
def training_target(self, sample, noise, timestep):
|
def training_target(self, sample, noise, timestep):
|
||||||
"""Return the training target for the given prediction type."""
|
|
||||||
if isinstance(timestep, torch.Tensor):
|
|
||||||
timestep = timestep.cpu()
|
|
||||||
if timestep.dim() == 0:
|
|
||||||
timestep = timestep.item()
|
|
||||||
elif timestep.dim() == 1:
|
|
||||||
timestep = timestep[0].item()
|
|
||||||
timestep = int(timestep)
|
|
||||||
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
|
|
||||||
if self.prediction_type == "epsilon":
|
if self.prediction_type == "epsilon":
|
||||||
return noise
|
return noise
|
||||||
elif self.prediction_type == "v_prediction":
|
|
||||||
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
|
|
||||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
|
|
||||||
return sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
|
||||||
elif self.prediction_type == "sample":
|
|
||||||
return sample
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported prediction_type: {self.prediction_type}")
|
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
def training_weight(self, timestep):
|
def training_weight(self, timestep):
|
||||||
"""Return training weight for the given timestep."""
|
return 1.0
|
||||||
timestep = max(0, min(int(timestep), self.num_train_timesteps - 1))
|
|
||||||
timestep_tensor = torch.tensor(timestep, device=self.timesteps.device)
|
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep_tensor).abs())
|
|
||||||
return self.linear_timesteps_weights[timestep_id]
|
|
||||||
|
|||||||
@@ -196,19 +196,14 @@ class SDUnit_InputImageEmbedder(PipelineUnit):
|
|||||||
|
|
||||||
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
|
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
|
||||||
if input_image is None:
|
if input_image is None:
|
||||||
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
|
return {"latents": noise}
|
||||||
if pipe.scheduler.training:
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
input_tensor = pipe.preprocess_image(input_image)
|
input_tensor = pipe.preprocess_image(input_image)
|
||||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
|
||||||
latents = noise * pipe.scheduler.init_noise_sigma
|
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||||
|
if pipe.scheduler.training:
|
||||||
return {"latents": latents, "input_latents": input_latents}
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
else:
|
else:
|
||||||
# Inference mode: VAE encode input image, add noise for initial latent
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
|
||||||
input_tensor = pipe.preprocess_image(input_image)
|
|
||||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
|
||||||
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
|
||||||
return {"latents": latents}
|
return {"latents": latents}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class StableDiffusionXLPipeline(BasePipeline):
|
|||||||
SDXLUnit_PromptEmbedder(),
|
SDXLUnit_PromptEmbedder(),
|
||||||
SDXLUnit_NoiseInitializer(),
|
SDXLUnit_NoiseInitializer(),
|
||||||
SDXLUnit_InputImageEmbedder(),
|
SDXLUnit_InputImageEmbedder(),
|
||||||
|
SDXLUnit_AddTimeIdsComputer(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_stable_diffusion_xl
|
self.model_fn = model_fn_stable_diffusion_xl
|
||||||
self.compilable_models = ["unet"]
|
self.compilable_models = ["unet"]
|
||||||
@@ -94,20 +95,11 @@ class StableDiffusionXLPipeline(BasePipeline):
|
|||||||
seed: int = None,
|
seed: int = None,
|
||||||
rand_device: str = "cpu",
|
rand_device: str = "cpu",
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
eta: float = 0.0,
|
|
||||||
guidance_rescale: float = 0.0,
|
guidance_rescale: float = 0.0,
|
||||||
original_size: tuple = None,
|
|
||||||
crops_coords_top_left: tuple = (0, 0),
|
|
||||||
target_size: tuple = None,
|
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
):
|
):
|
||||||
original_size = original_size or (height, width)
|
|
||||||
target_size = target_size or (height, width)
|
|
||||||
|
|
||||||
# 1. Scheduler
|
# 1. Scheduler
|
||||||
self.scheduler.set_timesteps(
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
num_inference_steps, eta=eta,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Three-dict input preparation
|
# 2. Three-dict input preparation
|
||||||
inputs_posi = {
|
inputs_posi = {
|
||||||
@@ -121,9 +113,7 @@ class StableDiffusionXLPipeline(BasePipeline):
|
|||||||
"height": height, "width": width,
|
"height": height, "width": width,
|
||||||
"seed": seed, "rand_device": rand_device,
|
"seed": seed, "rand_device": rand_device,
|
||||||
"guidance_rescale": guidance_rescale,
|
"guidance_rescale": guidance_rescale,
|
||||||
"original_size": original_size,
|
"crops_coords_top_left": (0, 0),
|
||||||
"crops_coords_top_left": crops_coords_top_left,
|
|
||||||
"target_size": target_size,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 3. Unit chain execution
|
# 3. Unit chain execution
|
||||||
@@ -132,18 +122,7 @@ class StableDiffusionXLPipeline(BasePipeline):
|
|||||||
unit, self, inputs_shared, inputs_posi, inputs_nega
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Compute add_time_ids (micro-conditioning)
|
# 4. Denoise loop
|
||||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
|
||||||
add_time_ids = self._get_add_time_ids(
|
|
||||||
original_size, crops_coords_top_left, target_size,
|
|
||||||
dtype=self.torch_dtype,
|
|
||||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
|
||||||
)
|
|
||||||
neg_add_time_ids = add_time_ids.clone()
|
|
||||||
inputs_posi["add_time_ids"] = add_time_ids
|
|
||||||
inputs_nega["add_time_ids"] = neg_add_time_ids
|
|
||||||
|
|
||||||
# 5. Denoise loop
|
|
||||||
self.load_models_to_device(self.in_iteration_models)
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
@@ -183,21 +162,6 @@ class StableDiffusionXLPipeline(BasePipeline):
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None):
|
|
||||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
|
||||||
# SDXL UNet doesn't have a config attribute, so we access add_embedding directly
|
|
||||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
|
||||||
# addition_time_embed_dim is the dimension of each time ID projection (256 for SDXL base)
|
|
||||||
addition_time_embed_dim = self.unet.add_time_proj.num_channels
|
|
||||||
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
|
||||||
if expected_add_embed_dim != passed_add_embed_dim:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
|
|
||||||
f"but a vector of {passed_add_embed_dim} was created."
|
|
||||||
)
|
|
||||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device)
|
|
||||||
return add_time_ids
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLUnit_ShapeChecker(PipelineUnit):
|
class SDXLUnit_ShapeChecker(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -294,22 +258,51 @@ class SDXLUnit_InputImageEmbedder(PipelineUnit):
|
|||||||
|
|
||||||
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
|
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
|
||||||
if input_image is None:
|
if input_image is None:
|
||||||
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
|
return {"latents": noise}
|
||||||
if pipe.scheduler.training:
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
input_tensor = pipe.preprocess_image(input_image)
|
input_tensor = pipe.preprocess_image(input_image)
|
||||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
|
||||||
latents = noise * pipe.scheduler.init_noise_sigma
|
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||||
|
if pipe.scheduler.training:
|
||||||
return {"latents": latents, "input_latents": input_latents}
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
else:
|
else:
|
||||||
# Inference mode: VAE encode input image, add noise for initial latent
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
|
||||||
input_tensor = pipe.preprocess_image(input_image)
|
|
||||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
|
||||||
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
|
||||||
return {"latents": latents}
|
return {"latents": latents}
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_AddTimeIdsComputer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("add_time_ids",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_add_time_ids(self, pipe, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim):
|
||||||
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||||
|
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
|
||||||
|
addition_time_embed_dim = pipe.unet.add_time_proj.num_channels
|
||||||
|
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||||
|
if expected_add_embed_dim != passed_add_embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
|
||||||
|
f"but a vector of {passed_add_embed_dim} was created."
|
||||||
|
)
|
||||||
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=pipe.device)
|
||||||
|
return add_time_ids
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, height, width):
|
||||||
|
original_size = (height, width)
|
||||||
|
target_size = (height, width)
|
||||||
|
crops_coords_top_left = (0, 0)
|
||||||
|
|
||||||
|
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
|
||||||
|
add_time_ids = self._get_add_time_ids(
|
||||||
|
pipe, original_size, crops_coords_top_left, target_size,
|
||||||
|
dtype=pipe.torch_dtype,
|
||||||
|
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||||
|
)
|
||||||
|
return {"add_time_ids": add_time_ids}
|
||||||
|
|
||||||
|
|
||||||
def model_fn_stable_diffusion_xl(
|
def model_fn_stable_diffusion_xl(
|
||||||
unet: SDXLUNet2DConditionModel,
|
unet: SDXLUNet2DConditionModel,
|
||||||
latents=None,
|
latents=None,
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
|||||||
)
|
)
|
||||||
|
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt="a photo of an astronaut riding a horse on mars",
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
negative_prompt="",
|
negative_prompt="blurry, low quality, deformed",
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
height=512,
|
height=512,
|
||||||
width=512,
|
width=512,
|
||||||
seed=42,
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
num_inference_steps=50,
|
num_inference_steps=50,
|
||||||
)
|
)
|
||||||
image.save("output_stable_diffusion_t2i.png")
|
image.save("image.jpg")
|
||||||
print("Image saved to output_stable_diffusion_t2i.png")
|
|
||||||
@@ -12,7 +12,6 @@ vram_config = {
|
|||||||
"computation_dtype": torch.float32,
|
"computation_dtype": torch.float32,
|
||||||
"computation_device": "cuda",
|
"computation_device": "cuda",
|
||||||
}
|
}
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
torch_dtype=torch.float32,
|
torch_dtype=torch.float32,
|
||||||
model_configs=[
|
model_configs=[
|
||||||
@@ -25,13 +24,13 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
|||||||
)
|
)
|
||||||
|
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt="a photo of an astronaut riding a horse on mars",
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
negative_prompt="",
|
negative_prompt="blurry, low quality, deformed",
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
height=512,
|
height=512,
|
||||||
width=512,
|
width=512,
|
||||||
seed=42,
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
num_inference_steps=50,
|
num_inference_steps=50,
|
||||||
)
|
)
|
||||||
image.save("output_stable_diffusion_t2i_low_vram.png")
|
image.save("image.jpg")
|
||||||
print("Image saved to output_stable_diffusion_t2i_low_vram.png")
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
# Dataset: data/diffsynth_example_dataset/stable_diffusion/StableDiffusion/
|
|
||||||
# Debug test: num_epochs=1, dataset_repeat=1 for quick validation
|
|
||||||
|
|
||||||
# ===== 固定参数(无需修改) =====
|
|
||||||
accelerate launch examples/stable_diffusion/model_training/train.py \
|
|
||||||
--learning_rate 1e-4 --num_epochs 1 \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--use_gradient_checkpointing --find_unused_parameters \
|
|
||||||
--dataset_base_path "./data/diffsynth_example_dataset/stable_diffusion/StableDiffusion" \
|
|
||||||
--dataset_metadata_path "./data/diffsynth_example_dataset/stable_diffusion/StableDiffusion/metadata.csv" \
|
|
||||||
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--lora_base_model "unet" \
|
|
||||||
--remove_prefix_in_ckpt "pipe.unet." \
|
|
||||||
--max_pixels 262144 \
|
|
||||||
--height 512 --width 512 \
|
|
||||||
--dataset_repeat 1 \
|
|
||||||
--output_path "./models/train/StableDiffusion_lora_debug" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,to_out.0" \
|
|
||||||
--data_file_keys "image"
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
# Dataset: data/diffsynth_example_dataset/stable_diffusion/StableDiffusion/
|
|
||||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/StableDiffusion/*" --local_dir ./data/diffsynth_example_dataset
|
|
||||||
|
|
||||||
# ===== 固定参数(无需修改) =====
|
|
||||||
accelerate launch examples/stable_diffusion/model_training/train.py \
|
|
||||||
--learning_rate 1e-4 --num_epochs 5 \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--use_gradient_checkpointing --find_unused_parameters \
|
|
||||||
--dataset_base_path "./data/diffsynth_example_dataset/stable_diffusion/StableDiffusion" \
|
|
||||||
--dataset_metadata_path "./data/diffsynth_example_dataset/stable_diffusion/StableDiffusion/metadata.csv" \
|
|
||||||
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--lora_base_model "unet" \
|
|
||||||
--remove_prefix_in_ckpt "pipe.unet." \
|
|
||||||
--max_pixels 262144 \
|
|
||||||
--height 512 --width 512 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--output_path "./models/train/StableDiffusion_lora" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,to_out.0" \
|
|
||||||
--data_file_keys "image"
|
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/stable-diffusion-v1-5/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/stable_diffusion/model_training/train.py \
|
||||||
|
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5 \
|
||||||
|
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5/metadata.csv \
|
||||||
|
--height 512 \
|
||||||
|
--width 512 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.unet." \
|
||||||
|
--output_path "./models/train/stable-diffusion-v1-5_lora" \
|
||||||
|
--lora_base_model "unet" \
|
||||||
|
--lora_target_modules "" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import torch, os, argparse, accelerate
|
import torch, os, argparse, accelerate
|
||||||
from diffsynth.core import UnifiedDataset
|
from diffsynth.core import UnifiedDataset
|
||||||
from diffsynth.core.data.operators import ToAbsolutePath, LoadImage, ImageCropAndResize, RouteByType, SequencialProcess
|
|
||||||
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
||||||
from diffsynth.diffusion import *
|
from diffsynth.diffusion import *
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -23,16 +22,13 @@ class StableDiffusionTrainingModule(DiffusionTrainingModule):
|
|||||||
task="sft",
|
task="sft",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# ===== 解析模型配置 =====
|
# Load models
|
||||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
# ===== Tokenizer 配置 =====
|
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"))
|
||||||
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"))
|
self.pipe = StableDiffusionPipeline.from_pretrained(torch_dtype=torch.float32, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||||
# ===== 构建 Pipeline =====
|
|
||||||
self.pipe = StableDiffusionPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
|
||||||
# ===== 拆分 Pipeline Units =====
|
|
||||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
# ===== 切换到训练模式 =====
|
# Training mode
|
||||||
self.switch_pipe_to_training_mode(
|
self.switch_pipe_to_training_mode(
|
||||||
self.pipe, trainable_models,
|
self.pipe, trainable_models,
|
||||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
@@ -40,42 +36,41 @@ class StableDiffusionTrainingModule(DiffusionTrainingModule):
|
|||||||
task=task,
|
task=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ===== 其他配置 =====
|
# Other configs
|
||||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
self.fp8_models = fp8_models
|
self.fp8_models = fp8_models
|
||||||
self.task = task
|
self.task = task
|
||||||
# ===== 任务模式路由 =====
|
|
||||||
self.task_to_loss = {
|
self.task_to_loss = {
|
||||||
"sft:data_process": lambda pipe, *args: args,
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"direct_distill:data_process": lambda pipe, *args: args,
|
||||||
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_pipeline_inputs(self, data):
|
def get_pipeline_inputs(self, data):
|
||||||
# ===== 正向提示词 =====
|
|
||||||
inputs_posi = {"prompt": data["prompt"]}
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
# ===== 负向提示词:训练不需要 =====
|
|
||||||
inputs_nega = {"negative_prompt": ""}
|
inputs_nega = {"negative_prompt": ""}
|
||||||
# ===== 共享参数 =====
|
|
||||||
inputs_shared = {
|
inputs_shared = {
|
||||||
# ===== 核心字段映射 =====
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
"input_image": data["image"],
|
"input_image": data["image"],
|
||||||
"height": data["image"].size[1],
|
"height": data["image"].size[1],
|
||||||
"width": data["image"].size[0],
|
"width": data["image"].size[0],
|
||||||
# ===== 框架控制参数 =====
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
"cfg_scale": 1,
|
"cfg_scale": 1,
|
||||||
"rand_device": self.pipe.device,
|
"rand_device": self.pipe.device,
|
||||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
}
|
}
|
||||||
# ===== 额外字段注入 =====
|
|
||||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
def forward(self, data, inputs=None):
|
def forward(self, data, inputs=None):
|
||||||
# ===== 标准实现,不要修改 =====
|
|
||||||
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
for unit in self.pipe.units:
|
for unit in self.pipe.units:
|
||||||
@@ -84,24 +79,21 @@ class StableDiffusionTrainingModule(DiffusionTrainingModule):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def stable_diffusion_parser():
|
def parser():
|
||||||
parser = argparse.ArgumentParser(description="Stable Diffusion training.")
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
parser = add_general_config(parser)
|
parser = add_general_config(parser)
|
||||||
parser = add_image_size_config(parser)
|
parser = add_image_size_config(parser)
|
||||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = stable_diffusion_parser()
|
parser = parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# ===== Accelerator 配置 =====
|
|
||||||
accelerator = accelerate.Accelerator(
|
accelerator = accelerate.Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
)
|
)
|
||||||
# ===== 数据集定义 =====
|
|
||||||
dataset = UnifiedDataset(
|
dataset = UnifiedDataset(
|
||||||
base_path=args.dataset_base_path,
|
base_path=args.dataset_base_path,
|
||||||
metadata_path=args.dataset_metadata_path,
|
metadata_path=args.dataset_metadata_path,
|
||||||
@@ -112,17 +104,10 @@ if __name__ == "__main__":
|
|||||||
max_pixels=args.max_pixels,
|
max_pixels=args.max_pixels,
|
||||||
height=args.height,
|
height=args.height,
|
||||||
width=args.width,
|
width=args.width,
|
||||||
height_division_factor=8,
|
height_division_factor=32,
|
||||||
width_division_factor=8,
|
width_division_factor=32,
|
||||||
),
|
)
|
||||||
special_operator_map={
|
|
||||||
"image": RouteByType(operator_map=[
|
|
||||||
(str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 8, 8)),
|
|
||||||
(list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 8, 8))),
|
|
||||||
]),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
# ===== TrainingModule =====
|
|
||||||
model = StableDiffusionTrainingModule(
|
model = StableDiffusionTrainingModule(
|
||||||
model_paths=args.model_paths,
|
model_paths=args.model_paths,
|
||||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
@@ -140,17 +125,18 @@ if __name__ == "__main__":
|
|||||||
fp8_models=args.fp8_models,
|
fp8_models=args.fp8_models,
|
||||||
offload_models=args.offload_models,
|
offload_models=args.offload_models,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
device=accelerator.device,
|
||||||
)
|
)
|
||||||
# ===== ModelLogger =====
|
|
||||||
model_logger = ModelLogger(
|
model_logger = ModelLogger(
|
||||||
args.output_path,
|
args.output_path,
|
||||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
)
|
)
|
||||||
# ===== 任务路由 =====
|
|
||||||
launcher_map = {
|
launcher_map = {
|
||||||
"sft:data_process": launch_data_process_task,
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"direct_distill:data_process": launch_data_process_task,
|
||||||
"sft": launch_training_task,
|
"sft": launch_training_task,
|
||||||
"sft:train": launch_training_task,
|
"sft:train": launch_training_task,
|
||||||
|
"direct_distill": launch_training_task,
|
||||||
|
"direct_distill:train": launch_training_task,
|
||||||
}
|
}
|
||||||
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.float32,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
prompt = "dog, white and brown dog, sitting on wall, under pink flowers"
|
|
||||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=7.5)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
|
||||||
import torch
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
torch_dtype=torch.float32,
|
torch_dtype=torch.float32,
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
@@ -12,7 +11,16 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
|||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
)
|
)
|
||||||
pipe.load_lora(pipe.unet, "./models/train/StableDiffusion_lora/epoch-4.safetensors")
|
pipe.load_lora(pipe.unet, "models/train/stable-diffusion-v1-5_lora/epoch-4.safetensors")
|
||||||
prompt = "dog, white and brown dog, sitting on wall, under pink flowers"
|
|
||||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=7.5)
|
image = pipe(
|
||||||
image.save("image.jpg")
|
prompt="a dog",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image_stable-diffusion-v1-5.jpg")
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
import torch
|
|
||||||
from diffsynth.core import ModelConfig
|
|
||||||
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
|
||||||
|
|
||||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.float32,
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
|
||||||
tokenizer_2_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
image = pipe(
|
|
||||||
prompt="a photo of an astronaut riding a horse on mars",
|
|
||||||
negative_prompt="",
|
|
||||||
cfg_scale=5.0,
|
|
||||||
height=1024,
|
|
||||||
width=1024,
|
|
||||||
seed=42,
|
|
||||||
num_inference_steps=50,
|
|
||||||
)
|
|
||||||
image.save("output_stable_diffusion_xl_t2i.png")
|
|
||||||
print("Image saved to output_stable_diffusion_xl_t2i.png")
|
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
import torch
|
|
||||||
from diffsynth.core import ModelConfig
|
|
||||||
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": torch.bfloat16,
|
|
||||||
"offload_device": "cpu",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
|
|
||||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.float32,
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
|
||||||
tokenizer_2_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
image = pipe(
|
|
||||||
prompt="a photo of an astronaut riding a horse on mars",
|
|
||||||
negative_prompt="",
|
|
||||||
cfg_scale=5.0,
|
|
||||||
height=1024,
|
|
||||||
width=1024,
|
|
||||||
seed=42,
|
|
||||||
num_inference_steps=50,
|
|
||||||
)
|
|
||||||
image.save("output_stable_diffusion_xl_t2i_low_vram.png")
|
|
||||||
print("Image saved to output_stable_diffusion_xl_t2i_low_vram.png")
|
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
# Dataset: data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL/
|
|
||||||
# Debug test: num_epochs=1, dataset_repeat=1 for quick validation
|
|
||||||
|
|
||||||
# ===== 固定参数(无需修改) =====
|
|
||||||
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
|
|
||||||
--learning_rate 1e-4 --num_epochs 1 \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--use_gradient_checkpointing --find_unused_parameters \
|
|
||||||
--dataset_base_path "./data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL" \
|
|
||||||
--dataset_metadata_path "./data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL/metadata.csv" \
|
|
||||||
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--tokenizer_path "AI-ModelScope/stable-diffusion-xl-base-1.0:tokenizer/" \
|
|
||||||
--tokenizer_2_path "AI-ModelScope/stable-diffusion-xl-base-1.0:tokenizer_2/" \
|
|
||||||
--lora_base_model "unet" \
|
|
||||||
--remove_prefix_in_ckpt "pipe.unet." \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--height 1024 --width 1024 \
|
|
||||||
--dataset_repeat 1 \
|
|
||||||
--output_path "./models/train/StableDiffusionXL_lora_debug" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,to_out.0" \
|
|
||||||
--data_file_keys "image"
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
# Dataset: data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL/
|
|
||||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/StableDiffusionXL/*" --local_dir ./data/diffsynth_example_dataset
|
|
||||||
|
|
||||||
# ===== 固定参数(无需修改) =====
|
|
||||||
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
|
|
||||||
--learning_rate 1e-4 --num_epochs 5 \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--use_gradient_checkpointing --find_unused_parameters \
|
|
||||||
--dataset_base_path "./data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL" \
|
|
||||||
--dataset_metadata_path "./data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL/metadata.csv" \
|
|
||||||
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--tokenizer_path "AI-ModelScope/stable-diffusion-xl-base-1.0:tokenizer/" \
|
|
||||||
--tokenizer_2_path "AI-ModelScope/stable-diffusion-xl-base-1.0:tokenizer_2/" \
|
|
||||||
--lora_base_model "unet" \
|
|
||||||
--remove_prefix_in_ckpt "pipe.unet." \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--height 1024 --width 1024 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--output_path "./models/train/StableDiffusionXL_lora" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,to_out.0" \
|
|
||||||
--data_file_keys "image"
|
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/stable-diffusion-xl-base-1.0/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
|
||||||
|
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0 \
|
||||||
|
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0/metadata.csv \
|
||||||
|
--height 1024 \
|
||||||
|
--width 1024 \
|
||||||
|
--dataset_repeat 10 \
|
||||||
|
--model_id_with_origin_paths "stabilityai/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.unet." \
|
||||||
|
--output_path "./models/train/stable-diffusion-xl-base-1.0_lora" \
|
||||||
|
--lora_base_model "unet" \
|
||||||
|
--lora_target_modules "" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import torch, os, argparse, accelerate
|
import torch, os, argparse, accelerate
|
||||||
from diffsynth.core import UnifiedDataset
|
from diffsynth.core import UnifiedDataset
|
||||||
from diffsynth.core.data.operators import ToAbsolutePath, LoadImage, ImageCropAndResize, RouteByType, SequencialProcess
|
|
||||||
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
||||||
from diffsynth.diffusion import *
|
from diffsynth.diffusion import *
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -10,7 +9,7 @@ class StableDiffusionXLTrainingModule(DiffusionTrainingModule):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_paths=None, model_id_with_origin_paths=None,
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
tokenizer_path=None, tokenizer_2_path=None,
|
tokenizer_path=None,
|
||||||
trainable_models=None,
|
trainable_models=None,
|
||||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
preset_lora_path=None, preset_lora_model=None,
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
@@ -23,17 +22,14 @@ class StableDiffusionXLTrainingModule(DiffusionTrainingModule):
|
|||||||
task="sft",
|
task="sft",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# ===== 解析模型配置 =====
|
# Load models
|
||||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
# ===== Tokenizer 配置 =====
|
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"))
|
||||||
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"))
|
tokenizer_2_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"))
|
||||||
tokenizer_2_config = self.parse_path_or_model_id(tokenizer_2_path, default_value=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"))
|
self.pipe = StableDiffusionXLPipeline.from_pretrained(torch_dtype=torch.float32, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, tokenizer_2_config=tokenizer_2_config)
|
||||||
# ===== 构建 Pipeline =====
|
|
||||||
self.pipe = StableDiffusionXLPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, tokenizer_2_config=tokenizer_2_config)
|
|
||||||
# ===== 拆分 Pipeline Units =====
|
|
||||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
# ===== 切换到训练模式 =====
|
# Training mode
|
||||||
self.switch_pipe_to_training_mode(
|
self.switch_pipe_to_training_mode(
|
||||||
self.pipe, trainable_models,
|
self.pipe, trainable_models,
|
||||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
@@ -41,57 +37,41 @@ class StableDiffusionXLTrainingModule(DiffusionTrainingModule):
|
|||||||
task=task,
|
task=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ===== 其他配置 =====
|
# Other configs
|
||||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
self.fp8_models = fp8_models
|
self.fp8_models = fp8_models
|
||||||
self.task = task
|
self.task = task
|
||||||
# ===== 任务模式路由 =====
|
|
||||||
self.task_to_loss = {
|
self.task_to_loss = {
|
||||||
"sft:data_process": lambda pipe, *args: args,
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"direct_distill:data_process": lambda pipe, *args: args,
|
||||||
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_pipeline_inputs(self, data):
|
def get_pipeline_inputs(self, data):
|
||||||
# ===== 正向提示词 =====
|
|
||||||
inputs_posi = {"prompt": data["prompt"]}
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
# ===== 负向提示词:训练不需要 =====
|
|
||||||
inputs_nega = {"negative_prompt": ""}
|
inputs_nega = {"negative_prompt": ""}
|
||||||
# ===== 共享参数 =====
|
|
||||||
height = data["image"].size[1]
|
|
||||||
width = data["image"].size[0]
|
|
||||||
inputs_shared = {
|
inputs_shared = {
|
||||||
# ===== 核心字段映射 =====
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
"input_image": data["image"],
|
"input_image": data["image"],
|
||||||
"height": height,
|
"height": data["image"].size[1],
|
||||||
"width": width,
|
"width": data["image"].size[0],
|
||||||
# ===== 框架控制参数 =====
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
"cfg_scale": 1,
|
"cfg_scale": 1,
|
||||||
"rand_device": self.pipe.device,
|
"rand_device": self.pipe.device,
|
||||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
}
|
}
|
||||||
# ===== SDXL 特有:add_time_ids (micro-conditioning) =====
|
|
||||||
# 在 __call__ 中计算,但训练不跑 __call__,所以在这里注入
|
|
||||||
text_encoder_projection_dim = self.pipe.text_encoder_2.config.projection_dim
|
|
||||||
add_time_ids = [height, width, 0, 0, height, width]
|
|
||||||
expected_add_embed_dim = self.pipe.unet.add_embedding.linear_1.in_features
|
|
||||||
addition_time_embed_dim = self.pipe.unet.add_time_proj.num_channels
|
|
||||||
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
|
||||||
if expected_add_embed_dim != passed_add_embed_dim:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
|
|
||||||
f"but a vector of {passed_add_embed_dim} was created."
|
|
||||||
)
|
|
||||||
inputs_posi["add_time_ids"] = torch.tensor([add_time_ids], dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
|
||||||
# ===== 额外字段注入 =====
|
|
||||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
def forward(self, data, inputs=None):
|
def forward(self, data, inputs=None):
|
||||||
# ===== 标准实现,不要修改 =====
|
|
||||||
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
for unit in self.pipe.units:
|
for unit in self.pipe.units:
|
||||||
@@ -100,25 +80,22 @@ class StableDiffusionXLTrainingModule(DiffusionTrainingModule):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def stable_diffusion_xl_parser():
|
def parser():
|
||||||
parser = argparse.ArgumentParser(description="Stable Diffusion XL training.")
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
parser = add_general_config(parser)
|
parser = add_general_config(parser)
|
||||||
parser = add_image_size_config(parser)
|
parser = add_image_size_config(parser)
|
||||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
parser.add_argument("--tokenizer_2_path", type=str, default=None, help="Path to tokenizer 2.")
|
parser.add_argument("--tokenizer_2_path", type=str, default=None, help="Path to tokenizer 2.")
|
||||||
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = stable_diffusion_xl_parser()
|
parser = parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# ===== Accelerator 配置 =====
|
|
||||||
accelerator = accelerate.Accelerator(
|
accelerator = accelerate.Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
)
|
)
|
||||||
# ===== 数据集定义 =====
|
|
||||||
dataset = UnifiedDataset(
|
dataset = UnifiedDataset(
|
||||||
base_path=args.dataset_base_path,
|
base_path=args.dataset_base_path,
|
||||||
metadata_path=args.dataset_metadata_path,
|
metadata_path=args.dataset_metadata_path,
|
||||||
@@ -129,22 +106,14 @@ if __name__ == "__main__":
|
|||||||
max_pixels=args.max_pixels,
|
max_pixels=args.max_pixels,
|
||||||
height=args.height,
|
height=args.height,
|
||||||
width=args.width,
|
width=args.width,
|
||||||
height_division_factor=8,
|
height_division_factor=32,
|
||||||
width_division_factor=8,
|
width_division_factor=32,
|
||||||
),
|
)
|
||||||
special_operator_map={
|
|
||||||
"image": RouteByType(operator_map=[
|
|
||||||
(str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 8, 8)),
|
|
||||||
(list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 8, 8))),
|
|
||||||
]),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
# ===== TrainingModule =====
|
|
||||||
model = StableDiffusionXLTrainingModule(
|
model = StableDiffusionXLTrainingModule(
|
||||||
model_paths=args.model_paths,
|
model_paths=args.model_paths,
|
||||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
tokenizer_path=args.tokenizer_path,
|
tokenizer_path=args.tokenizer_path,
|
||||||
tokenizer_2_path=args.tokenizer_2_path,
|
|
||||||
trainable_models=args.trainable_models,
|
trainable_models=args.trainable_models,
|
||||||
lora_base_model=args.lora_base_model,
|
lora_base_model=args.lora_base_model,
|
||||||
lora_target_modules=args.lora_target_modules,
|
lora_target_modules=args.lora_target_modules,
|
||||||
@@ -158,17 +127,18 @@ if __name__ == "__main__":
|
|||||||
fp8_models=args.fp8_models,
|
fp8_models=args.fp8_models,
|
||||||
offload_models=args.offload_models,
|
offload_models=args.offload_models,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
device=accelerator.device,
|
||||||
)
|
)
|
||||||
# ===== ModelLogger =====
|
|
||||||
model_logger = ModelLogger(
|
model_logger = ModelLogger(
|
||||||
args.output_path,
|
args.output_path,
|
||||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
)
|
)
|
||||||
# ===== 任务路由 =====
|
|
||||||
launcher_map = {
|
launcher_map = {
|
||||||
"sft:data_process": launch_data_process_task,
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"direct_distill:data_process": launch_data_process_task,
|
||||||
"sft": launch_training_task,
|
"sft": launch_training_task,
|
||||||
"sft:train": launch_training_task,
|
"sft:train": launch_training_task,
|
||||||
|
"direct_distill": launch_training_task,
|
||||||
|
"direct_distill:train": launch_training_task,
|
||||||
}
|
}
|
||||||
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.float32,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
|
||||||
tokenizer_2_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
|
||||||
)
|
|
||||||
prompt = "dog, white and brown dog, sitting on wall, under pink flowers"
|
|
||||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=5.0)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.float32,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
|
||||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
|
||||||
tokenizer_2_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.unet, "./models/train/StableDiffusionXL_lora/epoch-4.safetensors")
|
|
||||||
prompt = "dog, white and brown dog, sitting on wall, under pink flowers"
|
|
||||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=5.0)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.unet, "models/train/stable-diffusion-xl-base-1.0_lora/epoch-4.safetensors")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a dog",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=7.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image_stable-diffusion-xl-base-1.0.jpg")
|
||||||
Reference in New Issue
Block a user