vram optimization

This commit is contained in:
Artiprocher
2025-03-10 17:11:11 +08:00
parent b548d7caf2
commit a05f647633
6 changed files with 83 additions and 51 deletions

View File

@@ -60,7 +60,6 @@ class WanVideoPipeline(BasePipeline):
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config = dict(
@@ -116,7 +115,7 @@ class WanVideoPipeline(BasePipeline):
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_dtype=dtype,
computation_device=self.device,
),
)
@@ -153,17 +152,21 @@ class WanVideoPipeline(BasePipeline):
def encode_image(self, image, num_frames, height, width):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
y = torch.concat([msk, y])
return {"clip_fea": clip_context, "y": [y]}
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
y = y.to(dtype=self.torch_dtype, device=self.device)
return {"clip_feature": clip_context, "y": y}
def tensor2video(self, frames):
@@ -174,18 +177,16 @@ class WanVideoPipeline(BasePipeline):
def prepare_extra_input(self, latents=None):
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
@@ -229,8 +230,8 @@ class WanVideoPipeline(BasePipeline):
if input_video is not None:
self.load_models_to_device(['vae'])
input_video = self.preprocess_images(input_video)
input_video = torch.stack(input_video, dim=2)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = noise