mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
ace-step train
This commit is contained in:
@@ -3,6 +3,7 @@ import torch, torchvision, imageio, os
|
||||
import imageio.v3 as iio
|
||||
from PIL import Image
|
||||
import torchaudio
|
||||
from diffsynth.utils.data.audio import read_audio
|
||||
|
||||
|
||||
class DataProcessingPipeline:
|
||||
@@ -276,3 +277,27 @@ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
|
||||
except:
|
||||
warnings.warn(f"Cannot load audio in {data}. The audio will be `None`.")
|
||||
return None
|
||||
|
||||
|
||||
class LoadPureAudioWithTorchaudio(DataProcessingOperator):
|
||||
|
||||
def __init__(self, target_sample_rate=None, target_duration=None):
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.target_duration = target_duration
|
||||
self.resample = True if target_sample_rate is not None else False
|
||||
|
||||
def __call__(self, data: str):
|
||||
try:
|
||||
waveform, sample_rate = read_audio(data, resample=self.resample, resample_rate=self.target_sample_rate)
|
||||
if self.target_duration is not None:
|
||||
target_samples = int(self.target_duration * sample_rate)
|
||||
current_samples = waveform.shape[-1]
|
||||
if current_samples > target_samples:
|
||||
waveform = waveform[..., :target_samples]
|
||||
elif current_samples < target_samples:
|
||||
padding = target_samples - current_samples
|
||||
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||
return waveform, sample_rate
|
||||
except Exception as e:
|
||||
warnings.warn(f"Cannot load audio in '{data}' due to '{e}'. The audio will be `None`.")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user