mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
support wan-series models
This commit is contained in:
217
diffsynth/utils/data/__init__.py
Normal file
217
diffsynth/utils/data/__init__.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import imageio, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
|
||||
class LowMemoryVideo:
|
||||
def __init__(self, file_name):
|
||||
self.reader = imageio.get_reader(file_name)
|
||||
|
||||
def __len__(self):
|
||||
return self.reader.count_frames()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
||||
|
||||
def __del__(self):
|
||||
self.reader.close()
|
||||
|
||||
|
||||
def split_file_name(file_name):
|
||||
result = []
|
||||
number = -1
|
||||
for i in file_name:
|
||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
||||
if number == -1:
|
||||
number = 0
|
||||
number = number*10 + ord(i) - ord("0")
|
||||
else:
|
||||
if number != -1:
|
||||
result.append(number)
|
||||
number = -1
|
||||
result.append(i)
|
||||
if number != -1:
|
||||
result.append(number)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
|
||||
def search_for_images(folder):
|
||||
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
||||
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
||||
file_list = [i[1] for i in sorted(file_list)]
|
||||
file_list = [os.path.join(folder, i) for i in file_list]
|
||||
return file_list
|
||||
|
||||
|
||||
class LowMemoryImageFolder:
|
||||
def __init__(self, folder, file_list=None):
|
||||
if file_list is None:
|
||||
self.file_list = search_for_images(folder)
|
||||
else:
|
||||
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_list)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return Image.open(self.file_list[item]).convert("RGB")
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
|
||||
def crop_and_resize(image, height, width):
|
||||
image = np.array(image)
|
||||
image_height, image_width, _ = image.shape
|
||||
if image_height / image_width < height / width:
|
||||
croped_width = int(image_height / height * width)
|
||||
left = (image_width - croped_width) // 2
|
||||
image = image[:, left: left+croped_width]
|
||||
image = Image.fromarray(image).resize((width, height))
|
||||
else:
|
||||
croped_height = int(image_width / width * height)
|
||||
left = (image_height - croped_height) // 2
|
||||
image = image[left: left+croped_height, :]
|
||||
image = Image.fromarray(image).resize((width, height))
|
||||
return image
|
||||
|
||||
|
||||
class VideoData:
|
||||
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
||||
if video_file is not None:
|
||||
self.data_type = "video"
|
||||
self.data = LowMemoryVideo(video_file, **kwargs)
|
||||
elif image_folder is not None:
|
||||
self.data_type = "images"
|
||||
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
||||
else:
|
||||
raise ValueError("Cannot open video or image folder")
|
||||
self.length = None
|
||||
self.set_shape(height, width)
|
||||
|
||||
def raw_data(self):
|
||||
frames = []
|
||||
for i in range(self.__len__()):
|
||||
frames.append(self.__getitem__(i))
|
||||
return frames
|
||||
|
||||
def set_length(self, length):
|
||||
self.length = length
|
||||
|
||||
def set_shape(self, height, width):
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
def __len__(self):
|
||||
if self.length is None:
|
||||
return len(self.data)
|
||||
else:
|
||||
return self.length
|
||||
|
||||
def shape(self):
|
||||
if self.height is not None and self.width is not None:
|
||||
return self.height, self.width
|
||||
else:
|
||||
height, width, _ = self.__getitem__(0).shape
|
||||
return height, width
|
||||
|
||||
def __getitem__(self, item):
|
||||
frame = self.data.__getitem__(item)
|
||||
width, height = frame.size
|
||||
if self.height is not None and self.width is not None:
|
||||
if self.height != height or self.width != width:
|
||||
frame = crop_and_resize(frame, self.height, self.width)
|
||||
return frame
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
def save_images(self, folder):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
||||
frame = self.__getitem__(i)
|
||||
frame.save(os.path.join(folder, f"{i}.png"))
|
||||
|
||||
|
||||
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
||||
for frame in tqdm(frames, desc="Saving video"):
|
||||
frame = np.array(frame)
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
|
||||
def save_frames(frames, save_path):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
||||
frame.save(os.path.join(save_path, f"{i}.png"))
|
||||
|
||||
|
||||
def merge_video_audio(video_path: str, audio_path: str):
|
||||
# TODO: may need a in-python implementation to avoid subprocess dependency
|
||||
"""
|
||||
Merge the video and audio into a new video, with the duration set to the shorter of the two,
|
||||
and overwrite the original video file.
|
||||
|
||||
Parameters:
|
||||
video_path (str): Path to the original video file
|
||||
audio_path (str): Path to the audio file
|
||||
"""
|
||||
|
||||
# check
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"video file {video_path} does not exist")
|
||||
if not os.path.exists(audio_path):
|
||||
raise FileNotFoundError(f"audio file {audio_path} does not exist")
|
||||
|
||||
base, ext = os.path.splitext(video_path)
|
||||
temp_output = f"{base}_temp{ext}"
|
||||
|
||||
try:
|
||||
# create ffmpeg command
|
||||
command = [
|
||||
'ffmpeg',
|
||||
'-y', # overwrite
|
||||
'-i',
|
||||
video_path,
|
||||
'-i',
|
||||
audio_path,
|
||||
'-c:v',
|
||||
'copy', # copy video stream
|
||||
'-c:a',
|
||||
'aac', # use AAC audio encoder
|
||||
'-b:a',
|
||||
'192k', # set audio bitrate (optional)
|
||||
'-map',
|
||||
'0:v:0', # select the first video stream
|
||||
'-map',
|
||||
'1:a:0', # select the first audio stream
|
||||
'-shortest', # choose the shortest duration
|
||||
temp_output
|
||||
]
|
||||
|
||||
# execute the command
|
||||
result = subprocess.run(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# check result
|
||||
if result.returncode != 0:
|
||||
error_msg = f"FFmpeg execute failed: {result.stderr}"
|
||||
print(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
shutil.move(temp_output, video_path)
|
||||
print(f"Merge completed, saved to {video_path}")
|
||||
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_output):
|
||||
os.remove(temp_output)
|
||||
print(f"merge_video_audio failed with error: {e}")
|
||||
|
||||
|
||||
def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
|
||||
save_video(frames, save_path, fps, quality, ffmpeg_params)
|
||||
merge_video_audio(save_path, audio_path)
|
||||
@@ -0,0 +1,6 @@
|
||||
def WanAnimateAdapterStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
83
diffsynth/utils/state_dict_converters/wan_video_dit.py
Normal file
83
diffsynth/utils/state_dict_converters/wan_video_dit.py
Normal file
@@ -0,0 +1,83 @@
|
||||
def WanVideoDiTFromDiffusers(state_dict):
|
||||
rename_dict = {
|
||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = state_dict[name]
|
||||
else:
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def WanVideoDiTStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vace"):
|
||||
continue
|
||||
if name.split(".")[0] in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]:
|
||||
continue
|
||||
name_ = name
|
||||
if name_.startswith("model."):
|
||||
name_ = name_[len("model."):]
|
||||
state_dict_[name_] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -0,0 +1,8 @@
|
||||
def WanImageEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("textual."):
|
||||
continue
|
||||
name_ = "model." + name
|
||||
state_dict_[name_] = state_dict[name]
|
||||
return state_dict_
|
||||
77
diffsynth/utils/state_dict_converters/wan_video_mot.py
Normal file
77
diffsynth/utils/state_dict_converters/wan_video_mot.py
Normal file
@@ -0,0 +1,77 @@
|
||||
def WanVideoMotStateDictConverter(state_dict):
|
||||
rename_dict = {
|
||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
|
||||
mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if "_mot_ref" not in name:
|
||||
continue
|
||||
name = name.replace("_mot_ref", "")
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = state_dict[name]
|
||||
else:
|
||||
if name.split(".")[1].isdigit():
|
||||
block_id = int(name.split(".")[1])
|
||||
name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = state_dict[name]
|
||||
return state_dict_
|
||||
3
diffsynth/utils/state_dict_converters/wan_video_vace.py
Normal file
3
diffsynth/utils/state_dict_converters/wan_video_vace.py
Normal file
@@ -0,0 +1,3 @@
|
||||
def VaceWanModelDictConverter(state_dict):
|
||||
state_dict_ = {name: state_dict[name] for name in state_dict if name.startswith("vace")}
|
||||
return state_dict_
|
||||
7
diffsynth/utils/state_dict_converters/wan_video_vae.py
Normal file
7
diffsynth/utils/state_dict_converters/wan_video_vae.py
Normal file
@@ -0,0 +1,7 @@
|
||||
def WanVideoVAEStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
if 'model_state' in state_dict:
|
||||
state_dict = state_dict['model_state']
|
||||
for name in state_dict:
|
||||
state_dict_['model.' + name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -0,0 +1,3 @@
|
||||
def WanS2VAudioEncoderStateDictConverter(state_dict):
|
||||
state_dict = {'model.' + k: state_dict[k] for k in state_dict}
|
||||
return state_dict
|
||||
1
diffsynth/utils/xfuser/__init__.py
Normal file
1
diffsynth/utils/xfuser/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp
|
||||
145
diffsynth/utils/xfuser/xdit_context_parallel.py
Normal file
145
diffsynth/utils/xfuser/xdit_context_parallel.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from einops import rearrange
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
|
||||
|
||||
def initialize_usp():
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
|
||||
dist.init_process_group(backend="nccl", init_method="env://")
|
||||
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
initialize_model_parallel(
|
||||
sequence_parallel_degree=dist.get_world_size(),
|
||||
ring_degree=1,
|
||||
ulysses_degree=dist.get_world_size(),
|
||||
)
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x.to(position.dtype)
|
||||
|
||||
def pad_freqs(original_tensor, target_len):
|
||||
seq_len, s1, s2 = original_tensor.shape
|
||||
pad_size = target_len - seq_len
|
||||
padding_tensor = torch.ones(
|
||||
pad_size,
|
||||
s1,
|
||||
s2,
|
||||
dtype=original_tensor.dtype,
|
||||
device=original_tensor.device)
|
||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||
return padded_tensor
|
||||
|
||||
def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
s_per_rank = x.shape[1]
|
||||
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
|
||||
sp_size = get_sequence_parallel_world_size()
|
||||
sp_rank = get_sequence_parallel_rank()
|
||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||
|
||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
def usp_dit_forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if self.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
# Context Parallel
|
||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
x = self.head(x, t)
|
||||
|
||||
# Context Parallel
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
x = x[:, :-pad_shape] if pad_shape > 0 else x
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
|
||||
def usp_attn_forward(self, x, freqs):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
|
||||
x = xFuserLongContextAttention()(
|
||||
None,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
)
|
||||
x = x.flatten(2)
|
||||
|
||||
del q, k, v
|
||||
torch.cuda.empty_cache()
|
||||
return self.o(x)
|
||||
Reference in New Issue
Block a user