DiffSynth-Studio 2.0 major update

This commit is contained in:
root
2025-12-04 16:33:07 +08:00
parent afd101f345
commit 72af7122b3
758 changed files with 26462 additions and 2221398 deletions

View File

@@ -1,287 +0,0 @@
import torch, warnings, glob, os
import numpy as np
from PIL import Image
from einops import repeat, reduce
from typing import Optional, Union
from dataclasses import dataclass
from huggingface_hub import snapshot_download as hf_snapshot_download
from modelscope import snapshot_download
import numpy as np
from PIL import Image
from typing import Optional
class BasePipeline(torch.nn.Module):
def __init__(
self,
device="cuda", torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
super().__init__()
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.vram_management_enabled = False
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None):
# Shape check
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if num_frames is None:
return height, width
else:
if num_frames % self.time_division_factor != self.time_division_remainder:
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
return height, width, num_frames
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
# Transform a PIL.Image to torch.Tensor
image = torch.Tensor(np.array(image, dtype=np.float32))
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
image = image * ((max_value - min_value) / 255) + min_value
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
return image
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a list of PIL.Image to torch.Tensor
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
video = torch.stack(video, dim=pattern.index("T") // 2)
return video
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to PIL.Image
if pattern != "H W C":
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
image = image.to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(image.numpy())
return image
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to list of PIL.Image
if pattern != "T H W C":
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
return video
def load_models_to_device(self, model_names=[]):
if self.vram_management_enabled:
# offload models
for name, model in self.named_children():
if name not in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
torch.cuda.empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
# Initialize Gaussian noise
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
return noise
def enable_cpu_offload(self):
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
self.vram_management_enabled = True
def get_vram(self):
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
def freeze_except(self, model_names):
for name, model in self.named_children():
if name in model_names:
model.train()
model.requires_grad_(True)
else:
model.eval()
model.requires_grad_(False)
def blend_with_mask(self, base, addition, mask):
return base * (1 - mask) + addition * mask
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
timestep = scheduler.timesteps[progress_id]
if inpaint_mask is not None:
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
latents_next = scheduler.step(noise_pred, timestep, latents)
return latents_next
@dataclass
class ModelConfig:
path: Union[str, list[str]] = None
model_id: str = None
origin_file_pattern: Union[str, list[str]] = None
download_resource: str = "ModelScope"
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
local_model_path: str = None
skip_download: bool = False
def download_if_necessary(self, use_usp=False):
if self.path is None:
# Check model_id and origin_file_pattern
if self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
# Skip if not in rank 0
if use_usp:
import torch.distributed as dist
skip_download = self.skip_download or dist.get_rank() != 0
else:
skip_download = self.skip_download
# Check whether the origin path is a folder
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.origin_file_pattern = ""
allow_file_pattern = None
is_folder = True
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
allow_file_pattern = self.origin_file_pattern + "*"
is_folder = True
else:
allow_file_pattern = self.origin_file_pattern
is_folder = False
# Download
if self.local_model_path is None:
self.local_model_path = "./models"
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
if self.download_resource.lower() == "modelscope":
snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_file_pattern=allow_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
elif self.download_resource.lower() == "huggingface":
hf_snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_patterns=allow_file_pattern,
ignore_patterns=downloaded_files,
local_files_only=False
)
else:
raise ValueError("`download_resource` should be `modelscope` or `huggingface`.")
# Let rank 1, 2, ... wait for rank 0
if use_usp:
import torch.distributed as dist
dist.barrier(device_ids=[dist.get_rank()])
# Return downloaded files
if is_folder:
self.path = os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)
else:
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
class PipelineUnit:
def __init__(
self,
seperate_cfg: bool = False,
take_over: bool = False,
input_params: tuple[str] = None,
input_params_posi: dict[str, str] = None,
input_params_nega: dict[str, str] = None,
onload_model_names: tuple[str] = None
):
self.seperate_cfg = seperate_cfg
self.take_over = take_over
self.input_params = input_params
self.input_params_posi = input_params_posi
self.input_params_nega = input_params_nega
self.onload_model_names = onload_model_names
def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict:
raise NotImplementedError("`process` is not implemented.")
class PipelineUnitRunner:
def __init__(self):
pass
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
if unit.take_over:
# Let the pipeline unit take over this function.
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
elif unit.seperate_cfg:
# Positive side
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_posi.update(processor_outputs)
# Negative side
if inputs_shared["cfg_scale"] != 1:
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_nega.update(processor_outputs)
else:
inputs_nega.update(processor_outputs)
else:
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_shared.update(processor_outputs)
return inputs_shared, inputs_posi, inputs_nega

