mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
compatibility update
This commit is contained in:
@@ -279,31 +279,19 @@ class SDUNet(torch.nn.Module):
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, additional_res_stack=None, **kwargs):
|
||||
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
if additional_res_stack is not None and i==31:
|
||||
hidden_states += additional_res_stack.pop()
|
||||
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
|
||||
additional_res_stack = None
|
||||
if tiled:
|
||||
hidden_states, time_emb, text_emb, res_stack = self.tiled_inference(
|
||||
block, hidden_states, time_emb, text_emb, res_stack,
|
||||
height, width, tile_size, tile_stride
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -312,23 +300,6 @@ class SDUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def tiled_inference(self, block, hidden_states, time_emb, text_emb, res_stack, height, width, tile_size, tile_stride):
|
||||
if block.__class__.__name__ in ["ResnetBlock", "AttentionBlock", "DownSampler", "UpSampler"]:
|
||||
batch_size, inter_channel, inter_height, inter_width = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
|
||||
hidden_states = Tiler()(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
inter_device=hidden_states.device,
|
||||
inter_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDUNetStateDictConverter()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user