mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
75 lines
2.2 KiB
Python
75 lines
2.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from .sd3_dit import RMSNorm
|
|
from .utils import hash_state_dict_keys
|
|
|
|
|
|
class BlockWiseControlBlock(torch.nn.Module):
|
|
# [linear, gelu, linear]
|
|
def __init__(self, dim: int = 3072):
|
|
super().__init__()
|
|
self.x_rms = RMSNorm(dim, eps=1e-6)
|
|
self.y_rms = RMSNorm(dim, eps=1e-6)
|
|
self.input_proj = nn.Linear(dim, dim)
|
|
self.act = nn.GELU()
|
|
self.output_proj = nn.Linear(dim, dim)
|
|
|
|
def forward(self, x, y):
|
|
x, y = self.x_rms(x), self.y_rms(y)
|
|
x = self.input_proj(x + y)
|
|
x = self.act(x)
|
|
x = self.output_proj(x)
|
|
return x
|
|
|
|
def init_weights(self):
|
|
# zero initialize output_proj
|
|
nn.init.zeros_(self.output_proj.weight)
|
|
nn.init.zeros_(self.output_proj.bias)
|
|
|
|
|
|
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_layers: int = 60,
|
|
in_dim: int = 64,
|
|
additional_in_dim: int = 0,
|
|
dim: int = 3072,
|
|
):
|
|
super().__init__()
|
|
self.img_in = nn.Linear(in_dim + additional_in_dim, dim)
|
|
self.controlnet_blocks = nn.ModuleList(
|
|
[
|
|
BlockWiseControlBlock(dim)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
|
|
def init_weight(self):
|
|
nn.init.zeros_(self.img_in.weight)
|
|
nn.init.zeros_(self.img_in.bias)
|
|
for block in self.controlnet_blocks:
|
|
block.init_weights()
|
|
|
|
def process_controlnet_conditioning(self, controlnet_conditioning):
|
|
return self.img_in(controlnet_conditioning)
|
|
|
|
def blockwise_forward(self, img, controlnet_conditioning, block_id):
|
|
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return QwenImageBlockWiseControlNetStateDictConverter()
|
|
|
|
|
|
class QwenImageBlockWiseControlNetStateDictConverter():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_civitai(self, state_dict):
|
|
hash_value = hash_state_dict_keys(state_dict)
|
|
extra_kwargs = {}
|
|
if hash_value == "a9e54e480a628f0b956a688a81c33bab":
|
|
# inpaint controlnet
|
|
extra_kwargs = {"additional_in_dim": 4}
|
|
return state_dict, extra_kwargs
|