From e7a21dbf0b8ad0dd9a192f769831f3c1d19af7dd Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 19 Jun 2025 14:53:11 +0800 Subject: [PATCH] flux-refactor --- diffsynth/flux/FLUX.1-dev.py | 23 ++ diffsynth/pipelines/flux_image_new.py | 510 ++++++++++++++++++++++++++ diffsynth/pipelines/wan_video_new.py | 30 +- 3 files changed, 561 insertions(+), 2 deletions(-) create mode 100644 diffsynth/flux/FLUX.1-dev.py create mode 100644 diffsynth/pipelines/flux_image_new.py diff --git a/diffsynth/flux/FLUX.1-dev.py b/diffsynth/flux/FLUX.1-dev.py new file mode 100644 index 0000000..6355436 --- /dev/null +++ b/diffsynth/flux/FLUX.1-dev.py @@ -0,0 +1,23 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +image = pipe( + prompt="a girl", + seed=0, +) +image.save("0.jpg") \ No newline at end of file diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py new file mode 100644 index 0000000..bab84b8 --- /dev/null +++ b/diffsynth/pipelines/flux_image_new.py @@ -0,0 +1,510 @@ +import torch, warnings, glob, os, types +import numpy as np +from PIL import Image +from einops import repeat, reduce +from typing import Optional, Union +from dataclasses import dataclass +from modelscope import snapshot_download +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from typing_extensions import Literal + +from ..schedulers import FlowMatchScheduler +from ..prompters import FluxPrompter +from ..models import ModelManager, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder +from ..models.tiler import FastTileWorker +from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit + + + +class FluxImagePipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler() + self.prompter = FluxPrompter() + self.text_encoder_1: SD3TextEncoder1 = None + self.text_encoder_2: FluxTextEncoder2 = None + self.dit: FluxDiT = None + self.vae_decoder: FluxVAEDecoder = None + self.vae_encoder: FluxVAEEncoder = None + self.unit_runner = PipelineUnitRunner() + self.in_iteration_models = ("dit", ) + self.units = [ + FluxImageUnit_ShapeChecker(), + FluxImageUnit_NoiseInitializer(), + FluxImageUnit_PromptEmbedder(), + FluxImageUnit_InputImageEmbedder(), + FluxImageUnit_ImageIDs(), + FluxImageUnit_EmbeddedGuidanceEmbedder(), + ] + self.model_fn = model_fn_flux_image + + + def load_lora(self, module, path, alpha=1): + pass + + + def training_loss(self, **inputs): + timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) + timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) + + inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) + training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) + + noise_pred = self.model_fn(**inputs, timestep=timestep) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.scheduler.training_weight(timestep) + return loss + + + def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): + pass + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), + local_model_path: str = "./models", + skip_download: bool = False, + redirect_common_files: bool = True, + use_usp=False, + ): + # Download and load models + model_manager = ModelManager() + for model_config in model_configs: + model_config.download_if_necessary(local_model_path, skip_download=skip_download) + model_manager.load_model( + model_config.path, + device=model_config.offload_device or device, + torch_dtype=model_config.offload_dtype or torch_dtype + ) + + # Initialize pipeline + pipe = FluxImagePipeline(device=device, torch_dtype=torch_dtype) + pipe.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1") + pipe.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2") + pipe.dit = model_manager.fetch_model("flux_dit") + pipe.vae_decoder = model_manager.fetch_model("flux_vae_decoder") + pipe.vae_encoder = model_manager.fetch_model("flux_vae_encoder") + pipe.prompter.fetch_models(pipe.text_encoder_1, pipe.text_encoder_2) + + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt, + negative_prompt="", + cfg_scale=1.0, + embedded_guidance=3.5, + t5_sequence_length=512, + # Image + input_image=None, + denoising_strength=1.0, + # Shape + height=1024, + width=1024, + # Randomness + seed=None, + rand_device: Optional[str] = "cpu", + # Scheduler + sigma_shift=None, + # Steps + num_inference_steps=30, + # local prompts + multidiffusion_prompts=(), + multidiffusion_masks=(), + multidiffusion_scales=(), + # ControlNet + controlnet_inputs=None, + # IP-Adapter + ipadapter_images=None, + ipadapter_scale=1.0, + # EliGen + eligen_entity_prompts=None, + eligen_entity_masks=None, + eligen_enable_on_negative=False, + eligen_enable_inpaint=False, + # InfiniteYou + infinityou_id_image=None, + infinityou_guidance=1.0, + # Flex + flex_inpaint_image=None, + flex_inpaint_mask=None, + flex_control_image=None, + flex_control_strength=0.5, + flex_control_stop=0.5, + # Step1x + step1x_reference_image=None, + # TeaCache + tea_cache_l1_thresh=None, + # Tile + tiled=False, + tile_size=128, + tile_stride=64, + # Progress bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "t5_sequence_length": t5_sequence_length, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps, + "multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales, + "controlnet_inputs": controlnet_inputs, + "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale, + "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint, + "infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance, + "flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop, + "step1x_reference_image": step1x_reference_image, + "tea_cache_l1_thresh": tea_cache_l1_thresh, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "progress_bar_cmd": progress_bar_cmd, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + + # Decode + self.load_models_to_device(['vae_decoder']) + image = self.vae_decoder(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + + +class FluxImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width")) + + def process(self, pipe: FluxImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + + +class FluxImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "seed", "rand_device")) + + def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device) + return {"noise": noise} + + + +class FluxImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae_encoder']) + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": None} + + + +class FluxImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + input_params=("t5_sequence_length",), + onload_model_names=("text_encoder_1", "text_encoder_2") + ) + + def process(self, pipe: FluxImagePipeline, prompt, t5_sequence_length, positive) -> dict: + if pipe.text_encoder_1 is not None and pipe.text_encoder_2 is not None: + prompt_emb, pooled_prompt_emb, text_ids = pipe.prompter.encode_prompt( + prompt, device=pipe.device, positive=positive, t5_sequence_length=t5_sequence_length + ) + return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} + else: + return {} + + +class FluxImageUnit_ImageIDs(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents",)) + + def process(self, pipe: FluxImagePipeline, latents): + latent_image_ids = pipe.dit.prepare_image_ids(latents) + return {"image_ids": latent_image_ids} + + + +class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit): + def __init__(self): + super().__init__(input_params=("embedded_guidance", "latents")) + + def process(self, pipe: FluxImagePipeline, embedded_guidance, latents): + guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + return {"guidance": guidance} + + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + def check(self, dit: FluxDiT, hidden_states, conditioning): + inp = hidden_states.clone() + temb_ = conditioning.clone() + modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_) + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = hidden_states.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +def model_fn_flux_image( + dit: FluxDiT, + controlnet=None, + step1x_connector=None, + latents=None, + timestep=None, + prompt_emb=None, + pooled_prompt_emb=None, + guidance=None, + text_ids=None, + image_ids=None, + controlnet_frames=None, + tiled=False, + tile_size=128, + tile_stride=64, + entity_prompt_emb=None, + entity_masks=None, + ipadapter_kwargs_list={}, + id_emb=None, + infinityou_guidance=None, + flex_condition=None, + flex_uncondition=None, + flex_control_stop_timestep=None, + step1x_llm_embedding=None, + step1x_mask=None, + step1x_reference_latents=None, + tea_cache: TeaCache = None, + **kwargs +): + if tiled: + def flux_forward_fn(hl, hr, wl, wr): + tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None + return model_fn_flux_image( + dit=dit, + controlnet=controlnet, + latents=latents[:, :, hl: hr, wl: wr], + timestep=timestep, + prompt_emb=prompt_emb, + pooled_prompt_emb=pooled_prompt_emb, + guidance=guidance, + text_ids=text_ids, + image_ids=None, + controlnet_frames=tiled_controlnet_frames, + tiled=False, + **kwargs + ) + return FastTileWorker().tiled_forward( + flux_forward_fn, + latents, + tile_size=tile_size, + tile_stride=tile_stride, + tile_device=latents.device, + tile_dtype=latents.dtype + ) + + hidden_states = latents + + # ControlNet + if controlnet is not None and controlnet_frames is not None: + controlnet_extra_kwargs = { + "hidden_states": hidden_states, + "timestep": timestep, + "prompt_emb": prompt_emb, + "pooled_prompt_emb": pooled_prompt_emb, + "guidance": guidance, + "text_ids": text_ids, + "image_ids": image_ids, + "tiled": tiled, + "tile_size": tile_size, + "tile_stride": tile_stride, + } + if id_emb is not None: + controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype) + controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance}) + controlnet_res_stack, controlnet_single_res_stack = controlnet( + controlnet_frames, **controlnet_extra_kwargs + ) + + # Flex + if flex_condition is not None: + if timestep.tolist()[0] >= flex_control_stop_timestep: + hidden_states = torch.concat([hidden_states, flex_condition], dim=1) + else: + hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1) + + # Step1x + if step1x_llm_embedding is not None: + prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask) + text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device) + + if image_ids is None: + image_ids = dit.prepare_image_ids(hidden_states) + + conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb) + if dit.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) + + height, width = hidden_states.shape[-2:] + hidden_states = dit.patchify(hidden_states) + + # Step1x + if step1x_reference_latents is not None: + step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents) + step1x_reference_latents = dit.patchify(step1x_reference_latents) + image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1) + + hidden_states = dit.x_embedder(hidden_states) + + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) + else: + prompt_emb = dit.context_embedder(prompt_emb) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) + else: + tea_cache_update = False + + if tea_cache_update: + hidden_states = tea_cache.update(hidden_states) + else: + # Joint Blocks + for block_id, block in enumerate(dit.blocks): + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) + ) + # ControlNet + if controlnet is not None and controlnet_frames is not None: + hidden_states = hidden_states + controlnet_res_stack[block_id] + + # Single Blocks + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + num_joint_blocks = len(dit.blocks) + for block_id, block in enumerate(dit.single_blocks): + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) + ) + # ControlNet + if controlnet is not None and controlnet_frames is not None: + hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] + hidden_states = hidden_states[:, prompt_emb.shape[1]:] + + if tea_cache is not None: + tea_cache.store(hidden_states) + + hidden_states = dit.final_norm_out(hidden_states, conditioning) + hidden_states = dit.final_proj_out(hidden_states) + + # Step1x + if step1x_reference_latents is not None: + hidden_states = hidden_states[:, :hidden_states.shape[1] // 2] + + hidden_states = dit.unpatchify(hidden_states, height, width) + + return hidden_states diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 4a7a2a2..a93694c 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -168,24 +168,44 @@ class ModelConfig: def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False): if self.path is None: + # Check model_id and origin_file_pattern if self.model_id is None or self.origin_file_pattern is None: raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") + + # Skip if not in rank 0 if use_usp: import torch.distributed as dist skip_download = dist.get_rank() != 0 + + # Check whether the origin path is a folder + if isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"): + allow_file_pattern = self.origin_file_pattern + "*" + is_folder = True + else: + allow_file_pattern = self.origin_file_pattern + is_folder = False + + # Download if not skip_download: downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id)) snapshot_download( self.model_id, local_dir=os.path.join(local_model_path, self.model_id), - allow_file_pattern=self.origin_file_pattern, + allow_file_pattern=allow_file_pattern, ignore_file_pattern=downloaded_files, local_files_only=False ) + + # Let rank 1, 2, ... wait for rank 0 if use_usp: import torch.distributed as dist dist.barrier(device_ids=[dist.get_rank()]) - self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) + + # Return downloaded files + if is_folder: + self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern) + else: + self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) if isinstance(self.path, list) and len(self.path) == 1: self.path = self.path[0] @@ -614,11 +634,17 @@ class PipelineUnitRunner: elif unit.seperate_cfg: # Positive side processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} + if unit.input_params is not None: + for name in unit.input_params: + processor_inputs[name] = inputs_shared.get(name) processor_outputs = unit.process(pipe, **processor_inputs) inputs_posi.update(processor_outputs) # Negative side if inputs_shared["cfg_scale"] != 1: processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} + if unit.input_params is not None: + for name in unit.input_params: + processor_inputs[name] = inputs_shared.get(name) processor_outputs = unit.process(pipe, **processor_inputs) inputs_nega.update(processor_outputs) else: