This commit is contained in:
Artiprocher
2025-11-27 22:43:43 +08:00
parent 0b527c460f
commit 0b72c2b3ba
10 changed files with 1329 additions and 17 deletions

View File

@@ -262,3 +262,121 @@ def FluxVAEDecoderStateDictConverter(state_dict):
param = state_dict[name]
state_dict_[rename_dict[name]] = param
return state_dict_
def FluxVAEEncoderStateDictConverterDiffusers(state_dict):
# architecture
block_types = [
'ResnetBlock', 'ResnetBlock', 'DownSampler',
'ResnetBlock', 'ResnetBlock', 'DownSampler',
'ResnetBlock', 'ResnetBlock', 'DownSampler',
'ResnetBlock', 'ResnetBlock',
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
]
# Rename each parameter
local_rename_dict = {
"quant_conv": "quant_conv",
"encoder.conv_in": "conv_in",
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
"encoder.conv_norm_out": "conv_norm_out",
"encoder.conv_out": "conv_out",
}
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
name_prefix = ".".join(names[:-1])
if name_prefix in local_rename_dict:
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
elif name.startswith("encoder.down_blocks"):
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
block_type_with_id = ".".join(names[:5])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:5])
names = ["blocks", str(block_id[block_type])] + names[5:]
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
state_dict_[rename_dict[name]] = state_dict[name]
return state_dict_
def FluxVAEDecoderStateDictConverterDiffusers(state_dict):
# architecture
block_types = [
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
]
# Rename each parameter
local_rename_dict = {
"post_quant_conv": "post_quant_conv",
"decoder.conv_in": "conv_in",
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
"decoder.conv_norm_out": "conv_norm_out",
"decoder.conv_out": "conv_out",
}
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
name_prefix = ".".join(names[:-1])
if name_prefix in local_rename_dict:
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
elif name.startswith("decoder.up_blocks"):
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
block_type_with_id = ".".join(names[:5])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:5])
names = ["blocks", str(block_id[block_type])] + names[5:]
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
state_dict_[rename_dict[name]] = state_dict[name]
return state_dict_