compatibility update

This commit is contained in:
Artiprocher
2023-12-23 20:13:41 +08:00
parent b30d0fa412
commit 66b3e995c2
27 changed files with 1051 additions and 398 deletions

View File

@@ -1,7 +1,7 @@
import torch
from .attention import Attention
from .sd_unet import ResnetBlock, UpSampler
from .tiler import Tiler
from .tiler import TileWorker
class VAEAttentionBlock(torch.nn.Module):
@@ -79,11 +79,13 @@ class SDVAEDecoder(torch.nn.Module):
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = Tiler()(
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states