mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
Merge pull request #884 from modelscope/dev2-dzj
Unified Dataset & Splited Training
This commit is contained in:
@@ -174,9 +174,12 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
computation_dtype=self.torch_dtype,
|
computation_dtype=self.torch_dtype,
|
||||||
computation_device="cuda",
|
computation_device="cuda",
|
||||||
)
|
)
|
||||||
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
|
if self.text_encoder is not None:
|
||||||
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
|
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
|
||||||
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
|
if self.dit is not None:
|
||||||
|
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
|
||||||
|
if self.vae is not None:
|
||||||
|
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
|
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
|
||||||
|
|||||||
334
diffsynth/trainers/unified_dataset.py
Normal file
334
diffsynth/trainers/unified_dataset.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
import torch, torchvision, imageio, os, json, pandas
|
||||||
|
import imageio.v3 as iio
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingPipeline:
|
||||||
|
def __init__(self, operators=None):
|
||||||
|
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
for operator in self.operators:
|
||||||
|
data = operator(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __rshift__(self, pipe):
|
||||||
|
if isinstance(pipe, DataProcessingOperator):
|
||||||
|
pipe = DataProcessingPipeline([pipe])
|
||||||
|
return DataProcessingPipeline(self.operators + pipe.operators)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingOperator:
|
||||||
|
def __call__(self, data):
|
||||||
|
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
|
||||||
|
|
||||||
|
def __rshift__(self, pipe):
|
||||||
|
if isinstance(pipe, DataProcessingOperator):
|
||||||
|
pipe = DataProcessingPipeline([pipe])
|
||||||
|
return DataProcessingPipeline([self]).__rshift__(pipe)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingOperatorRaw(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ToInt(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return int(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ToFloat(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return float(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ToStr(DataProcessingOperator):
|
||||||
|
def __init__(self, none_value=""):
|
||||||
|
self.none_value = none_value
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
if data is None: data = self.none_value
|
||||||
|
return str(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImage(DataProcessingOperator):
|
||||||
|
def __init__(self, convert_RGB=True):
|
||||||
|
self.convert_RGB = convert_RGB
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
image = Image.open(data)
|
||||||
|
if self.convert_RGB: image = image.convert("RGB")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCropAndResize(DataProcessingOperator):
|
||||||
|
def __init__(self, height, width, max_pixels, height_division_factor, width_division_factor):
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.max_pixels = max_pixels
|
||||||
|
self.height_division_factor = height_division_factor
|
||||||
|
self.width_division_factor = width_division_factor
|
||||||
|
|
||||||
|
def crop_and_resize(self, image, target_height, target_width):
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(target_width / width, target_height / height)
|
||||||
|
image = torchvision.transforms.functional.resize(
|
||||||
|
image,
|
||||||
|
(round(height*scale), round(width*scale)),
|
||||||
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
|
)
|
||||||
|
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||||
|
return image
|
||||||
|
|
||||||
|
def get_height_width(self, image):
|
||||||
|
if self.height is None or self.width is None:
|
||||||
|
width, height = image.size
|
||||||
|
if width * height > self.max_pixels:
|
||||||
|
scale = (width * height / self.max_pixels) ** 0.5
|
||||||
|
height, width = int(height / scale), int(width / scale)
|
||||||
|
height = height // self.height_division_factor * self.height_division_factor
|
||||||
|
width = width // self.width_division_factor * self.width_division_factor
|
||||||
|
else:
|
||||||
|
height, width = self.height, self.width
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, data: Image.Image):
|
||||||
|
image = self.crop_and_resize(data, *self.get_height_width(data))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ToList(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LoadVideo(DataProcessingOperator):
|
||||||
|
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# frame_processor is build in the video loader for high efficiency.
|
||||||
|
self.frame_processor = frame_processor
|
||||||
|
|
||||||
|
def get_num_frames(self, reader):
|
||||||
|
num_frames = self.num_frames
|
||||||
|
if int(reader.count_frames()) < num_frames:
|
||||||
|
num_frames = int(reader.count_frames())
|
||||||
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames -= 1
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
reader = imageio.get_reader(data)
|
||||||
|
num_frames = self.get_num_frames(reader)
|
||||||
|
frames = []
|
||||||
|
for frame_id in range(num_frames):
|
||||||
|
frame = reader.get_data(frame_id)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
frame = self.frame_processor(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
reader.close()
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SequencialProcess(DataProcessingOperator):
|
||||||
|
def __init__(self, operator=lambda x: x):
|
||||||
|
self.operator = operator
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return [self.operator(i) for i in data]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LoadGIF(DataProcessingOperator):
|
||||||
|
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# frame_processor is build in the video loader for high efficiency.
|
||||||
|
self.frame_processor = frame_processor
|
||||||
|
|
||||||
|
def get_num_frames(self, path):
|
||||||
|
num_frames = self.num_frames
|
||||||
|
images = iio.imread(path, mode="RGB")
|
||||||
|
if len(images) < num_frames:
|
||||||
|
num_frames = len(images)
|
||||||
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames -= 1
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
num_frames = self.get_num_frames(data)
|
||||||
|
frames = []
|
||||||
|
images = iio.imread(data, mode="RGB")
|
||||||
|
for img in images:
|
||||||
|
frame = Image.fromarray(img)
|
||||||
|
frame = self.frame_processor(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
if len(frames) >= num_frames:
|
||||||
|
break
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RouteByExtensionName(DataProcessingOperator):
|
||||||
|
def __init__(self, operator_map):
|
||||||
|
self.operator_map = operator_map
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
file_ext_name = data.split(".")[-1].lower()
|
||||||
|
for ext_names, operator in self.operator_map:
|
||||||
|
if ext_names is None or file_ext_name in ext_names:
|
||||||
|
return operator(data)
|
||||||
|
raise ValueError(f"Unsupported file: {data}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RouteByType(DataProcessingOperator):
|
||||||
|
def __init__(self, operator_map):
|
||||||
|
self.operator_map = operator_map
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
for dtype, operator in self.operator_map:
|
||||||
|
if dtype is None or isinstance(data, dtype):
|
||||||
|
return operator(data)
|
||||||
|
raise ValueError(f"Unsupported data: {data}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTorchPickle(DataProcessingOperator):
|
||||||
|
def __init__(self, map_location="cpu"):
|
||||||
|
self.map_location = map_location
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return torch.load(data, map_location=self.map_location, weights_only=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ToAbsolutePath(DataProcessingOperator):
|
||||||
|
def __init__(self, base_path=""):
|
||||||
|
self.base_path = base_path
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return os.path.join(self.base_path, data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_path=None, metadata_path=None,
|
||||||
|
repeat=1,
|
||||||
|
data_file_keys=tuple(),
|
||||||
|
main_data_operator=lambda x: x,
|
||||||
|
special_operator_map=None,
|
||||||
|
):
|
||||||
|
self.base_path = base_path
|
||||||
|
self.metadata_path = metadata_path
|
||||||
|
self.repeat = repeat
|
||||||
|
self.data_file_keys = data_file_keys
|
||||||
|
self.main_data_operator = main_data_operator
|
||||||
|
self.cached_data_operator = LoadTorchPickle()
|
||||||
|
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
||||||
|
self.data = []
|
||||||
|
self.cached_data = []
|
||||||
|
self.load_from_cache = metadata_path is None
|
||||||
|
self.load_metadata(metadata_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_image_operator(
|
||||||
|
base_path="",
|
||||||
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
):
|
||||||
|
return RouteByType(operator_map=[
|
||||||
|
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||||
|
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
|
||||||
|
])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_video_operator(
|
||||||
|
base_path="",
|
||||||
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
||||||
|
):
|
||||||
|
return RouteByType(operator_map=[
|
||||||
|
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||||
|
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||||
|
(("gif",), LoadGIF(num_frames, time_division_factor, time_division_remainder) >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||||
|
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||||
|
num_frames, time_division_factor, time_division_remainder,
|
||||||
|
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||||
|
)),
|
||||||
|
])),
|
||||||
|
])
|
||||||
|
|
||||||
|
def search_for_cached_data_files(self, path):
|
||||||
|
for file_name in os.listdir(path):
|
||||||
|
subpath = os.path.join(path, file_name)
|
||||||
|
if os.path.isdir(subpath):
|
||||||
|
self.search_for_cached_data_files(subpath)
|
||||||
|
elif subpath.endswith(".pth"):
|
||||||
|
self.cached_data.append(subpath)
|
||||||
|
|
||||||
|
def load_metadata(self, metadata_path):
|
||||||
|
if metadata_path is None:
|
||||||
|
print("No metadata_path. Searching for cached data files.")
|
||||||
|
self.search_for_cached_data_files(self.base_path)
|
||||||
|
print(f"{len(self.cached_data)} cached data files found.")
|
||||||
|
elif metadata_path.endswith(".json"):
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.data = metadata
|
||||||
|
elif metadata_path.endswith(".jsonl"):
|
||||||
|
metadata = []
|
||||||
|
with open(metadata_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
metadata.append(json.loads(line.strip()))
|
||||||
|
self.data = metadata
|
||||||
|
else:
|
||||||
|
metadata = pandas.read_csv(metadata_path)
|
||||||
|
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
if self.load_from_cache:
|
||||||
|
data = self.cached_data[data_id % len(self.cached_data)]
|
||||||
|
data = self.cached_data_operator(data)
|
||||||
|
else:
|
||||||
|
data = self.data[data_id % len(self.data)].copy()
|
||||||
|
for key in self.data_file_keys:
|
||||||
|
if key in data:
|
||||||
|
if key in self.special_operator_map:
|
||||||
|
data[key] = self.special_operator_map[key]
|
||||||
|
elif key in self.data_file_keys:
|
||||||
|
data[key] = self.main_data_operator(data[key])
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.load_from_cache:
|
||||||
|
return len(self.cached_data) * self.repeat
|
||||||
|
else:
|
||||||
|
return len(self.data) * self.repeat
|
||||||
|
|
||||||
|
def check_data_equal(self, data1, data2):
|
||||||
|
# Debug only
|
||||||
|
if len(data1) != len(data2):
|
||||||
|
return False
|
||||||
|
for k in data1:
|
||||||
|
if data1[k] != data2[k]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
@@ -417,6 +417,13 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
state_dict_[name] = param
|
state_dict_[name] = param
|
||||||
state_dict = state_dict_
|
state_dict = state_dict_
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def transfer_data_to_device(self, data, device):
|
||||||
|
for key in data:
|
||||||
|
if isinstance(data[key], torch.Tensor):
|
||||||
|
data[key] = data[key].to(device)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -484,7 +491,10 @@ def launch_training_task(
|
|||||||
for data in tqdm(dataloader):
|
for data in tqdm(dataloader):
|
||||||
with accelerator.accumulate(model):
|
with accelerator.accumulate(model):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = model(data)
|
if dataset.load_from_cache:
|
||||||
|
loss = model({}, inputs=data)
|
||||||
|
else:
|
||||||
|
loss = model(data)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
model_logger.on_step_end(accelerator, model, save_steps)
|
model_logger.on_step_end(accelerator, model, save_steps)
|
||||||
@@ -494,16 +504,24 @@ def launch_training_task(
|
|||||||
model_logger.on_training_end(accelerator, model, save_steps)
|
model_logger.on_training_end(accelerator, model, save_steps)
|
||||||
|
|
||||||
|
|
||||||
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
def launch_data_process_task(
|
||||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0])
|
dataset: torch.utils.data.Dataset,
|
||||||
|
model: DiffusionTrainingModule,
|
||||||
|
model_logger: ModelLogger,
|
||||||
|
num_workers: int = 8,
|
||||||
|
):
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
accelerator = Accelerator()
|
accelerator = Accelerator()
|
||||||
model, dataloader = accelerator.prepare(model, dataloader)
|
model, dataloader = accelerator.prepare(model, dataloader)
|
||||||
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True)
|
|
||||||
for data_id, data in enumerate(tqdm(dataloader)):
|
for data_id, data in tqdm(enumerate(dataloader)):
|
||||||
with torch.no_grad():
|
with accelerator.accumulate(model):
|
||||||
inputs = model.forward_preprocess(data)
|
with torch.no_grad():
|
||||||
inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs}
|
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
|
||||||
torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth"))
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||||
|
data = model(data)
|
||||||
|
torch.save(data, save_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import torch, os, json
|
import torch, os, json
|
||||||
from diffsynth import load_state_dict
|
from diffsynth import load_state_dict
|
||||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, flux_parser
|
||||||
from diffsynth.models.lora import FluxLoRAConverter
|
from diffsynth.models.lora import FluxLoRAConverter
|
||||||
|
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@@ -106,7 +107,20 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = flux_parser()
|
parser = flux_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
dataset = ImageDataset(args=args)
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=16,
|
||||||
|
width_division_factor=16,
|
||||||
|
)
|
||||||
|
)
|
||||||
model = FluxTrainingModule(
|
model = FluxTrainingModule(
|
||||||
model_paths=args.model_paths,
|
model_paths=args.model_paths,
|
||||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
|||||||
@@ -0,0 +1,25 @@
|
|||||||
|
accelerate launch examples/qwen_image/model_training/train_data_process.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--output_path "./models/train/Qwen-Image_lora_cache" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8
|
||||||
|
|
||||||
|
accelerate launch examples/qwen_image/model_training/train.py \
|
||||||
|
--dataset_base_path models/train/Qwen-Image_lora_cache \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Qwen-Image_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--find_unused_parameters \
|
||||||
|
--enable_fp8_training
|
||||||
@@ -2,7 +2,8 @@ import torch, os, json
|
|||||||
from diffsynth import load_state_dict
|
from diffsynth import load_state_dict
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, qwen_image_parser
|
||||||
|
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@@ -110,6 +111,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
|
|
||||||
def forward(self, data, inputs=None):
|
def forward(self, data, inputs=None):
|
||||||
if inputs is None: inputs = self.forward_preprocess(data)
|
if inputs is None: inputs = self.forward_preprocess(data)
|
||||||
|
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
|
||||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||||
loss = self.pipe.training_loss(**models, **inputs)
|
loss = self.pipe.training_loss(**models, **inputs)
|
||||||
return loss
|
return loss
|
||||||
@@ -119,7 +121,20 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = qwen_image_parser()
|
parser = qwen_image_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
dataset = ImageDataset(args=args)
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=16,
|
||||||
|
width_division_factor=16,
|
||||||
|
)
|
||||||
|
)
|
||||||
model = QwenImageTrainingModule(
|
model = QwenImageTrainingModule(
|
||||||
model_paths=args.model_paths,
|
model_paths=args.model_paths,
|
||||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
|||||||
154
examples/qwen_image/model_training/train_data_process.py
Normal file
154
examples/qwen_image/model_training/train_data_process.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import torch, os, json
|
||||||
|
from diffsynth import load_state_dict
|
||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
|
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||||
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_data_process_task, qwen_image_parser
|
||||||
|
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
tokenizer_path=None, processor_path=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
enable_fp8_training=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
|
||||||
|
model_configs = []
|
||||||
|
if model_paths is not None:
|
||||||
|
model_paths = json.loads(model_paths)
|
||||||
|
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
|
||||||
|
if model_id_with_origin_paths is not None:
|
||||||
|
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||||
|
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
|
||||||
|
|
||||||
|
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||||
|
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
|
||||||
|
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
|
||||||
|
|
||||||
|
# Enable FP8
|
||||||
|
if enable_fp8_training:
|
||||||
|
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
# Reset training scheduler (do it in each training step)
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
# Freeze untrainable models
|
||||||
|
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||||
|
|
||||||
|
# Add LoRA to the base models
|
||||||
|
if lora_base_model is not None:
|
||||||
|
model = self.add_lora_to_model(
|
||||||
|
getattr(self.pipe, lora_base_model),
|
||||||
|
target_modules=lora_target_modules.split(","),
|
||||||
|
lora_rank=lora_rank,
|
||||||
|
upcast_dtype=self.pipe.torch_dtype,
|
||||||
|
)
|
||||||
|
if lora_checkpoint is not None:
|
||||||
|
state_dict = load_state_dict(lora_checkpoint)
|
||||||
|
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||||
|
load_result = model.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||||
|
if len(load_result[1]) > 0:
|
||||||
|
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||||
|
setattr(self.pipe, lora_base_model, model)
|
||||||
|
|
||||||
|
# Store other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
|
||||||
|
|
||||||
|
def forward_preprocess(self, data):
|
||||||
|
# CFG-sensitive parameters
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
|
||||||
|
# CFG-unsensitive parameters
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
"edit_image_auto_resize": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extra inputs
|
||||||
|
controlnet_input, blockwise_controlnet_input = {}, {}
|
||||||
|
for extra_input in self.extra_inputs:
|
||||||
|
if extra_input.startswith("blockwise_controlnet_"):
|
||||||
|
blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
|
||||||
|
elif extra_input.startswith("controlnet_"):
|
||||||
|
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
|
||||||
|
else:
|
||||||
|
inputs_shared[extra_input] = data[extra_input]
|
||||||
|
if len(controlnet_input) > 0:
|
||||||
|
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
|
||||||
|
if len(blockwise_controlnet_input) > 0:
|
||||||
|
inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)]
|
||||||
|
|
||||||
|
# Pipeline units will automatically process the input parameters.
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
return {**inputs_shared, **inputs_posi}
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.forward_preprocess(data)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = qwen_image_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=1, # Set repeat = 1
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=16,
|
||||||
|
width_division_factor=16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = QwenImageTrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
tokenizer_path=args.tokenizer_path,
|
||||||
|
processor_path=args.processor_path,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_checkpoint=args.lora_checkpoint,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
enable_fp8_training=args.enable_fp8_training,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
||||||
|
launch_data_process_task(
|
||||||
|
dataset, model, model_logger,
|
||||||
|
num_workers=args.dataset_num_workers,
|
||||||
|
)
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
import torch, os, json
|
import torch, os, json
|
||||||
from diffsynth import load_state_dict
|
from diffsynth import load_state_dict
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
|
||||||
|
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@@ -112,7 +113,23 @@ class WanTrainingModule(DiffusionTrainingModule):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = wan_parser()
|
parser = wan_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
dataset = VideoDataset(args=args)
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_video_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=16,
|
||||||
|
width_division_factor=16,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
time_division_factor=4,
|
||||||
|
time_division_remainder=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
model = WanTrainingModule(
|
model = WanTrainingModule(
|
||||||
model_paths=args.model_paths,
|
model_paths=args.model_paths,
|
||||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
|||||||
Reference in New Issue
Block a user