mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
2.6 KiB
2.6 KiB
接入模型训练
在接入模型并实现 Pipeline后,接下来接入模型训练功能。
训推一致的 Pipeline 改造
为了保证训练和推理过程严格的一致性,我们会在训练过程中沿用大部分推理代码,但仍需作出少量改造。
首先,在推理过程中添加额外的逻辑,让图生图/视频生视频逻辑根据 scheduler 状态进行切换。以 Qwen-Image 为例:
class QwenImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
然后,在 model_fn 中启用 Gradient Checkpointing,这将以计算速度为代价,大幅度减少训练所需的显存。这并不是必需的,但我们强烈建议这么做。
以 Qwen-Image 为例,修改前:
text, image = block(
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
)
修改后:
from ..core import gradient_checkpoint_forward
text, image = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
)
编写训练脚本
DiffSynth-Studio 没有对训练框架做严格的封装,而是将脚本内容暴露给开发者,这种方式可以更方便地对训练脚本进行修改,实现额外的功能。开发者可参考现有的训练脚本,例如 examples/qwen_image/model_training/train.py 进行修改,从而适配新的模型训练。