mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:23:43 +00:00
205 lines
8.6 KiB
Python
205 lines
8.6 KiB
Python
import math
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None):
|
|
required_duration = num_sample / target_fps
|
|
required_origin_frames = int(np.ceil(required_duration * original_fps))
|
|
if required_duration > total_frames / original_fps:
|
|
raise ValueError("required_duration must be less than video length")
|
|
|
|
if not fixed_start is None and fixed_start >= 0:
|
|
start_frame = fixed_start
|
|
else:
|
|
max_start = total_frames - required_origin_frames
|
|
if max_start < 0:
|
|
raise ValueError("video length is too short")
|
|
start_frame = np.random.randint(0, max_start + 1)
|
|
start_time = start_frame / original_fps
|
|
|
|
end_time = start_time + required_duration
|
|
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
|
|
|
|
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
|
|
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
|
|
return frame_indices
|
|
|
|
|
|
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
|
"""
|
|
features: shape=[1, T, 512]
|
|
input_fps: fps for audio, f_a
|
|
output_fps: fps for video, f_m
|
|
output_len: video length
|
|
"""
|
|
features = features.transpose(1, 2)
|
|
seq_len = features.shape[2] / float(input_fps)
|
|
if output_len is None:
|
|
output_len = int(seq_len * output_fps)
|
|
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len]
|
|
return output_features.transpose(1, 2)
|
|
|
|
|
|
class WanS2VAudioEncoder(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Config
|
|
config = {
|
|
"_name_or_path": "facebook/wav2vec2-large-xlsr-53",
|
|
"activation_dropout": 0.05,
|
|
"apply_spec_augment": True,
|
|
"architectures": ["Wav2Vec2ForCTC"],
|
|
"attention_dropout": 0.1,
|
|
"bos_token_id": 1,
|
|
"conv_bias": True,
|
|
"conv_dim": [512, 512, 512, 512, 512, 512, 512],
|
|
"conv_kernel": [10, 3, 3, 3, 3, 2, 2],
|
|
"conv_stride": [5, 2, 2, 2, 2, 2, 2],
|
|
"ctc_loss_reduction": "mean",
|
|
"ctc_zero_infinity": True,
|
|
"do_stable_layer_norm": True,
|
|
"eos_token_id": 2,
|
|
"feat_extract_activation": "gelu",
|
|
"feat_extract_dropout": 0.0,
|
|
"feat_extract_norm": "layer",
|
|
"feat_proj_dropout": 0.05,
|
|
"final_dropout": 0.0,
|
|
"hidden_act": "gelu",
|
|
"hidden_dropout": 0.05,
|
|
"hidden_size": 1024,
|
|
"initializer_range": 0.02,
|
|
"intermediate_size": 4096,
|
|
"layer_norm_eps": 1e-05,
|
|
"layerdrop": 0.05,
|
|
"mask_channel_length": 10,
|
|
"mask_channel_min_space": 1,
|
|
"mask_channel_other": 0.0,
|
|
"mask_channel_prob": 0.0,
|
|
"mask_channel_selection": "static",
|
|
"mask_feature_length": 10,
|
|
"mask_feature_prob": 0.0,
|
|
"mask_time_length": 10,
|
|
"mask_time_min_space": 1,
|
|
"mask_time_other": 0.0,
|
|
"mask_time_prob": 0.05,
|
|
"mask_time_selection": "static",
|
|
"model_type": "wav2vec2",
|
|
"num_attention_heads": 16,
|
|
"num_conv_pos_embedding_groups": 16,
|
|
"num_conv_pos_embeddings": 128,
|
|
"num_feat_extract_layers": 7,
|
|
"num_hidden_layers": 24,
|
|
"pad_token_id": 0,
|
|
"transformers_version": "4.7.0.dev0",
|
|
"vocab_size": 33
|
|
}
|
|
self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))
|
|
self.video_rate = 30
|
|
|
|
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'):
|
|
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device)
|
|
|
|
# retrieve logits & take argmax
|
|
res = self.model(input_values, output_hidden_states=True)
|
|
if return_all_layers:
|
|
feat = torch.cat(res.hidden_states)
|
|
else:
|
|
feat = res.hidden_states[-1]
|
|
feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)
|
|
return feat
|
|
|
|
def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):
|
|
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
|
|
|
if num_layers > 1:
|
|
return_all_layers = True
|
|
else:
|
|
return_all_layers = False
|
|
|
|
min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
|
|
|
|
bucket_num = min_batch_num * batch_frames
|
|
batch_idx = [stride * i for i in range(bucket_num)]
|
|
batch_audio_eb = []
|
|
for bi in batch_idx:
|
|
if bi < audio_frame_num:
|
|
audio_sample_stride = 2
|
|
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
|
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
|
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
|
|
|
|
if return_all_layers:
|
|
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
|
else:
|
|
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
|
else:
|
|
frame_audio_embed = \
|
|
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
|
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
|
batch_audio_eb.append(frame_audio_embed)
|
|
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
|
|
|
return batch_audio_eb, min_batch_num
|
|
|
|
def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0):
|
|
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
|
|
|
if num_layers > 1:
|
|
return_all_layers = True
|
|
else:
|
|
return_all_layers = False
|
|
|
|
scale = self.video_rate / fps
|
|
|
|
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
|
|
|
|
bucket_num = min_batch_num * batch_frames
|
|
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num
|
|
batch_idx = get_sample_indices(
|
|
original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0
|
|
)
|
|
batch_audio_eb = []
|
|
audio_sample_stride = int(self.video_rate / fps)
|
|
for bi in batch_idx:
|
|
if bi < audio_frame_num:
|
|
|
|
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
|
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
|
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
|
|
|
|
if return_all_layers:
|
|
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
|
else:
|
|
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
|
else:
|
|
frame_audio_embed = \
|
|
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
|
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
|
batch_audio_eb.append(frame_audio_embed)
|
|
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
|
|
|
return batch_audio_eb, min_batch_num
|
|
|
|
def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'):
|
|
audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device)
|
|
audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m)
|
|
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)
|
|
audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]
|
|
return audio_embeds
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return WanS2VAudioEncoderStateDictConverter()
|
|
|
|
|
|
class WanS2VAudioEncoderStateDictConverter():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_civitai(self, state_dict):
|
|
state_dict = {'model.' + k: v for k, v in state_dict.items()}
|
|
return state_dict
|