mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 07:18:14 +00:00
hunyuanvideo pipeline
This commit is contained in:
@@ -2,15 +2,17 @@ import torch
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConverter
|
||||
|
||||
|
||||
|
||||
class SD3TextEncoder1(SDTextEncoder):
|
||||
def __init__(self, vocab_size=49408):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
if extra_mask is not None:
|
||||
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
|
||||
Reference in New Issue
Block a user