mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
rebuild base modules
This commit is contained in:
@@ -90,6 +90,8 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
original_dtype = sample.dtype
|
||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -110,10 +112,12 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user