View File

@@ -0,0 +1,2 @@
from .controlnet_input import ControlNetInput
from .annotator import Annotator

View File

@@ -0,0 +1,62 @@
from typing_extensions import Literal, TypeAlias
Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
if not skip_processor:
if processor_id == "canny":
from controlnet_aux.processor import CannyDetector
self.processor = CannyDetector()
elif processor_id == "depth":
from controlnet_aux.processor import MidasDetector
self.processor = MidasDetector.from_pretrained(model_path).to(device)
elif processor_id == "softedge":
from controlnet_aux.processor import HEDdetector
self.processor = HEDdetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart":
from controlnet_aux.processor import LineartDetector
self.processor = LineartDetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart_anime":
from controlnet_aux.processor import LineartAnimeDetector
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
elif processor_id == "openpose":
from controlnet_aux.processor import OpenposeDetector
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
elif processor_id == "normal":
from controlnet_aux.processor import NormalBaeDetector
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
self.processor = None
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
else:
self.processor = None
self.processor_id = processor_id
self.detect_resolution = detect_resolution
def to(self,device):
if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
self.processor.model.to(device)
def __call__(self, image, mask=None):
width, height = image.size
if self.processor_id == "openpose":
kwargs = {
"include_body": True,
"include_hand": True,
"include_face": True
}
else:
kwargs = {}
if self.processor is not None:
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
image = image.resize((width, height))
return image

View File

@@ -0,0 +1,13 @@
from dataclasses import dataclass
from PIL import Image
@dataclass
class ControlNetInput:
controlnet_id: int = 0
scale: float = 1.0
start: float = 1.0
end: float = 0.0
image: Image.Image = None
inpaint_mask: Image.Image = None
processor_id: str = None

View File

