mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
compatibility update
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler
|
||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class ControlNetConditioningLayer(torch.nn.Module):
|
||||
@@ -92,20 +93,37 @@ class SDControlNet(torch.nn.Module):
|
||||
|
||||
self.global_pool = global_pool
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, conditioning):
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, conditioning,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
if tiled and not isinstance(block, PushBlock):
|
||||
_, _, inter_height, _ = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. ControlNet blocks
|
||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||
|
||||
Reference in New Issue
Block a user