mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
optimize stepvideo vae
This commit is contained in:
@@ -13,7 +13,7 @@ from PIL import Image
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings
|
||||
from ..models.stepvideo_dit import RMSNorm
|
||||
from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Resnet3DBlock, AttnBlock, Res3DBlockUpsample, Upsample2D
|
||||
from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Upsample2D, BaseGroupNorm
|
||||
|
||||
|
||||
|
||||
@@ -100,10 +100,8 @@ class StepVideoPipeline(BasePipeline):
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
CausalConv: AutoWrappedModule,
|
||||
CausalConvAfterNorm: AutoWrappedModule,
|
||||
Resnet3DBlock: AutoWrappedModule,
|
||||
AttnBlock: AutoWrappedModule,
|
||||
Res3DBlockUpsample: AutoWrappedModule,
|
||||
Upsample2D: AutoWrappedModule,
|
||||
BaseGroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -143,7 +141,7 @@ class StepVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
frames = rearrange(frames, "T C H W -> T H W C")
|
||||
frames = rearrange(frames, "C T H W -> T H W C")
|
||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
@@ -163,9 +161,16 @@ class StepVideoPipeline(BasePipeline):
|
||||
num_frames=204,
|
||||
cfg_scale=9.0,
|
||||
num_inference_steps=30,
|
||||
tiled=True,
|
||||
tile_size=(34, 34),
|
||||
tile_stride=(16, 16),
|
||||
smooth_scale=0.6,
|
||||
progress_bar_cmd=lambda x: x,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
@@ -197,7 +202,7 @@ class StepVideoPipeline(BasePipeline):
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
frames = self.vae.decode(latents)
|
||||
frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs)
|
||||
self.load_models_to_device([])
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user