refactor wan dit

This commit is contained in:
Artiprocher
2025-03-07 16:35:26 +08:00
parent 84fb61aaaf
commit b548d7caf2
3 changed files with 254 additions and 664 deletions

View File

@@ -104,5 +104,6 @@ class WanPrompter(BasePrompter):
mask = mask.to(device)
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_emb = self.text_encoder(ids, mask)
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
for i, v in enumerate(seq_lens):
prompt_emb[:, v:] = 0
return prompt_emb