@@ -0,0 +1,217 @@
import imageio, os
import numpy as np
from PIL import Image
from tqdm import tqdm
import subprocess
import shutil
class LowMemoryVideo:
def __init__(self, file_name):
self.reader = imageio.get_reader(file_name)
def __len__(self):
return self.reader.count_frames()
def __getitem__(self, item):
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
def __del__(self):
self.reader.close()
def split_file_name(file_name):
result = []
number = -1
for i in file_name:
if ord(i)>=ord("0") and ord(i)<=ord("9"):
if number == -1:
number = 0
number = number*10 + ord(i) - ord("0")
else:
if number != -1:
result.append(number)
number = -1
result.append(i)
if number != -1:
result.append(number)
result = tuple(result)
return result
def search_for_images(folder):
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
file_list = [i[1] for i in sorted(file_list)]
file_list = [os.path.join(folder, i) for i in file_list]
return file_list
class LowMemoryImageFolder:
def __init__(self, folder, file_list=None):
if file_list is None:
self.file_list = search_for_images(folder)
else:
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
def __len__(self):
return len(self.file_list)
def __getitem__(self, item):
return Image.open(self.file_list[item]).convert("RGB")
def __del__(self):
pass
def crop_and_resize(image, height, width):
image = np.array(image)
image_height, image_width, _ = image.shape
if image_height / image_width < height / width:
croped_width = int(image_height / height * width)
left = (image_width - croped_width) // 2
image = image[:, left: left+croped_width]
image = Image.fromarray(image).resize((width, height))
else:
croped_height = int(image_width / width * height)
left = (image_height - croped_height) // 2
image = image[left: left+croped_height, :]
image = Image.fromarray(image).resize((width, height))
return image
class VideoData:
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
if video_file is not None:
self.data_type = "video"
self.data = LowMemoryVideo(video_file, **kwargs)
elif image_folder is not None:
self.data_type = "images"
self.data = LowMemoryImageFolder(image_folder, **kwargs)
else:
raise ValueError("Cannot open video or image folder")
self.length = None
self.set_shape(height, width)
def raw_data(self):
frames = []
for i in range(self.__len__()):
frames.append(self.__getitem__(i))
return frames
def set_length(self, length):
self.length = length
def set_shape(self, height, width):
self.height = height
self.width = width
def __len__(self):
if self.length is None:
return len(self.data)
else:
return self.length
def shape(self):
if self.height is not None and self.width is not None:
return self.height, self.width
else:
height, width, _ = self.__getitem__(0).shape
return height, width
def __getitem__(self, item):
frame = self.data.__getitem__(item)
width, height = frame.size
if self.height is not None and self.width is not None:
if self.height != height or self.width != width:
frame = crop_and_resize(frame, self.height, self.width)
return frame
def __del__(self):
pass
def save_images(self, folder):
os.makedirs(folder, exist_ok=True)
for i in tqdm(range(self.__len__()), desc="Saving images"):
frame = self.__getitem__(i)
frame.save(os.path.join(folder, f"{i}.png"))
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
for frame in tqdm(frames, desc="Saving video"):
frame = np.array(frame)
writer.append_data(frame)
writer.close()
def save_frames(frames, save_path):
os.makedirs(save_path, exist_ok=True)
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
frame.save(os.path.join(save_path, f"{i}.png"))
def merge_video_audio(video_path: str, audio_path: str):
# TODO: may need a in-python implementation to avoid subprocess dependency
"""
Merge the video and audio into a new video, with the duration set to the shorter of the two,
and overwrite the original video file.
Parameters:
video_path (str): Path to the original video file
audio_path (str): Path to the audio file
"""
# check
if not os.path.exists(video_path):
raise FileNotFoundError(f"video file {video_path} does not exist")
if not os.path.exists(audio_path):
raise FileNotFoundError(f"audio file {audio_path} does not exist")
base, ext = os.path.splitext(video_path)
temp_output = f"{base}_temp{ext}"
try:
# create ffmpeg command
command = [
'ffmpeg',
'-y', # overwrite
'-i',
video_path,
'-i',
audio_path,
'-c:v',
'copy', # copy video stream
'-c:a',
'aac', # use AAC audio encoder
'-b:a',
'192k', # set audio bitrate (optional)
'-map',
'0:v:0', # select the first video stream
'-map',
'1:a:0', # select the first audio stream
'-shortest', # choose the shortest duration
temp_output
]
# execute the command
result = subprocess.run(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# check result
if result.returncode != 0:
error_msg = f"FFmpeg execute failed: {result.stderr}"
print(error_msg)
raise RuntimeError(error_msg)
shutil.move(temp_output, video_path)
print(f"Merge completed, saved to {video_path}")
except Exception as e:
if os.path.exists(temp_output):
os.remove(temp_output)
print(f"merge_video_audio failed with error: {e}")
def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
save_video(frames, save_path, fps, quality, ffmpeg_params)
merge_video_audio(save_path, audio_path)

View File

@@ -0,0 +1 @@
from .general import GeneralLoRALoader

View File

@@ -0,0 +1,204 @@
from .general import GeneralLoRALoader
import torch, math
class FluxLoRALoader(GeneralLoRALoader):
def __init__(self, device="cpu", torch_dtype=torch.float32):
super().__init__(device=device, torch_dtype=torch_dtype)
self.diffusers_rename_dict = {
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.weight",
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.weight",
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.weight",
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.weight",
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.weight",
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.weight",
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.weight",
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.weight",
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.weight",
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.weight",
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.weight",
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.weight",
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.weight",
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.weight",
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.weight",
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.weight",
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.weight",
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.weight",
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.weight",
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.weight",
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.weight",
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.weight",
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.weight",
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.weight",
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.weight",
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.weight",
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.weight",
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.weight",
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.weight",
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.weight",
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.weight",
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.weight",
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.weight",
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.weight",
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.weight",
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.weight",
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.weight",
}
self.civitai_rename_dict = {
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.weight",
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.weight",
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.weight",
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.weight",
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.weight",
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.weight",
}
def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
super().fuse_lora_to_base_model(model, state_dict_lora, alpha)
def convert_state_dict(self, state_dict):
def guess_block_id(name,model_resource):
if model_resource == 'civitai':
names = name.split("_")
for i in names:
if i.isdigit():
return i, name.replace(f"_{i}_", "_blockid_")
if model_resource == 'diffusers':
names = name.split(".")
for i in names:
if i.isdigit():
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
return None, None
def guess_resource(state_dict):
for k in state_dict:
if "lora_unet_" in k:
return 'civitai'
elif k.startswith("transformer."):
return 'diffusers'
else:
None
model_resource = guess_resource(state_dict)
if model_resource is None:
return state_dict
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
def guess_alpha(state_dict):
for name, param in state_dict.items():
if ".alpha" in name:
for suffix in [".lora_down.weight", ".lora_A.weight"]:
name_ = name.replace(".alpha", suffix)
if name_ in state_dict:
lora_alpha = param.item() / state_dict[name_].shape[0]
lora_alpha = math.sqrt(lora_alpha)
return lora_alpha
return 1
alpha = guess_alpha(state_dict)
state_dict_ = {}
for name, param in state_dict.items():
block_id, source_name = guess_block_id(name,model_resource)
if alpha != 1:
param *= alpha
if source_name in rename_dict:
target_name = rename_dict[source_name]
target_name = target_name.replace(".blockid.", f".{block_id}.")
state_dict_[target_name] = param
else:
state_dict_[name] = param
if model_resource == 'diffusers':
for name in list(state_dict_.keys()):
if "single_blocks." in name and ".a_to_q." in name:
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
if mlp is None:
dim = 4
if 'lora_A' in name:
dim = 1
mlp = torch.zeros(dim * state_dict_[name].shape[0],
*state_dict_[name].shape[1:],
dtype=state_dict_[name].dtype)
else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
if 'lora_A' in name:
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
elif 'lora_B' in name:
d, r = state_dict_[name].shape
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
param[:d, :r] = state_dict_.pop(name)
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
param[3*d:, 3*r:] = mlp
else:
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
state_dict_[name_] = param
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
concat_dim = 0
if 'lora_A' in name:
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
elif 'lora_B' in name:
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
d, r = origin.shape
# print(d, r)
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
else:
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_

View File

@@ -0,0 +1,62 @@
import torch
class GeneralLoRALoader:
def __init__(self, device="cpu", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
def get_name_dict(self, lora_state_dict):
lora_name_dict = {}
for key in lora_state_dict:
if ".lora_up." in key:
lora_A_key = "lora_down"
lora_B_key = "lora_up"
else:
lora_A_key = "lora_A"
lora_B_key = "lora_B"
if lora_B_key not in key:
continue
keys = key.split(".")
if len(keys) > keys.index(lora_B_key) + 2:
keys.pop(keys.index(lora_B_key) + 1)
keys.pop(keys.index(lora_B_key))
if keys[0] == "diffusion_model":
keys.pop(0)
keys.pop(-1)
target_name = ".".join(keys)
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key))
return lora_name_dict
def convert_state_dict(self, state_dict, suffix=".weight"):
name_dict = self.get_name_dict(state_dict)
state_dict_ = {}
for name in name_dict:
weight_up = state_dict[name_dict[name][0]]
weight_down = state_dict[name_dict[name][1]]
state_dict_[name + f".lora_B{suffix}"] = weight_up
state_dict_[name + f".lora_A{suffix}"] = weight_down
return state_dict_
def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict, alpha=1.0):
updated_num = 0
state_dict = self.convert_state_dict(state_dict)
lora_layer_names = set([i.replace(".lora_B.weight", "") for i in state_dict if i.endswith(".lora_B.weight")])
for name, module in model.named_modules():
if name in lora_layer_names:
weight_up = state_dict[name + ".lora_B.weight"].to(device=self.device, dtype=self.torch_dtype)
weight_down = state_dict[name + ".lora_A.weight"].to(device=self.device, dtype=self.torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
state_dict_base = module.state_dict()
state_dict_base["weight"] = state_dict_base["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
module.load_state_dict(state_dict_base)
updated_num += 1
print(f"{updated_num} tensors are fused by LoRA. Fused LoRA layers cannot be cleared by `pipe.clear_lora()`.")

View File

@@ -0,0 +1,20 @@
import torch
from typing import Dict, List
def merge_lora_weight(tensors_A, tensors_B):
lora_A = torch.concat(tensors_A, dim=0)
lora_B = torch.concat(tensors_B, dim=1)
return lora_A, lora_B
def merge_lora(loras: List[Dict[str, torch.Tensor]], alpha=1):
lora_merged = {}
keys = [i for i in loras[0].keys() if ".lora_A." in i]
for key in keys:
tensors_A = [lora[key] for lora in loras]
tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras]
lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B)
lora_merged[key] = lora_A * alpha
lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B
return lora_merged

