refine code

This commit is contained in:
Artiprocher
2025-07-29 18:47:16 +08:00
parent 7df48fc2b5
commit 9c51623fc2
14 changed files with 124 additions and 18 deletions

View File

@@ -2,9 +2,8 @@ import math
import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import _compute_default_rope_parameters
from transformers import AutoConfig
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
@@ -39,6 +38,7 @@ class Qwen2_5_VLRotaryEmbedding(nn.Module):
self.original_max_seq_len = config.max_position_embeddings
self.config = config
from transformers.modeling_rope_utils import _compute_default_rope_parameters
self.rope_init_fn = _compute_default_rope_parameters
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
@@ -181,6 +181,7 @@ class Qwen2_5_VLAttention(nn.Module):
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
from transformers.activations import ACT2FN
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
@@ -254,6 +255,8 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
class NexusGenImageEmbeddingMerger(nn.Module):
def __init__(self, model_path="models/DiffSynth-Studio/Nexus-GenV2", num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'):
super().__init__()
from transformers import AutoConfig
from transformers.activations import ACT2FN
config = AutoConfig.from_pretrained(model_path)
self.config = config
self.num_layers = num_layers