Files
DiffSynth-Studio/diffsynth/models/sdxl_vae_decoder.py
2024-07-26 12:15:40 +08:00

25 lines
762 B
Python

from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
class SDXLVAEDecoder(SDVAEDecoder):
def __init__(self, upcast_to_float32=True):
super().__init__()
self.scaling_factor = 0.13025
@staticmethod
def state_dict_converter():
return SDXLVAEDecoderStateDictConverter()
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
def __init__(self):
super().__init__()
def from_diffusers(self, state_dict):
state_dict = super().from_diffusers(state_dict)
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
state_dict = super().from_civitai(state_dict)
return state_dict, {"upcast_to_float32": True}