mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
rebuild base modules
This commit is contained in:
@@ -123,21 +123,23 @@ class MotionBucketManager:
|
||||
class LightningModel(pl.LightningModule):
|
||||
def __init__(self, learning_rate=1e-5, svd_ckpt_path=None, add_positional_conv=128, contrast_enhance_scale=1.01):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device=self.device)
|
||||
model_manager.load_stable_video_diffusion(state_dict=load_state_dict(svd_ckpt_path), add_positional_conv=add_positional_conv)
|
||||
state_dict = load_state_dict(svd_ckpt_path)
|
||||
|
||||
self.image_encoder: SVDImageEncoder = model_manager.image_encoder
|
||||
self.image_encoder = SVDImageEncoder().to(dtype=torch.float16, device=self.device)
|
||||
self.image_encoder.load_state_dict(SVDImageEncoder.state_dict_converter().from_civitai(state_dict))
|
||||
self.image_encoder.eval()
|
||||
self.image_encoder.requires_grad_(False)
|
||||
|
||||
self.unet: SVDUNet = model_manager.unet
|
||||
self.unet = SVDUNet(add_positional_conv=add_positional_conv).to(dtype=torch.float16, device=self.device)
|
||||
self.unet.load_state_dict(SVDUNet.state_dict_converter().from_civitai(state_dict), strict=False)
|
||||
self.unet.train()
|
||||
self.unet.requires_grad_(False)
|
||||
for block in self.unet.blocks:
|
||||
if isinstance(block, TemporalAttentionBlock):
|
||||
block.requires_grad_(True)
|
||||
|
||||
self.vae_encoder: SVDVAEEncoder = model_manager.vae_encoder
|
||||
self.vae_encoder = SVDVAEEncoder.to(dtype=torch.float16, device=self.device)
|
||||
self.vae_encoder.load_state_dict(SVDVAEEncoder.state_dict_converter().from_civitai(state_dict))
|
||||
self.vae_encoder.eval()
|
||||
self.vae_encoder.requires_grad_(False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user