View File

@@ -0,0 +1,17 @@
def Flux2TextEncoderStateDictConverter(state_dict):
rename_dict = {
"multi_modal_projector.linear_1.weight": "model.multi_modal_projector.linear_1.weight",
"multi_modal_projector.linear_2.weight": "model.multi_modal_projector.linear_2.weight",
"multi_modal_projector.norm.weight": "model.multi_modal_projector.norm.weight",
"multi_modal_projector.patch_merger.merging_layer.weight": "model.multi_modal_projector.patch_merger.merging_layer.weight",
"language_model.lm_head.weight": "lm_head.weight",
}
state_dict_ = {}
for k in state_dict:
k_ = k
k_ = k_.replace("language_model.model", "model.language_model")
k_ = k_.replace("vision_tower", "model.vision_tower")
if k_ in rename_dict:
k_ = rename_dict[k_]
state_dict_[k_] = state_dict[k]
return state_dict_

View File

@@ -0,0 +1,103 @@
import torch
def FluxControlNetStateDictConverter(state_dict):
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name in state_dict:
param = state_dict[name]
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
state_dict_[name] = param
else:
state_dict_[name] = param
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_

View File

@@ -0,0 +1,92 @@
import torch
def FluxDiTStateDictConverter(state_dict):
is_nexus_gen = sum([key.startswith("pipe.dit.") for key in state_dict]) > 0
if is_nexus_gen:
dit_state_dict = {}
for key in state_dict:
if key.startswith('pipe.dit.'):
param = state_dict[key]
new_key = key.replace("pipe.dit.", "")
if new_key.startswith("final_norm_out.linear."):
param = torch.concat([param[3072:], param[:3072]], dim=0)
dit_state_dict[new_key] = param
return dit_state_dict
rename_dict = {
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
"txt_in.bias": "context_embedder.bias",
"txt_in.weight": "context_embedder.weight",
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
"final_layer.linear.bias": "final_proj_out.bias",
"final_layer.linear.weight": "final_proj_out.weight",
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
"img_in.bias": "x_embedder.bias",
"img_in.weight": "x_embedder.weight",
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
}
suffix_rename_dict = {
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
"img_attn.proj.bias": "attn.a_to_out.bias",
"img_attn.proj.weight": "attn.a_to_out.weight",
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
"img_mlp.0.bias": "ff_a.0.bias",
"img_mlp.0.weight": "ff_a.0.weight",
"img_mlp.2.bias": "ff_a.2.bias",
"img_mlp.2.weight": "ff_a.2.weight",
"img_mod.lin.bias": "norm1_a.linear.bias",
"img_mod.lin.weight": "norm1_a.linear.weight",
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
"txt_attn.proj.bias": "attn.b_to_out.bias",
"txt_attn.proj.weight": "attn.b_to_out.weight",
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
"txt_mlp.0.bias": "ff_b.0.bias",
"txt_mlp.0.weight": "ff_b.0.weight",
"txt_mlp.2.bias": "ff_b.2.bias",
"txt_mlp.2.weight": "ff_b.2.weight",
"txt_mod.lin.bias": "norm1_b.linear.bias",
"txt_mod.lin.weight": "norm1_b.linear.weight",
"linear1.bias": "to_qkv_mlp.bias",
"linear1.weight": "to_qkv_mlp.weight",
"linear2.bias": "proj_out.bias",
"linear2.weight": "proj_out.weight",
"modulation.lin.bias": "norm.linear.bias",
"modulation.lin.weight": "norm.linear.weight",
"norm.key_norm.scale": "norm_k_a.weight",
"norm.query_norm.scale": "norm_q_a.weight",
}
state_dict_ = {}
for name in state_dict:
original_name = name
if name.startswith("model.diffusion_model."):
name = name[len("model.diffusion_model."):]
names = name.split(".")
if name in rename_dict:
rename = rename_dict[name]
state_dict_[rename] = state_dict[original_name]
elif names[0] == "double_blocks":
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = state_dict[original_name]
elif names[0] == "single_blocks":
if ".".join(names[2:]) in suffix_rename_dict:
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = state_dict[original_name]
else:
pass
return state_dict_

