mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
update wans2v training
This commit is contained in:
@@ -225,6 +225,13 @@ class ToAbsolutePath(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return os.path.join(self.base_path, data)
|
||||
|
||||
class LoadAudio(DataProcessingOperator):
|
||||
def __init__(self, sr=16000):
|
||||
self.sr = sr
|
||||
def __call__(self, data: str):
|
||||
import librosa
|
||||
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
||||
return {'input_audio':input_audio, 'sample_rate':sample_rate}
|
||||
|
||||
|
||||
class UnifiedDataset(torch.utils.data.Dataset):
|
||||
|
||||
@@ -603,6 +603,7 @@ def wan_parser():
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||
parser.add_argument("--audio_processor_config", type=str, default=None, help="Model ID with origin paths to the audio processor config, e.g., Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||
|
||||
Reference in New Issue
Block a user