support qwen-image blockwise controlnet training

This commit is contained in:
Artiprocher
2025-08-15 18:41:01 +08:00
parent 024fdad76d
commit 01a1f48f70
17 changed files with 269 additions and 15 deletions

View File

@@ -80,17 +80,18 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
}
# Extra inputs
controlnet_input = {}
controlnet_input, blockwise_controlnet_input = {}, {}
for extra_input in self.extra_inputs:
if extra_input.startswith("blockwise_controlnet_"):
controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
elif extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
if len(controlnet_input) > 0:
controlnet_key = "blockwise_controlnet_inputs" if "blockwise_controlnet_image" in self.extra_inputs else "controlnet_inputs"
inputs_shared[controlnet_key] = [ControlNetInput(**controlnet_input)]
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
if len(blockwise_controlnet_input) > 0:
inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units: