diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 299d9d5..85f781c 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -922,6 +922,13 @@ stable_diffusion_xl_series = [ "model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2", "state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter", }, + { + # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors") + "model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78", + "model_name": "stable_diffusion_text_encoder", + "model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter", + }, { # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") "model_hash": "13115dd45a6e1c39860f91ab073b8a78", @@ -971,4 +978,4 @@ joyai_image_series = [ }, ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + stable_diffusion_xl_series + stable_diffusion_series + joyai_image_series +MODEL_CONFIGS = stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series diff --git a/diffsynth/pipelines/stable_diffusion_xl.py b/diffsynth/pipelines/stable_diffusion_xl.py new file mode 100644 index 0000000..4ef7ef6 --- /dev/null +++ b/diffsynth/pipelines/stable_diffusion_xl.py @@ -0,0 +1,332 @@ +import torch +from PIL import Image +from tqdm import tqdm +from typing import Union + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion.ddim_scheduler import DDIMScheduler +from ..core import ModelConfig +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from transformers import AutoTokenizer, CLIPTextModel +from ..models.stable_diffusion_text_encoder import SDTextEncoder +from ..models.stable_diffusion_xl_unet import SDXLUNet2DConditionModel +from ..models.stable_diffusion_xl_text_encoder import SDXLTextEncoder2 +from ..models.stable_diffusion_vae import StableDiffusionVAE + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """Rescale noise_cfg based on guidance_rescale to prevent overexposure. + + Based on Section 3.4 from "Common Diffusion Noise Schedules and Sample Steps are Flawed" + https://huggingface.co/papers/2305.08891 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLPipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=8, width_division_factor=8, + ) + self.scheduler = DDIMScheduler() + self.text_encoder: SDTextEncoder = None + self.text_encoder_2: SDXLTextEncoder2 = None + self.unet: SDXLUNet2DConditionModel = None + self.vae: StableDiffusionVAE = None + self.tokenizer: AutoTokenizer = None + self.tokenizer_2: AutoTokenizer = None + + self.in_iteration_models = ("unet",) + self.units = [ + SDXLUnit_ShapeChecker(), + SDXLUnit_PromptEmbedder(), + SDXLUnit_NoiseInitializer(), + SDXLUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_stable_diffusion_xl + self.compilable_models = ["unet"] + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = None, + tokenizer_2_config: ModelConfig = None, + vram_limit: float = None, + ): + pipe = StableDiffusionXLPipeline(device=device, torch_dtype=torch_dtype) + # Override vram_config to use the specified torch_dtype for all models + for mc in model_configs: + mc._vram_config_override = { + 'onload_dtype': torch_dtype, + 'computation_dtype': torch_dtype, + } + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder") + pipe.text_encoder_2 = model_pool.fetch_model("stable_diffusion_xl_text_encoder") + pipe.unet = model_pool.fetch_model("stable_diffusion_xl_unet") + pipe.vae = model_pool.fetch_model("stable_diffusion_xl_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + if tokenizer_2_config is not None: + tokenizer_2_config.download_if_necessary() + pipe.tokenizer_2 = AutoTokenizer.from_pretrained(tokenizer_2_config.path) + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + prompt: str, + prompt_2: str = None, + negative_prompt: str = "", + negative_prompt_2: str = None, + cfg_scale: float = 5.0, + height: int = 1024, + width: int = 1024, + seed: int = None, + rand_device: str = "cpu", + num_inference_steps: int = 50, + eta: float = 0.0, + guidance_rescale: float = 0.0, + original_size: tuple = None, + crops_coords_top_left: tuple = (0, 0), + target_size: tuple = None, + progress_bar_cmd=tqdm, + ): + prompt_2 = prompt_2 or prompt + negative_prompt_2 = negative_prompt_2 or negative_prompt + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Scheduler + self.scheduler.set_timesteps( + num_inference_steps, eta=eta, + ) + + # 2. Three-dict input preparation + inputs_posi = { + "prompt": prompt, + "prompt_2": prompt_2, + } + inputs_nega = { + "prompt": negative_prompt, + "prompt_2": negative_prompt_2, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "guidance_rescale": guidance_rescale, + "original_size": original_size, + "crops_coords_top_left": crops_coords_top_left, + "target_size": target_size, + } + + # 3. Unit chain execution + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner( + unit, self, inputs_shared, inputs_posi, inputs_nega + ) + + # 4. Compute add_time_ids (micro-conditioning) + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, + dtype=self.torch_dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + neg_add_time_ids = add_time_ids.clone() + inputs_posi["add_time_ids"] = add_time_ids + inputs_nega["add_time_ids"] = neg_add_time_ids + + # 5. Denoise loop + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + + # Apply guidance_rescale + if guidance_rescale > 0.0: + # cfg_guided_model_fn already applied CFG, now apply rescale + # We need the text-only prediction for rescale + noise_pred_text = self.model_fn( + self.unet, + inputs_shared["latents"], + timestep, + inputs_posi["prompt_embeds"], + pooled_prompt_embeds=inputs_posi["pooled_prompt_embeds"], + add_time_ids=inputs_posi["add_time_ids"], + ) + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale + ) + + inputs_shared["latents"] = self.step( + self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared + ) + + # 6. VAE decode + self.load_models_to_device(['vae']) + latents = inputs_shared["latents"] / self.vae.scaling_factor + image = self.vae.decode(latents) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + # SDXL UNet doesn't have a config attribute, so we access add_embedding directly + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + # addition_time_embed_dim is the dimension of each time ID projection (256 for SDXL base) + addition_time_embed_dim = self.unet.add_time_proj.num_channels + passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " + f"but a vector of {passed_add_embed_dim} was created." + ) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device) + return add_time_ids + + +class SDXLUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: StableDiffusionXLPipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + +class SDXLUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "prompt_2": "prompt_2"}, + input_params_nega={"prompt": "prompt", "prompt_2": "prompt_2"}, + output_params=("prompt_embeds", "pooled_prompt_embeds"), + onload_model_names=("text_encoder", "text_encoder_2") + ) + + def encode_prompt( + self, + pipe: StableDiffusionXLPipeline, + prompt: str, + prompt_2: str, + device: torch.device, + ) -> tuple: + """Encode prompt using both text encoders. + + Returns (prompt_embeds, pooled_prompt_embeds): + - prompt_embeds: concat(encoder1_output, encoder2_output) -> (B, 77, 2048) + - pooled_prompt_embeds: encoder2 pooled output -> (B, 1280) + """ + # Text Encoder 1 (CLIP-L, 768-dim) + text_input_ids_1 = pipe.tokenizer( + prompt, + padding="max_length", + max_length=pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.to(device) + prompt_embeds_1 = pipe.text_encoder(text_input_ids_1) + if isinstance(prompt_embeds_1, tuple): + prompt_embeds_1 = prompt_embeds_1[0] + + # Text Encoder 2 (CLIP-bigG, 1280-dim) — uses penultimate hidden states + pooled + text_input_ids_2 = pipe.tokenizer_2( + prompt_2, + padding="max_length", + max_length=pipe.tokenizer_2.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.to(device) + # SDXLTextEncoder2 forward returns (text_embeds/pooled, hidden_states_tuple) + pooled_prompt_embeds, hidden_states = pipe.text_encoder_2(text_input_ids_2, output_hidden_states=True) + # Use penultimate hidden state (same as diffusers: hidden_states[-2]) + prompt_embeds_2 = hidden_states[-2] + + # Concatenate both encoder outputs along feature dimension + prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1) + + return prompt_embeds, pooled_prompt_embeds + + def process(self, pipe: StableDiffusionXLPipeline, prompt, prompt_2): + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, pooled_prompt_embeds = self.encode_prompt(pipe, prompt, prompt_2, pipe.device) + return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds} + + +class SDXLUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: StableDiffusionXLPipeline, height, width, seed, rand_device): + noise = pipe.generate_noise( + (1, pipe.unet.in_channels, height // 8, width // 8), + seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype + ) + return {"noise": noise} + + +class SDXLUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("noise",), + output_params=("latents",), + ) + + def process(self, pipe: StableDiffusionXLPipeline, noise): + # For Text-to-Image, latents = noise (scaled by scheduler) + latents = noise * pipe.scheduler.init_noise_sigma + return {"latents": latents} + + +def model_fn_stable_diffusion_xl( + unet: SDXLUNet2DConditionModel, + latents=None, + timestep=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + add_time_ids=None, + cross_attention_kwargs=None, + timestep_cond=None, + **kwargs, +): + """SDXL model forward with added_cond_kwargs for micro-conditioning.""" + added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds, + "time_ids": add_time_ids, + } + noise_pred = unet( + latents, + timestep, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + timestep_cond=timestep_cond, + return_dict=False, + )[0] + return noise_pred diff --git a/examples/stable_diffusion_xl/model_inference/StableDiffusionXL-T2I.py b/examples/stable_diffusion_xl/model_inference/StableDiffusionXL-T2I.py new file mode 100644 index 0000000..715474a --- /dev/null +++ b/examples/stable_diffusion_xl/model_inference/StableDiffusionXL-T2I.py @@ -0,0 +1,27 @@ +import torch +from diffsynth.core import ModelConfig +from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline + +pipe = StableDiffusionXLPipeline.from_pretrained( + torch_dtype=torch.float32, + model_configs=[ + ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"), + tokenizer_2_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"), +) + +image = pipe( + prompt="a photo of an astronaut riding a horse on mars", + negative_prompt="", + cfg_scale=5.0, + height=1024, + width=1024, + seed=42, + num_inference_steps=50, +) +image.save("output_stable_diffusion_xl_t2i.png") +print("Image saved to output_stable_diffusion_xl_t2i.png") diff --git a/examples/stable_diffusion_xl/model_inference_low_vram/StableDiffusionXL-T2I.py b/examples/stable_diffusion_xl/model_inference_low_vram/StableDiffusionXL-T2I.py new file mode 100644 index 0000000..48496eb --- /dev/null +++ b/examples/stable_diffusion_xl/model_inference_low_vram/StableDiffusionXL-T2I.py @@ -0,0 +1,39 @@ +import torch +from diffsynth.core import ModelConfig +from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +pipe = StableDiffusionXLPipeline.from_pretrained( + torch_dtype=torch.float32, + model_configs=[ + ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"), + tokenizer_2_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = pipe( + prompt="a photo of an astronaut riding a horse on mars", + negative_prompt="", + cfg_scale=5.0, + height=1024, + width=1024, + seed=42, + num_inference_steps=50, +) +image.save("output_stable_diffusion_xl_t2i_low_vram.png") +print("Image saved to output_stable_diffusion_xl_t2i_low_vram.png")