mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-16 15:28:21 +00:00
Support JoyAI-Image-Edit (#1393)
* auto intergrate joyimage model * joyimage pipeline * train * ready * styling * joyai-image docs * update readme * pr review
This commit is contained in:
82
diffsynth/models/joyai_image_text_encoder.py
Normal file
82
diffsynth/models/joyai_image_text_encoder.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class JoyAIImageTextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
|
||||
|
||||
config = Qwen3VLConfig(
|
||||
text_config={
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 4096,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 12288,
|
||||
"max_position_embeddings": 262144,
|
||||
"model_type": "qwen3_vl_text",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"rope_scaling": {
|
||||
"mrope_interleaved": True,
|
||||
"mrope_section": [24, 20, 20],
|
||||
"rope_type": "default",
|
||||
},
|
||||
"rope_theta": 5000000,
|
||||
"use_cache": True,
|
||||
"vocab_size": 151936,
|
||||
},
|
||||
vision_config={
|
||||
"deepstack_visual_indexes": [8, 16, 24],
|
||||
"depth": 27,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"in_channels": 3,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "qwen3_vl",
|
||||
"num_heads": 16,
|
||||
"num_position_embeddings": 2304,
|
||||
"out_hidden_size": 4096,
|
||||
"patch_size": 16,
|
||||
"spatial_merge_size": 2,
|
||||
"temporal_patch_size": 2,
|
||||
},
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
vision_start_token_id=151652,
|
||||
vision_end_token_id=151653,
|
||||
tie_word_embeddings=False,
|
||||
)
|
||||
|
||||
self.model = Qwen3VLForConditionalGeneration(config)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
pre_norm_output = [None]
|
||||
def hook_fn(module, args, kwargs_output=None):
|
||||
pre_norm_output[0] = args[0]
|
||||
self.model.model.language_model.norm.register_forward_hook(hook_fn)
|
||||
_ = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
**kwargs,
|
||||
)
|
||||
return pre_norm_output[0]
|
||||
Reference in New Issue
Block a user