View File

@@ -0,0 +1,2 @@
def FluxInfiniteYouImageProjectorStateDictConverter(state_dict):
return state_dict['image_proj']

View File

@@ -0,0 +1,32 @@
def FluxIpAdapterStateDictConverter(state_dict):
state_dict_ = {}
if "ip_adapter" in state_dict and isinstance(state_dict["ip_adapter"], dict):
for name, param in state_dict["ip_adapter"].items():
name_ = 'ipadapter_modules.' + name
state_dict_[name_] = param
if "image_proj" in state_dict:
for name, param in state_dict["image_proj"].items():
name_ = "image_proj." + name
state_dict_[name_] = param
return state_dict_
for key, value in state_dict.items():
if key.startswith("image_proj."):
state_dict_[key] = value
elif key.startswith("ip_adapter."):
new_key = key.replace("ip_adapter.", "ipadapter_modules.")
state_dict_[new_key] = value
else:
pass
return state_dict_
def SiglipStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if key.startswith("vision_model."):
new_state_dict[key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,31 @@
def FluxTextEncoderClipStateDictConverter(state_dict):
rename_dict = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embeds",
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
"text_model.final_layer_norm.bias": "final_layer_norm.bias",
}
attn_rename_dict = {
"self_attn.q_proj": "attn.to_q",
"self_attn.k_proj": "attn.to_k",
"self_attn.v_proj": "attn.to_v",
"self_attn.out_proj": "attn.to_out",
"layer_norm1": "layer_norm1",
"layer_norm2": "layer_norm2",
"mlp.fc1": "fc1",
"mlp.fc2": "fc2",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif name.startswith("text_model.encoder.layers."):
param = state_dict[name]
names = name.split(".")
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
state_dict_[name_] = param
return state_dict_

View File

@@ -0,0 +1,4 @@
def FluxTextEncoderT5StateDictConverter(state_dict):
state_dict_ = {i: state_dict[i] for i in state_dict}
state_dict_["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
return state_dict_

View File

@@ -0,0 +1,382 @@
def FluxVAEEncoderStateDictConverter(state_dict):
rename_dict = {
"encoder.conv_in.bias": "conv_in.bias",
"encoder.conv_in.weight": "conv_in.weight",
"encoder.conv_out.bias": "conv_out.bias",
"encoder.conv_out.weight": "conv_out.weight",
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
"encoder.norm_out.bias": "conv_norm_out.bias",
"encoder.norm_out.weight": "conv_norm_out.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
state_dict_[rename_dict[name]] = param
return state_dict_
def FluxVAEDecoderStateDictConverter(state_dict):
rename_dict = {
"decoder.conv_in.bias": "conv_in.bias",
"decoder.conv_in.weight": "conv_in.weight",
"decoder.conv_out.bias": "conv_out.bias",
"decoder.conv_out.weight": "conv_out.weight",
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
"decoder.norm_out.bias": "conv_norm_out.bias",
"decoder.norm_out.weight": "conv_norm_out.weight",
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
state_dict_[rename_dict[name]] = param
return state_dict_
def FluxVAEEncoderStateDictConverterDiffusers(state_dict):
# architecture
block_types = [
'ResnetBlock', 'ResnetBlock', 'DownSampler',
'ResnetBlock', 'ResnetBlock', 'DownSampler',
'ResnetBlock', 'ResnetBlock', 'DownSampler',
'ResnetBlock', 'ResnetBlock',
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
]
# Rename each parameter
local_rename_dict = {
"quant_conv": "quant_conv",
"encoder.conv_in": "conv_in",
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
"encoder.conv_norm_out": "conv_norm_out",
"encoder.conv_out": "conv_out",
}
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
name_prefix = ".".join(names[:-1])
if name_prefix in local_rename_dict:
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
elif name.startswith("encoder.down_blocks"):
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
block_type_with_id = ".".join(names[:5])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:5])
names = ["blocks", str(block_id[block_type])] + names[5:]
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
state_dict_[rename_dict[name]] = state_dict[name]
return state_dict_
def FluxVAEDecoderStateDictConverterDiffusers(state_dict):
# architecture
block_types = [
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
]
# Rename each parameter
local_rename_dict = {
"post_quant_conv": "post_quant_conv",
"decoder.conv_in": "conv_in",
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
"decoder.conv_norm_out": "conv_norm_out",
"decoder.conv_out": "conv_out",
}
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
name_prefix = ".".join(names[:-1])
if name_prefix in local_rename_dict:
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
elif name.startswith("decoder.up_blocks"):
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
block_type_with_id = ".".join(names[:5])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:5])
names = ["blocks", str(block_id[block_type])] + names[5:]
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
state_dict_[rename_dict[name]] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,6 @@
def NexusGenAutoregressiveModelStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
value = state_dict[key]
new_state_dict["model." + key] = value
return new_state_dict

