optimize stepvideo vae

This commit is contained in:
Artiprocher
2025-02-18 17:28:05 +08:00
parent f191353cf4
commit 9cff769fbd
7 changed files with 197 additions and 28 deletions

View File

@@ -13,7 +13,7 @@ from PIL import Image
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from transformers.models.bert.modeling_bert import BertEmbeddings
from ..models.stepvideo_dit import RMSNorm
from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Resnet3DBlock, AttnBlock, Res3DBlockUpsample, Upsample2D
from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Upsample2D, BaseGroupNorm
@@ -100,10 +100,8 @@ class StepVideoPipeline(BasePipeline):
torch.nn.Conv3d: AutoWrappedModule,
CausalConv: AutoWrappedModule,
CausalConvAfterNorm: AutoWrappedModule,
Resnet3DBlock: AutoWrappedModule,
AttnBlock: AutoWrappedModule,
Res3DBlockUpsample: AutoWrappedModule,
Upsample2D: AutoWrappedModule,
BaseGroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
@@ -143,7 +141,7 @@ class StepVideoPipeline(BasePipeline):
def tensor2video(self, frames):
frames = rearrange(frames, "T C H W -> T H W C")
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
@@ -163,9 +161,16 @@ class StepVideoPipeline(BasePipeline):
num_frames=204,
cfg_scale=9.0,
num_inference_steps=30,
tiled=True,
tile_size=(34, 34),
tile_stride=(16, 16),
smooth_scale=0.6,
progress_bar_cmd=lambda x: x,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
@@ -197,7 +202,7 @@ class StepVideoPipeline(BasePipeline):
# Decode
self.load_models_to_device(['vae'])
frames = self.vae.decode(latents)
frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs)
self.load_models_to_device([])
frames = self.tensor2video(frames[0])