mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
ltx iclora train
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import torch, os, argparse, accelerate, warnings
|
||||
from diffsynth.core import UnifiedDataset
|
||||
from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath, RouteByType, SequencialProcess
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.diffusion import *
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -69,6 +68,7 @@ class LTX2TrainingModule(DiffusionTrainingModule):
|
||||
"height": data["video"][0].size[1],
|
||||
"width": data["video"][0].size[0],
|
||||
"num_frames": len(data["video"]),
|
||||
"frame_rate": data.get("frame_rate", 24),
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"cfg_scale": 1,
|
||||
@@ -108,12 +108,7 @@ if __name__ == "__main__":
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_video_operator(
|
||||
video_processor = UnifiedDataset.default_video_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
@@ -123,9 +118,19 @@ if __name__ == "__main__":
|
||||
num_frames=args.num_frames,
|
||||
time_division_factor=8,
|
||||
time_division_remainder=1,
|
||||
),
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=video_processor,
|
||||
special_operator_map={
|
||||
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(duration=float(args.num_frames) / float(args.frame_rate)),
|
||||
"in_context_videos": RouteByType(operator_map=[
|
||||
(str, video_processor),
|
||||
(list, SequencialProcess(video_processor)),
|
||||
]),
|
||||
}
|
||||
)
|
||||
model = LTX2TrainingModule(
|
||||
|
||||
Reference in New Issue
Block a user