import torch class AceStepTextEncoder(torch.nn.Module): """ Text encoder for ACE-Step using Qwen3-Embedding-0.6B. Converts text/lyric tokens to hidden state embeddings that are further processed by the ACE-Step ConditionEncoder. Wraps a Qwen3Model transformers model. Config is manually constructed, and model weights are loaded via DiffSynth's standard mechanism from safetensors files. """ def __init__( self, ): super().__init__() from transformers import Qwen3Config, Qwen3Model config = Qwen3Config( attention_bias=False, attention_dropout=0.0, bos_token_id=151643, dtype="bfloat16", eos_token_id=151643, head_dim=128, hidden_act="silu", hidden_size=1024, initializer_range=0.02, intermediate_size=3072, layer_types=["full_attention"] * 28, max_position_embeddings=32768, max_window_layers=28, model_type="qwen3", num_attention_heads=16, num_hidden_layers=28, num_key_value_heads=8, pad_token_id=151643, rms_norm_eps=1e-06, rope_scaling=None, rope_theta=1000000, sliding_window=None, tie_word_embeddings=True, use_cache=True, use_sliding_window=False, vocab_size=151669, ) self.model = Qwen3Model(config) self.config = config self.hidden_size = config.hidden_size @torch.no_grad() def forward( self, input_ids: torch.LongTensor, attention_mask: torch.Tensor, ): """ Encode text/lyric tokens to hidden states. Args: input_ids: [B, T] token IDs attention_mask: [B, T] attention mask Returns: last_hidden_state: [B, T, hidden_size] """ outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) return outputs.last_hidden_state def to(self, *args, **kwargs): self.model.to(*args, **kwargs) return self