mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
34 lines
1.3 KiB
Python
34 lines
1.3 KiB
Python
# This script is for initializing a Qwen-Image-ControlNet
|
|
from diffsynth import load_state_dict, hash_state_dict_keys
|
|
from diffsynth.pipelines.qwen_image import QwenImageControlNet
|
|
import torch
|
|
from safetensors.torch import save_file
|
|
|
|
|
|
state_dict_dit = {}
|
|
for i in range(1, 10):
|
|
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
|
|
|
|
controlnet = QwenImageControlNet().to(dtype=torch.bfloat16, device="cuda")
|
|
state_dict_controlnet = controlnet.state_dict()
|
|
|
|
state_dict_init = {}
|
|
for k in state_dict_controlnet:
|
|
if k in state_dict_dit:
|
|
if state_dict_dit[k].shape == state_dict_controlnet[k].shape:
|
|
state_dict_init[k] = state_dict_dit[k]
|
|
elif k == "img_in.weight":
|
|
state_dict_init[k] = torch.concat(
|
|
[
|
|
state_dict_dit[k],
|
|
state_dict_dit[k],
|
|
],
|
|
dim=-1
|
|
)
|
|
else:
|
|
print("Zero Initialized:", k)
|
|
state_dict_init[k] = torch.zeros_like(state_dict_controlnet[k])
|
|
controlnet.load_state_dict(state_dict_init)
|
|
|
|
print(hash_state_dict_keys(state_dict_init))
|
|
save_file(state_dict_init, "models/controlnet.safetensors") |