Files
DiffSynth-Studio/diffsynth/models/ace_step_text_encoder.py
2026-04-21 13:16:15 +08:00

54 lines
1.4 KiB
Python

import torch
class AceStepTextEncoder(torch.nn.Module):
def __init__(
self,
):
super().__init__()
from transformers import Qwen3Config, Qwen3Model
config = Qwen3Config(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=151643,
dtype="bfloat16",
eos_token_id=151643,
head_dim=128,
hidden_act="silu",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=3072,
layer_types=["full_attention"] * 28,
max_position_embeddings=32768,
max_window_layers=28,
model_type="qwen3",
num_attention_heads=16,
num_hidden_layers=28,
num_key_value_heads=8,
pad_token_id=151643,
rms_norm_eps=1e-06,
rope_scaling=None,
rope_theta=1000000,
sliding_window=None,
tie_word_embeddings=True,
use_cache=True,
use_sliding_window=False,
vocab_size=151669,
)
self.model = Qwen3Model(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
return outputs.last_hidden_state