rebuild base modules

This commit is contained in:
Artiprocher
2024-07-26 12:15:40 +08:00
parent 9471bff8a4
commit e3f8a576cf
76 changed files with 3253 additions and 3563 deletions

View File

@@ -123,21 +123,23 @@ class MotionBucketManager:
class LightningModel(pl.LightningModule):
def __init__(self, learning_rate=1e-5, svd_ckpt_path=None, add_positional_conv=128, contrast_enhance_scale=1.01):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.float16, device=self.device)
model_manager.load_stable_video_diffusion(state_dict=load_state_dict(svd_ckpt_path), add_positional_conv=add_positional_conv)
state_dict = load_state_dict(svd_ckpt_path)
self.image_encoder: SVDImageEncoder = model_manager.image_encoder
self.image_encoder = SVDImageEncoder().to(dtype=torch.float16, device=self.device)
self.image_encoder.load_state_dict(SVDImageEncoder.state_dict_converter().from_civitai(state_dict))
self.image_encoder.eval()
self.image_encoder.requires_grad_(False)
self.unet: SVDUNet = model_manager.unet
self.unet = SVDUNet(add_positional_conv=add_positional_conv).to(dtype=torch.float16, device=self.device)
self.unet.load_state_dict(SVDUNet.state_dict_converter().from_civitai(state_dict), strict=False)
self.unet.train()
self.unet.requires_grad_(False)
for block in self.unet.blocks:
if isinstance(block, TemporalAttentionBlock):
block.requires_grad_(True)
self.vae_encoder: SVDVAEEncoder = model_manager.vae_encoder
self.vae_encoder = SVDVAEEncoder.to(dtype=torch.float16, device=self.device)
self.vae_encoder.load_state_dict(SVDVAEEncoder.state_dict_converter().from_civitai(state_dict))
self.vae_encoder.eval()
self.vae_encoder.requires_grad_(False)