View File

@@ -0,0 +1,15 @@
def NexusGenMergerStateDictConverter(state_dict):
merger_state_dict = {}
for key in state_dict:
if key.startswith('embedding_merger.'):
value = state_dict[key]
new_key = key.replace("embedding_merger.", "")
merger_state_dict[new_key] = value
return merger_state_dict
def NexusGenAdapterStateDictConverter(state_dict):
adapter_state_dict = {}
for key in state_dict:
if key.startswith('adapter.'):
adapter_state_dict[key] = state_dict[key]
return adapter_state_dict

View File

@@ -0,0 +1,10 @@
def QwenImageTextEncoderStateDictConverter(state_dict):
state_dict_ = {}
for k in state_dict:
v = state_dict[k]
if k.startswith("visual."):
k = "model." + k
elif k.startswith("model."):
k = k.replace("model.", "model.language_model.")
state_dict_[k] = v
return state_dict_

View File

@@ -0,0 +1,7 @@
def Qwen2ConnectorStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("connector."):
name_ = name[len("connector."):]
state_dict_[name_] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,6 @@
def WanAnimateAdapterStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"):
state_dict_[name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,83 @@
def WanVideoDiTFromDiffusers(state_dict):
rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
state_dict_[rename_dict[name]] = state_dict[name]
else:
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = state_dict[name]
return state_dict_
def WanVideoDiTStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vace"):
continue
if name.split(".")[0] in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]:
continue
name_ = name
if name_.startswith("model."):
name_ = name_[len("model."):]
state_dict_[name_] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,8 @@
def WanImageEncoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("textual."):
continue
name_ = "model." + name
state_dict_[name_] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,78 @@
def WanVideoMotStateDictConverter(state_dict):
rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}
state_dict_ = {}
for name in state_dict:
if "_mot_ref" not in name:
continue
param = state_dict[name]
name = name.replace("_mot_ref", "")
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
if name.split(".")[1].isdigit():
block_id = int(name.split(".")[1])
name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param
return state_dict_

