qwen-image controlnet

This commit is contained in:
Artiprocher
2025-08-08 11:29:23 +08:00
parent 32cf5d32ce
commit 6e13deb6de
6 changed files with 284 additions and 3 deletions

View File

@@ -1,5 +1,5 @@
import torch, os, json
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
from diffsynth.models.lora import QwenImageLoRAConverter
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -73,8 +73,15 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
}
# Extra inputs
controlnet_input = {}
for extra_input in self.extra_inputs:
inputs_shared[extra_input] = data[extra_input]
if extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
if len(controlnet_input) > 0:
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units: