mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
rebuild base modules
This commit is contained in:
@@ -50,6 +50,8 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
original_dtype = sample.dtype
|
||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -71,6 +73,7 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
hidden_states = self.quant_conv(hidden_states)
|
||||
hidden_states = hidden_states[:, :4]
|
||||
hidden_states *= self.scaling_factor
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -91,7 +94,8 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user