diff --git a/README.md b/README.md index aa116b7..11403a5 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ image.save("image.jpg") |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| |[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-||| |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| +|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)| diff --git a/README_zh.md b/README_zh.md index 7c0b569..650d2ec 100644 --- a/README_zh.md +++ b/README_zh.md @@ -100,6 +100,7 @@ image.save("image.jpg") |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| |[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-||| |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| +|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)| diff --git a/diffsynth/lora/flux_lora.py b/diffsynth/lora/flux_lora.py index 45baeaa..cb53b73 100644 --- a/diffsynth/lora/flux_lora.py +++ b/diffsynth/lora/flux_lora.py @@ -1,6 +1,8 @@ import torch, math -from diffsynth.lora import GeneralLoRALoader -from diffsynth.models.lora import FluxLoRAFromCivitai +from . import GeneralLoRALoader +from ..utils import ModelConfig +from ..models.utils import load_state_dict +from typing import Union class FluxLoRALoader(GeneralLoRALoader): @@ -276,3 +278,47 @@ class FluxLoraPatcherStateDictConverter: def from_civitai(self, state_dict): return state_dict + + +class FluxLoRAFuser: + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + self.device = device + self.torch_dtype = torch_dtype + + def Matrix_Decomposition_lowrank(self, A, k): + U, S, V = torch.svd_lowrank(A.float(), q=k) + S_k = torch.diag(S[:k]) + U_hat = U @ S_k + return U_hat, V.t() + + def LoRA_State_Dicts_Decomposition(self, lora_state_dicts=[], q=4): + lora_1 = lora_state_dicts[0] + state_dict_ = {} + for k,v in lora_1.items(): + if 'lora_A.' in k: + lora_B_name = k.replace('lora_A.', 'lora_B.') + lora_B = lora_1[lora_B_name] + weight = torch.mm(lora_B, v) + for lora_dict in lora_state_dicts[1:]: + lora_A_ = lora_dict[k] + lora_B_ = lora_dict[lora_B_name] + weight_ = torch.mm(lora_B_, lora_A_) + weight += weight_ + new_B, new_A = self.Matrix_Decomposition_lowrank(weight, q) + state_dict_[lora_B_name] = new_B.to(dtype=torch.bfloat16) + state_dict_[k] = new_A.to(dtype=torch.bfloat16) + return state_dict_ + + def __call__(self, lora_configs: list[Union[ModelConfig, str]]): + loras = [] + loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device) + for lora_config in lora_configs: + if isinstance(lora_config, str): + lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device) + else: + lora_config.download_if_necessary() + lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device) + lora = loader.convert_state_dict(lora) + loras.append(lora) + lora = self.LoRA_State_Dicts_Decomposition(loras) + return lora diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 6525dd4..2abd16c 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -22,8 +22,8 @@ from ..models.flux_value_control import MultiValueEncoder from ..models.flux_infiniteyou import InfiniteYouImageProjector from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock from ..models.tiler import FastTileWorker -from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit -from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher +from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit +from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher, FluxLoRAFuser from ..models.flux_dit import RMSNorm from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear @@ -125,18 +125,20 @@ class FluxImagePipeline(BasePipeline): def load_lora( self, module: torch.nn.Module, - lora_config: Union[ModelConfig, str], + lora_config: Union[ModelConfig, str] = None, alpha=1, hotload=False, - local_model_path="./models", - skip_download=False + state_dict=None, ): - if isinstance(lora_config, str): - lora_config = ModelConfig(path=lora_config) + if state_dict is None: + if isinstance(lora_config, str): + lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device) + else: + lora_config.download_if_necessary() + lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device) else: - lora_config.download_if_necessary(local_model_path, skip_download=skip_download) + lora = state_dict loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device) - lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device) lora = loader.convert_state_dict(lora) if hotload: for name, module in module.named_modules(): @@ -150,19 +152,21 @@ class FluxImagePipeline(BasePipeline): loader.load(module, lora, alpha=alpha) - def enable_lora_patcher(self): - if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled): - print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.") - return - if self.lora_patcher is None: - print("Please load lora patcher models before `enable_lora_patcher()`.") - return - for name, module in self.dit.named_modules(): - if isinstance(module, AutoWrappedLinear): - merger_name = name.replace(".", "___") - if merger_name in self.lora_patcher.model_dict: - module.lora_merger = self.lora_patcher.model_dict[merger_name] - + def load_loras( + self, + module: torch.nn.Module, + lora_configs: list[Union[ModelConfig, str]], + alpha=1, + hotload=False, + extra_fused_lora=False, + ): + for lora_config in lora_configs: + self.load_lora(module, lora_config, hotload=hotload, alpha=alpha) + if extra_fused_lora: + lora_fuser = FluxLoRAFuser(device="cuda", torch_dtype=torch.bfloat16) + fused_lora = lora_fuser(lora_configs) + self.load_lora(module, state_dict=fused_lora, hotload=hotload, alpha=alpha) + def clear_lora(self): for name, module in self.named_modules(): @@ -365,16 +369,11 @@ class FluxImagePipeline(BasePipeline): torch_dtype: torch.dtype = torch.bfloat16, device: Union[str, torch.device] = "cuda", model_configs: list[ModelConfig] = [], - tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), - local_model_path: str = "./models", - skip_download: bool = False, - redirect_common_files: bool = True, - use_usp=False, ): # Download and load models model_manager = ModelManager() for model_config in model_configs: - model_config.download_if_necessary(local_model_path, skip_download=skip_download) + model_config.download_if_necessary() model_manager.load_model( model_config.path, device=model_config.offload_device or device, diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 6902011..21b8ba0 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -12,6 +12,7 @@ from tqdm import tqdm from typing import Optional from typing_extensions import Literal +from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner from ..models import ModelManager, load_state_dict from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm @@ -26,196 +27,6 @@ from ..lora import GeneralLoRALoader -class BasePipeline(torch.nn.Module): - - def __init__( - self, - device="cuda", torch_dtype=torch.float16, - height_division_factor=64, width_division_factor=64, - time_division_factor=None, time_division_remainder=None, - ): - super().__init__() - # The device and torch_dtype is used for the storage of intermediate variables, not models. - self.device = device - self.torch_dtype = torch_dtype - # The following parameters are used for shape check. - self.height_division_factor = height_division_factor - self.width_division_factor = width_division_factor - self.time_division_factor = time_division_factor - self.time_division_remainder = time_division_remainder - self.vram_management_enabled = False - - - def to(self, *args, **kwargs): - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None: - self.device = device - if dtype is not None: - self.torch_dtype = dtype - super().to(*args, **kwargs) - return self - - - def check_resize_height_width(self, height, width, num_frames=None): - # Shape check - if height % self.height_division_factor != 0: - height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor - print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") - if width % self.width_division_factor != 0: - width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor - print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") - if num_frames is None: - return height, width - else: - if num_frames % self.time_division_factor != self.time_division_remainder: - num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder - print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") - return height, width, num_frames - - - def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): - # Transform a PIL.Image to torch.Tensor - image = torch.Tensor(np.array(image, dtype=np.float32)) - image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) - image = image * ((max_value - min_value) / 255) + min_value - image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) - return image - - - def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): - # Transform a list of PIL.Image to torch.Tensor - video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] - video = torch.stack(video, dim=pattern.index("T") // 2) - return video - - - def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): - # Transform a torch.Tensor to PIL.Image - if pattern != "H W C": - vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") - image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) - image = image.to(device="cpu", dtype=torch.uint8) - image = Image.fromarray(image.numpy()) - return image - - - def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): - # Transform a torch.Tensor to list of PIL.Image - if pattern != "T H W C": - vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") - video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] - return video - - - def load_models_to_device(self, model_names=[]): - if self.vram_management_enabled: - # offload models - for name, model in self.named_children(): - if name not in model_names: - if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: - for module in model.modules(): - if hasattr(module, "offload"): - module.offload() - else: - model.cpu() - torch.cuda.empty_cache() - # onload models - for name, model in self.named_children(): - if name in model_names: - if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: - for module in model.modules(): - if hasattr(module, "onload"): - module.onload() - else: - model.to(self.device) - - - def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): - # Initialize Gaussian noise - generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) - noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) - noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) - return noise - - - def enable_cpu_offload(self): - warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.") - self.vram_management_enabled = True - - - def get_vram(self): - return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3) - - - def freeze_except(self, model_names): - for name, model in self.named_children(): - if name in model_names: - model.train() - model.requires_grad_(True) - else: - model.eval() - model.requires_grad_(False) - - -@dataclass -class ModelConfig: - path: Union[str, list[str]] = None - model_id: str = None - origin_file_pattern: Union[str, list[str]] = None - download_resource: str = "ModelScope" - offload_device: Optional[Union[str, torch.device]] = None - offload_dtype: Optional[torch.dtype] = None - skip_download: bool = False - - def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False): - if self.path is None: - # Check model_id and origin_file_pattern - if self.model_id is None: - raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") - - # Skip if not in rank 0 - if use_usp: - import torch.distributed as dist - skip_download = dist.get_rank() != 0 - - # Check whether the origin path is a folder - if self.origin_file_pattern is None or self.origin_file_pattern == "": - self.origin_file_pattern = "" - allow_file_pattern = None - is_folder = True - elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"): - allow_file_pattern = self.origin_file_pattern + "*" - is_folder = True - else: - allow_file_pattern = self.origin_file_pattern - is_folder = False - - # Download - skip_download = skip_download or self.skip_download - if not skip_download: - downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id)) - snapshot_download( - self.model_id, - local_dir=os.path.join(local_model_path, self.model_id), - allow_file_pattern=allow_file_pattern, - ignore_file_pattern=downloaded_files, - local_files_only=False - ) - - # Let rank 1, 2, ... wait for rank 0 - if use_usp: - import torch.distributed as dist - dist.barrier(device_ids=[dist.get_rank()]) - - # Return downloaded files - if is_folder: - self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern) - else: - self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) - if isinstance(self.path, list) and len(self.path) == 1: - self.path = self.path[0] - - class WanVideoPipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): @@ -438,8 +249,6 @@ class WanVideoPipeline(BasePipeline): device: Union[str, torch.device] = "cuda", model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), - local_model_path: str = "./models", - skip_download: bool = False, redirect_common_files: bool = True, use_usp=False, ): @@ -464,7 +273,7 @@ class WanVideoPipeline(BasePipeline): # Download and load models model_manager = ModelManager() for model_config in model_configs: - model_config.download_if_necessary(local_model_path, skip_download=skip_download, use_usp=use_usp) + model_config.download_if_necessary(use_usp=use_usp) model_manager.load_model( model_config.path, device=model_config.offload_device or device, @@ -480,7 +289,7 @@ class WanVideoPipeline(BasePipeline): pipe.vace = model_manager.fetch_model("wan_video_vace") # Initialize tokenizer - tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download) + tokenizer_config.download_if_necessary(use_usp=use_usp) pipe.prompter.fetch_models(pipe.text_encoder) pipe.prompter.fetch_tokenizer(tokenizer_config.path) @@ -606,63 +415,6 @@ class WanVideoPipeline(BasePipeline): -class PipelineUnit: - def __init__( - self, - seperate_cfg: bool = False, - take_over: bool = False, - input_params: tuple[str] = None, - input_params_posi: dict[str, str] = None, - input_params_nega: dict[str, str] = None, - onload_model_names: tuple[str] = None - ): - self.seperate_cfg = seperate_cfg - self.take_over = take_over - self.input_params = input_params - self.input_params_posi = input_params_posi - self.input_params_nega = input_params_nega - self.onload_model_names = onload_model_names - - - def process(self, pipe: WanVideoPipeline, inputs: dict, positive=True, **kwargs) -> dict: - raise NotImplementedError("`process` is not implemented.") - - - -class PipelineUnitRunner: - def __init__(self): - pass - - def __call__(self, unit: PipelineUnit, pipe: WanVideoPipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: - if unit.take_over: - # Let the pipeline unit take over this function. - inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) - elif unit.seperate_cfg: - # Positive side - processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} - if unit.input_params is not None: - for name in unit.input_params: - processor_inputs[name] = inputs_shared.get(name) - processor_outputs = unit.process(pipe, **processor_inputs) - inputs_posi.update(processor_outputs) - # Negative side - if inputs_shared["cfg_scale"] != 1: - processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} - if unit.input_params is not None: - for name in unit.input_params: - processor_inputs[name] = inputs_shared.get(name) - processor_outputs = unit.process(pipe, **processor_inputs) - inputs_nega.update(processor_outputs) - else: - inputs_nega.update(processor_outputs) - else: - processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params} - processor_outputs = unit.process(pipe, **processor_inputs) - inputs_shared.update(processor_outputs) - return inputs_shared, inputs_posi, inputs_nega - - - class WanVideoUnit_ShapeChecker(PipelineUnit): def __init__(self): super().__init__(input_params=("height", "width", "num_frames")) diff --git a/diffsynth/utils/__init__.py b/diffsynth/utils/__init__.py new file mode 100644 index 0000000..58733c1 --- /dev/null +++ b/diffsynth/utils/__init__.py @@ -0,0 +1,261 @@ +import torch, warnings, glob, os +import numpy as np +from PIL import Image +from einops import repeat, reduce +from typing import Optional, Union +from dataclasses import dataclass +from modelscope import snapshot_download +import numpy as np +from PIL import Image +from typing import Optional + + +class BasePipeline(torch.nn.Module): + + def __init__( + self, + device="cuda", torch_dtype=torch.float16, + height_division_factor=64, width_division_factor=64, + time_division_factor=None, time_division_remainder=None, + ): + super().__init__() + # The device and torch_dtype is used for the storage of intermediate variables, not models. + self.device = device + self.torch_dtype = torch_dtype + # The following parameters are used for shape check. + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + self.vram_management_enabled = False + + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self.device = device + if dtype is not None: + self.torch_dtype = dtype + super().to(*args, **kwargs) + return self + + + def check_resize_height_width(self, height, width, num_frames=None): + # Shape check + if height % self.height_division_factor != 0: + height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor + print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") + if width % self.width_division_factor != 0: + width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor + print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") + if num_frames is None: + return height, width + else: + if num_frames % self.time_division_factor != self.time_division_remainder: + num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder + print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") + return height, width, num_frames + + + def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): + # Transform a PIL.Image to torch.Tensor + image = torch.Tensor(np.array(image, dtype=np.float32)) + image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + image = image * ((max_value - min_value) / 255) + min_value + image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) + return image + + + def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a list of PIL.Image to torch.Tensor + video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] + video = torch.stack(video, dim=pattern.index("T") // 2) + return video + + + def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to PIL.Image + if pattern != "H W C": + vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") + image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) + image = image.to(device="cpu", dtype=torch.uint8) + image = Image.fromarray(image.numpy()) + return image + + + def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to list of PIL.Image + if pattern != "T H W C": + vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") + video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] + return video + + + def load_models_to_device(self, model_names=[]): + if self.vram_management_enabled: + # offload models + for name, model in self.named_children(): + if name not in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "offload"): + module.offload() + else: + model.cpu() + torch.cuda.empty_cache() + # onload models + for name, model in self.named_children(): + if name in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "onload"): + module.onload() + else: + model.to(self.device) + + + def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): + # Initialize Gaussian noise + generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) + noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + return noise + + + def enable_cpu_offload(self): + warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.") + self.vram_management_enabled = True + + + def get_vram(self): + return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3) + + + def freeze_except(self, model_names): + for name, model in self.named_children(): + if name in model_names: + model.train() + model.requires_grad_(True) + else: + model.eval() + model.requires_grad_(False) + + +@dataclass +class ModelConfig: + path: Union[str, list[str]] = None + model_id: str = None + origin_file_pattern: Union[str, list[str]] = None + download_resource: str = "ModelScope" + offload_device: Optional[Union[str, torch.device]] = None + offload_dtype: Optional[torch.dtype] = None + local_model_path: str = None + skip_download: bool = False + + def download_if_necessary(self, use_usp=False): + if self.path is None: + # Check model_id and origin_file_pattern + if self.model_id is None: + raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") + + # Skip if not in rank 0 + if use_usp: + import torch.distributed as dist + skip_download = self.skip_download or dist.get_rank() != 0 + else: + skip_download = self.skip_download + + # Check whether the origin path is a folder + if self.origin_file_pattern is None or self.origin_file_pattern == "": + self.origin_file_pattern = "" + allow_file_pattern = None + is_folder = True + elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"): + allow_file_pattern = self.origin_file_pattern + "*" + is_folder = True + else: + allow_file_pattern = self.origin_file_pattern + is_folder = False + + # Download + if not skip_download: + if self.local_model_path is None: + self.local_model_path = "./models" + downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) + snapshot_download( + self.model_id, + local_dir=os.path.join(self.local_model_path, self.model_id), + allow_file_pattern=allow_file_pattern, + ignore_file_pattern=downloaded_files, + local_files_only=False + ) + + # Let rank 1, 2, ... wait for rank 0 + if use_usp: + import torch.distributed as dist + dist.barrier(device_ids=[dist.get_rank()]) + + # Return downloaded files + if is_folder: + self.path = os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern) + else: + self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) + if isinstance(self.path, list) and len(self.path) == 1: + self.path = self.path[0] + + + +class PipelineUnit: + def __init__( + self, + seperate_cfg: bool = False, + take_over: bool = False, + input_params: tuple[str] = None, + input_params_posi: dict[str, str] = None, + input_params_nega: dict[str, str] = None, + onload_model_names: tuple[str] = None + ): + self.seperate_cfg = seperate_cfg + self.take_over = take_over + self.input_params = input_params + self.input_params_posi = input_params_posi + self.input_params_nega = input_params_nega + self.onload_model_names = onload_model_names + + + def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict: + raise NotImplementedError("`process` is not implemented.") + + + +class PipelineUnitRunner: + def __init__(self): + pass + + def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: + if unit.take_over: + # Let the pipeline unit take over this function. + inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) + elif unit.seperate_cfg: + # Positive side + processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} + if unit.input_params is not None: + for name in unit.input_params: + processor_inputs[name] = inputs_shared.get(name) + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_posi.update(processor_outputs) + # Negative side + if inputs_shared["cfg_scale"] != 1: + processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} + if unit.input_params is not None: + for name in unit.input_params: + processor_inputs[name] = inputs_shared.get(name) + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_nega.update(processor_outputs) + else: + inputs_nega.update(processor_outputs) + else: + processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params} + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_shared.update(processor_outputs) + return inputs_shared, inputs_posi, inputs_nega diff --git a/examples/flux/README.md b/examples/flux/README.md index b1451d1..a66e2bc 100644 --- a/examples/flux/README.md +++ b/examples/flux/README.md @@ -52,6 +52,7 @@ image.save("image.jpg") |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| |[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./model_inference/FLUX.1-dev-EliGen.py)|[code](./model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-||| |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| +|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./model_inference/Step1X-Edit.py)|[code](./model_inference_low_vram/Step1X-Edit.py)|[code](./model_training/full/Step1X-Edit.sh)|[code](./model_training/validate_full/Step1X-Edit.py)|[code](./model_training/lora/Step1X-Edit.sh)|[code](./model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./model_inference/FLEX.2-preview.py)|[code](./model_inference_low_vram/FLEX.2-preview.py)|[code](./model_training/full/FLEX.2-preview.sh)|[code](./model_training/validate_full/FLEX.2-preview.py)|[code](./model_training/lora/FLEX.2-preview.sh)|[code](./model_training/validate_lora/FLEX.2-preview.py)| @@ -105,7 +106,7 @@ ModelConfig(path=[ ]) ``` -The `from_pretrained` method also provides extra arguments to control model loading behavior: +The `ModelConfig` method also provides extra arguments to control model loading behavior: * `local_model_path`: Path to save downloaded models. Default is `"./models"`. * `skip_download`: Whether to skip downloading. Default is `False`. If your network cannot access [ModelScope](https://modelscope.cn/ ), download the required files manually and set this to `True`. diff --git a/examples/flux/README_zh.md b/examples/flux/README_zh.md index aeade8f..3d3dc35 100644 --- a/examples/flux/README_zh.md +++ b/examples/flux/README_zh.md @@ -52,6 +52,7 @@ image.save("image.jpg") |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| |[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./model_inference/FLUX.1-dev-EliGen.py)|[code](./model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-||| |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| +|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./model_inference/Step1X-Edit.py)|[code](./model_inference_low_vram/Step1X-Edit.py)|[code](./model_training/full/Step1X-Edit.sh)|[code](./model_training/validate_full/Step1X-Edit.py)|[code](./model_training/lora/Step1X-Edit.sh)|[code](./model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./model_inference/FLEX.2-preview.py)|[code](./model_inference_low_vram/FLEX.2-preview.py)|[code](./model_training/full/FLEX.2-preview.sh)|[code](./model_training/validate_full/FLEX.2-preview.py)|[code](./model_training/lora/FLEX.2-preview.sh)|[code](./model_training/validate_lora/FLEX.2-preview.py)| @@ -105,7 +106,7 @@ ModelConfig(path=[ ]) ``` -`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为: +`ModelConfig` 还提供了额外的参数用于控制模型加载时的行为: * `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。 * `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。 diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py new file mode 100644 index 0000000..69b20db --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py @@ -0,0 +1,29 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors"), + ], +) +pipe.enable_lora_magic() + +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"), + hotload=True, +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"), + hotload=True, +) +image = pipe(prompt="a cat", seed=0) +image.save("image_fused.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py b/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py deleted file mode 100644 index 68116d0..0000000 --- a/examples/flux/model_inference/FLUX.1-dev-LoRAFusion.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig - - -pipe = FluxImagePipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), - ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-LoRAFusion", origin_file_pattern="model.safetensors") - ], -) -pipe.enable_vram_management() -pipe.enable_lora_patcher() -pipe.load_lora( - pipe.dit, - ModelConfig(model_id="yangyufeng/fgao", origin_file_pattern="30.safetensors"), - hotload=True -) -pipe.load_lora( - pipe.dit, - ModelConfig(model_id="bobooblue/LoRA-bling-mai", origin_file_pattern="10.safetensors"), - hotload=True -) -pipe.load_lora( - pipe.dit, - ModelConfig(model_id="JIETANGAB/E", origin_file_pattern="17.safetensors"), - hotload=True -) - -image = pipe(prompt="This is a digital painting in a soft, ethereal style. a beautiful Asian girl Shine like a diamond. Everywhere is shining with bling bling luster.The background is a textured blue with visible brushstrokes, giving the image an impressionistic style reminiscent of Vincent van Gogh's work", seed=0) -image.save("flux.jpg") diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 1b78e78..2c811f5 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -121,11 +121,14 @@ ModelConfig(path=[ ]) ``` -The `from_pretrained` function also provides additional parameters to control the behavior during model loading: +The `ModelConfig` function provides additional parameters to control the behavior during model loading: -* `tokenizer_config`: Path to the tokenizer of the Wan model. Default value is `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`. * `local_model_path`: Path where downloaded models are saved. Default value is `"./models"`. * `skip_download`: Whether to skip downloading models. Default value is `False`. When your network cannot access [ModelScope](https://modelscope.cn/), manually download the necessary files and set this to `True`. + +The `from_pretrained` function provides additional parameters to control the behavior during model loading: + +* `tokenizer_config`: Path to the tokenizer of the Wan model. Default value is `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`. * `redirect_common_files`: Whether to redirect duplicate model files. Default value is `True`. Since the Wan series models include multiple base models, some modules like text encoder are shared across these models. To avoid redundant downloads, we redirect the model paths. * `use_usp`: Whether to enable Unified Sequence Parallel. Default value is `False`. Used for multi-GPU parallel inference. diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index d9cd43b..517cd9e 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -120,11 +120,14 @@ ModelConfig(path=[ ]) ``` -`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为: +`ModelConfig` 提供了额外的参数用于控制模型加载时的行为: -* `tokenizer_config`: Wan 模型的 tokenizer 路径,默认值为 `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`。 * `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。 * `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。 + +`from_pretrained` 提供了额外的参数用于控制模型加载时的行为: + +* `tokenizer_config`: Wan 模型的 tokenizer 路径,默认值为 `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`。 * `redirect_common_files`: 是否重定向重复模型文件,默认值为 `True`。由于 Wan 系列模型包括多个基础模型,每个基础模型的 text encoder 等模块都是相同的,为避免重复下载,我们会对模型路径进行重定向。 * `use_usp`: 是否启用 Unified Sequence Parallel,默认值为 `False`。用于多 GPU 并行推理。