View File

@@ -0,0 +1,3 @@
def VaceWanModelDictConverter(state_dict):
state_dict_ = {name: state_dict[name] for name in state_dict if name.startswith("vace")}
return state_dict_

View File

@@ -0,0 +1,7 @@
def WanVideoVAEStateDictConverter(state_dict):
state_dict_ = {}
if 'model_state' in state_dict:
state_dict = state_dict['model_state']
for name in state_dict:
state_dict_['model.' + name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,12 @@
def WanS2VAudioEncoderStateDictConverter(state_dict):
rename_dict = {
"model.wav2vec2.encoder.pos_conv_embed.conv.weight_g": "model.wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0",
"model.wav2vec2.encoder.pos_conv_embed.conv.weight_v": "model.wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1",
}
state_dict_ = {}
for name in state_dict:
name_ = "model." + name
if name_ in rename_dict:
name_ = rename_dict[name_]
state_dict_[name_] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1 @@
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp

View File

@@ -0,0 +1,145 @@
import torch
from typing import Optional
from einops import rearrange
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
def initialize_usp():
import torch.distributed as dist
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
dist.init_process_group(backend="nccl", init_method="env://")
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
s_per_rank = x.shape[1]
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype)
def usp_dit_forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x, (f, h, w) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Context Parallel
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
x = chunks[get_sequence_parallel_rank()]
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
x = self.head(x, t)
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
x = x[:, :-pad_shape] if pad_shape > 0 else x
# unpatchify
x = self.unpatchify(x, (f, h, w))
return x
def usp_attn_forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
x = xFuserLongContextAttention()(
None,
query=q,
key=k,
value=v,
)
x = x.flatten(2)
del q, k, v
torch.cuda.empty_cache()
return self.o(x)