mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
236 lines
8.6 KiB
Python
236 lines
8.6 KiB
Python
import torch
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
from typing import Union
|
|
|
|
from ..core.device.npu_compatible_device import get_device_type
|
|
from ..diffusion.ddim_scheduler import DDIMScheduler
|
|
from ..core import ModelConfig
|
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
|
|
|
from transformers import AutoTokenizer, CLIPTextModel
|
|
from ..models.stable_diffusion_text_encoder import SDTextEncoder
|
|
from ..models.stable_diffusion_unet import UNet2DConditionModel
|
|
from ..models.stable_diffusion_vae import StableDiffusionVAE
|
|
|
|
|
|
class StableDiffusionPipeline(BasePipeline):
|
|
|
|
def __init__(self, device=get_device_type(), torch_dtype=torch.float16):
|
|
super().__init__(
|
|
device=device, torch_dtype=torch_dtype,
|
|
height_division_factor=8, width_division_factor=8,
|
|
)
|
|
self.scheduler = DDIMScheduler()
|
|
self.text_encoder: SDTextEncoder = None
|
|
self.unet: UNet2DConditionModel = None
|
|
self.vae: StableDiffusionVAE = None
|
|
self.tokenizer: AutoTokenizer = None
|
|
|
|
self.in_iteration_models = ("unet",)
|
|
self.units = [
|
|
SDUnit_ShapeChecker(),
|
|
SDUnit_PromptEmbedder(),
|
|
SDUnit_NoiseInitializer(),
|
|
SDUnit_InputImageEmbedder(),
|
|
]
|
|
self.model_fn = model_fn_stable_diffusion
|
|
self.compilable_models = ["unet"]
|
|
|
|
@staticmethod
|
|
def from_pretrained(
|
|
torch_dtype: torch.dtype = torch.float16,
|
|
device: Union[str, torch.device] = get_device_type(),
|
|
model_configs: list[ModelConfig] = [],
|
|
tokenizer_config: ModelConfig = None,
|
|
vram_limit: float = None,
|
|
):
|
|
pipe = StableDiffusionPipeline(device=device, torch_dtype=torch_dtype)
|
|
# Override vram_config to use the specified torch_dtype for all models
|
|
for mc in model_configs:
|
|
mc._vram_config_override = {
|
|
'onload_dtype': torch_dtype,
|
|
'computation_dtype': torch_dtype,
|
|
}
|
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
|
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
|
|
pipe.unet = model_pool.fetch_model("stable_diffusion_unet")
|
|
pipe.vae = model_pool.fetch_model("stable_diffusion_vae")
|
|
if tokenizer_config is not None:
|
|
tokenizer_config.download_if_necessary()
|
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
|
return pipe
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
prompt: str,
|
|
negative_prompt: str = "",
|
|
cfg_scale: float = 7.5,
|
|
height: int = 512,
|
|
width: int = 512,
|
|
seed: int = None,
|
|
rand_device: str = "cpu",
|
|
num_inference_steps: int = 50,
|
|
eta: float = 0.0,
|
|
guidance_rescale: float = 0.0,
|
|
progress_bar_cmd=tqdm,
|
|
):
|
|
# 1. Scheduler
|
|
self.scheduler.set_timesteps(
|
|
num_inference_steps, eta=eta,
|
|
)
|
|
|
|
# 2. Three-dict input preparation
|
|
inputs_posi = {"prompt": prompt}
|
|
inputs_nega = {"negative_prompt": negative_prompt}
|
|
inputs_shared = {
|
|
"cfg_scale": cfg_scale,
|
|
"height": height, "width": width,
|
|
"seed": seed, "rand_device": rand_device,
|
|
"guidance_rescale": guidance_rescale,
|
|
}
|
|
|
|
# 3. Unit chain execution
|
|
for unit in self.units:
|
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
|
)
|
|
|
|
# 4. Denoise loop
|
|
self.load_models_to_device(self.in_iteration_models)
|
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
noise_pred = self.cfg_guided_model_fn(
|
|
self.model_fn, cfg_scale,
|
|
inputs_shared, inputs_posi, inputs_nega,
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
)
|
|
inputs_shared["latents"] = self.step(
|
|
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
|
|
)
|
|
|
|
# 5. VAE decode
|
|
self.load_models_to_device(['vae'])
|
|
latents = inputs_shared["latents"] / self.vae.scaling_factor
|
|
image = self.vae.decode(latents)
|
|
image = self.vae_output_to_image(image)
|
|
self.load_models_to_device([])
|
|
|
|
return image
|
|
|
|
|
|
class SDUnit_ShapeChecker(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("height", "width"),
|
|
output_params=("height", "width"),
|
|
)
|
|
|
|
def process(self, pipe: StableDiffusionPipeline, height, width):
|
|
height, width = pipe.check_resize_height_width(height, width)
|
|
return {"height": height, "width": width}
|
|
|
|
|
|
class SDUnit_PromptEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
seperate_cfg=True,
|
|
input_params_posi={"prompt": "prompt"},
|
|
input_params_nega={"prompt": "negative_prompt"},
|
|
output_params=("prompt_embeds",),
|
|
onload_model_names=("text_encoder",)
|
|
)
|
|
|
|
def encode_prompt(
|
|
self,
|
|
pipe: StableDiffusionPipeline,
|
|
prompt: str,
|
|
device: torch.device,
|
|
) -> torch.Tensor:
|
|
text_inputs = pipe.tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=pipe.tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids.to(device)
|
|
prompt_embeds = pipe.text_encoder(text_input_ids)
|
|
# TextEncoder returns (last_hidden_state, hidden_states) or just last_hidden_state.
|
|
# last_hidden_state is the post-final-layer-norm output, matching diffusers encode_prompt.
|
|
if isinstance(prompt_embeds, tuple):
|
|
prompt_embeds = prompt_embeds[0]
|
|
return prompt_embeds
|
|
|
|
def process(self, pipe: StableDiffusionPipeline, prompt):
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
|
return {"prompt_embeds": prompt_embeds}
|
|
|
|
|
|
class SDUnit_NoiseInitializer(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("height", "width", "seed", "rand_device"),
|
|
output_params=("noise",),
|
|
)
|
|
|
|
def process(self, pipe: StableDiffusionPipeline, height, width, seed, rand_device):
|
|
noise = pipe.generate_noise(
|
|
(1, pipe.unet.in_channels, height // 8, width // 8),
|
|
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
|
|
)
|
|
return {"noise": noise}
|
|
|
|
|
|
class SDUnit_InputImageEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_image", "noise"),
|
|
output_params=("latents", "input_latents"),
|
|
onload_model_names=("vae",),
|
|
)
|
|
|
|
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
|
|
if input_image is None:
|
|
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
|
|
if pipe.scheduler.training:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
input_tensor = pipe.preprocess_image(input_image)
|
|
input_latents = pipe.vae.encode(input_tensor).sample()
|
|
latents = noise * pipe.scheduler.init_noise_sigma
|
|
return {"latents": latents, "input_latents": input_latents}
|
|
else:
|
|
# Inference mode: VAE encode input image, add noise for initial latent
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
input_tensor = pipe.preprocess_image(input_image)
|
|
input_latents = pipe.vae.encode(input_tensor).sample()
|
|
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
|
return {"latents": latents}
|
|
|
|
|
|
def model_fn_stable_diffusion(
|
|
unet: UNet2DConditionModel,
|
|
latents=None,
|
|
timestep=None,
|
|
prompt_embeds=None,
|
|
cross_attention_kwargs=None,
|
|
timestep_cond=None,
|
|
added_cond_kwargs=None,
|
|
**kwargs,
|
|
):
|
|
# SD timestep is already in 0-999 range, no scaling needed
|
|
noise_pred = unet(
|
|
latents,
|
|
timestep,
|
|
encoder_hidden_states=prompt_embeds,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
timestep_cond=timestep_cond,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
return_dict=False,
|
|
)[0]
|
|
return noise_pred
|