Support JoyAI-Image-Edit (#1393)

* auto intergrate joyimage model

* joyimage pipeline

* train

* ready

* styling

* joyai-image docs

* update readme

* pr review
This commit is contained in:
Hong Zhang
2026-04-15 16:57:11 +08:00
committed by GitHub
parent 8f18e24597
commit 079e51c9f3
21 changed files with 2040 additions and 123 deletions

View File

@@ -900,4 +900,20 @@ mova_series = [
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series
joyai_image_series = [
{
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
"model_hash": "56592ddfd7d0249d3aa527d24161a863",
"model_name": "joyai_image_dit",
"model_class": "diffsynth.models.joyai_image_dit.JoyAIImageDiT",
},
{
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors")
"model_hash": "2d11bf14bba8b4e87477c8199a895403",
"model_name": "joyai_image_text_encoder",
"model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_text_encoder.JoyAIImageTextEncoderStateDictConverter",
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series

View File

@@ -279,6 +279,22 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.joyai_image_dit.Transformer3DModel": {
"diffsynth.models.joyai_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.joyai_image_dit.ModulateWan": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():

View File

@@ -159,6 +159,18 @@ class FlowMatchScheduler():
timesteps[timestep_id] = timestep
return sigmas, timesteps
@staticmethod
def set_timesteps_joyai_image(num_inference_steps=100, denoising_strength=1.0, shift=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 4.0 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
num_train_timesteps = 1000

View File

@@ -0,0 +1,636 @@
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
) -> torch.Tensor:
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
emb = scale * emb
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
return get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = nn.SiLU()
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
self.post_act = nn.SiLU() if post_act_fn == "silu" else None
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class PixArtAlphaTextProjection(nn.Module):
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj(hidden_states)
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
# Build activation + projection matching diffusers pattern
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
elif activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
else:
act_fn = GELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
self.net.append(act_fn)
self.net.append(nn.Dropout(dropout))
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
if len(args) == 0:
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = _to_tuple(args[1], dim=dim)
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij")
grid = torch.stack(grid, dim=0)
return grid
def reshape_for_broadcast(freqs_cis, x, head_first=False):
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
if head_first:
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1])
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
if head_first:
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(xq, xk, freqs_cis, head_first=False):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
return xq_out, xk_out
def get_1d_rotary_pos_embed(dim, pos, theta=10000.0, use_real=False, theta_rescale_factor=1.0, interpolation_factor=1.0):
if isinstance(pos, int):
pos = torch.arange(pos).float()
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = torch.outer(pos * interpolation_factor, freqs)
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
return freqs_cos, freqs_sin
else:
return torch.polar(torch.ones_like(freqs), freqs)
def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000.0, use_real=False,
txt_rope_size=None, theta_rescale_factor=1.0, interpolation_factor=1.0):
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
if isinstance(theta_rescale_factor, (int, float)):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
if isinstance(interpolation_factor, (int, float)):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i], grid[i].reshape(-1), theta,
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
)
embs.append(emb)
if use_real:
vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1))
else:
vis_emb = torch.cat(embs, dim=1)
if txt_rope_size is not None:
embs_txt = []
vis_max_ids = grid.view(-1).max().item()
grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i], grid_txt, theta,
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
)
embs_txt.append(emb)
if use_real:
txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1))
else:
txt_emb = torch.cat(embs_txt, dim=1)
else:
txt_emb = None
return vis_emb, txt_emb
class ModulateWan(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor
factory_kwargs = {"dtype": dtype, "device": device}
self.modulate_table = nn.Parameter(
torch.zeros(1, factor, hidden_size, **factory_kwargs) / hidden_size**0.5,
requires_grad=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)]
def modulate(x, shift=None, scale=None):
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def load_modulation(modulate_type: str, hidden_size: int, factor: int, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
if modulate_type == 'wanx':
return ModulateWan(hidden_size, factor, **factory_kwargs)
raise ValueError(f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.")
class RMSNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with separate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
dit_modulation_type: Optional[str] = "wanx",
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.dit_modulation_type = dit_modulation_type
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = load_modulation(
modulate_type=self.dit_modulation_type,
hidden_size=hidden_size, factor=6, **factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
self.txt_mod = load_modulation(
modulate_type=self.dit_modulation_type,
hidden_size=hidden_size, factor=6, **factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
vis_freqs_cis: tuple = None,
txt_freqs_cis: tuple = None,
attn_kwargs: Optional[dict] = {},
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift, img_mod1_scale, img_mod1_gate,
img_mod2_shift, img_mod2_scale, img_mod2_gate,
) = self.img_mod(vec)
(
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate,
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate,
) = self.txt_mod(vec)
img_modulated = self.img_norm1(img)
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
img_q, img_k = img_qq, img_kk
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
raise NotImplementedError("RoPE text is not supported for inference")
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
# Use DiffSynth unified attention
attn_out = attention_forward(
q, k, v,
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
)
attn_out = attn_out.flatten(2, 3)
img_attn, txt_attn = attn_out[:, : img.shape[1]], attn_out[:, img.shape[1]:]
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
gate=img_mod2_gate,
)
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
gate=txt_mod2_gate,
)
return img, txt
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class JoyAIImageDiT(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
patch_size: list = [1, 2, 2],
in_channels: int = 16,
out_channels: int = 16,
hidden_size: int = 4096,
heads_num: int = 32,
text_states_dim: int = 4096,
mlp_width_ratio: float = 4.0,
mm_double_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
rope_type: str = 'rope',
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
dit_modulation_type: str = "wanx",
theta: int = 10000,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.heads_num = heads_num
self.rope_dim_list = rope_dim_list
self.dit_modulation_type = dit_modulation_type
self.mm_double_blocks_depth = mm_double_blocks_depth
self.rope_type = rope_type
self.theta = theta
factory_kwargs = {"device": device, "dtype": dtype}
if hidden_size % heads_num != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
self.condition_embedder = WanTimeTextImageEmbedding(
dim=hidden_size,
time_freq_dim=256,
time_proj_dim=hidden_size * 6,
text_embed_dim=text_states_dim,
)
self.double_blocks = nn.ModuleList([
MMDoubleStreamBlock(
self.hidden_size, self.heads_num,
mlp_width_ratio=mlp_width_ratio,
dit_modulation_type=self.dit_modulation_type,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
])
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size), **factory_kwargs)
def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None):
target_ndim = 3
if len(vis_rope_size) != target_ndim:
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
head_dim = self.hidden_size // self.heads_num
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim
vis_freqs, txt_freqs = get_nd_rotary_pos_embed(
rope_dim_list, vis_rope_size,
txt_rope_size=txt_rope_size if self.rope_type == 'mrope' else None,
theta=self.theta, use_real=True, theta_rescale_factor=1,
)
return vis_freqs, txt_freqs
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_mask: torch.Tensor = None,
return_dict: bool = True,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
is_multi_item = (len(hidden_states.shape) == 6)
num_items = 0
if is_multi_item:
num_items = hidden_states.shape[1]
if num_items > 1:
assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1"
hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1)
hidden_states = rearrange(hidden_states, 'b n c t h w -> b c (n t) h w')
batch_size, _, ot, oh, ow = hidden_states.shape
tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2]
if encoder_hidden_states_mask is None:
encoder_hidden_states_mask = torch.ones(
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]),
dtype=torch.bool,
).to(encoder_hidden_states.device)
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
temb, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed(
vis_rope_size=(tt, th, tw),
txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None,
)
for block in self.double_blocks:
img, txt = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
img=img, txt=txt, vec=vec,
vis_freqs_cis=vis_freqs_cis, txt_freqs_cis=txt_freqs_cis,
attn_kwargs={},
)
img_len = img.shape[1]
x = torch.cat((img, txt), 1)
img = x[:, :img_len, ...]
img = self.proj_out(self.norm_out(img))
img = self.unpatchify(img, tt, th, tw)
if is_multi_item:
img = rearrange(img, 'b c (n t) h w -> b n c t h w', n=num_items)
if num_items > 1:
img = torch.cat([img[:, 1:], img[:, :1]], dim=1)
return img
def unpatchify(self, x, t, h, w):
c = self.out_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = torch.einsum("nthwopqc->nctohpwq", x)
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))

View File

@@ -0,0 +1,82 @@
import torch
from typing import Optional
class JoyAIImageTextEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
config = Qwen3VLConfig(
text_config={
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12288,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-6,
"rope_scaling": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_type": "default",
},
"rope_theta": 5000000,
"use_cache": True,
"vocab_size": 151936,
},
vision_config={
"deepstack_visual_indexes": [8, 16, 24],
"depth": 27,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4304,
"model_type": "qwen3_vl",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 4096,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=False,
)
self.model = Qwen3VLForConditionalGeneration(config)
self.config = config
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
pre_norm_output = [None]
def hook_fn(module, args, kwargs_output=None):
pre_norm_output[0] = args[0]
self.model.model.language_model.norm.register_forward_hook(hook_fn)
_ = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
return pre_norm_output[0]

View File

@@ -0,0 +1,282 @@
import torch
from PIL import Image
from typing import Union, Optional
from tqdm import tqdm
from einops import rearrange
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.joyai_image_dit import JoyAIImageDiT
from ..models.joyai_image_text_encoder import JoyAIImageTextEncoder
from ..models.wan_video_vae import WanVideoVAE
class JoyAIImagePipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("Wan")
self.text_encoder: JoyAIImageTextEncoder = None
self.dit: JoyAIImageDiT = None
self.vae: WanVideoVAE = None
self.processor = None
self.in_iteration_models = ("dit",)
self.units = [
JoyAIImageUnit_ShapeChecker(),
JoyAIImageUnit_EditImageEmbedder(),
JoyAIImageUnit_PromptEmbedder(),
JoyAIImageUnit_NoiseInitializer(),
JoyAIImageUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_joyai_image
self.compilable_models = ["dit"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
# Processor
processor_config: ModelConfig = None,
# Optional
vram_limit: float = None,
):
pipe = JoyAIImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("joyai_image_text_encoder")
pipe.dit = model_pool.fetch_model("joyai_image_dit")
pipe.vae = model_pool.fetch_model("wan_video_vae")
if processor_config is not None:
processor_config.download_if_necessary()
from transformers import AutoProcessor
pipe.processor = AutoProcessor.from_pretrained(processor_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 5.0,
# Image
edit_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
# Steps
max_sequence_length: int = 4096,
num_inference_steps: int = 30,
# Tiling
tiled: Optional[bool] = False,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# Scheduler
shift: Optional[float] = 4.0,
# Progress bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"edit_image": edit_image,
"denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "max_sequence_length": max_sequence_length,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
}
# Unit chain
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
latents = rearrange(inputs_shared["latents"], "b n c f h w -> (b n) c f h w")
image = self.vae.decode(latents, device=self.device)[0]
image = self.vae_output_to_image(image, pattern="C 1 H W")
self.load_models_to_device([])
return image
class JoyAIImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: "JoyAIImagePipeline", height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class JoyAIImageUnit_PromptEmbedder(PipelineUnit):
prompt_template_encode = {
'image':
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
'multiple_images':
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n",
'video':
"<|im_start|>system\n \\nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
}
prompt_template_encode_start_idx = {'image': 34, 'multiple_images': 34, 'video': 91}
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
input_params=("edit_image", "max_sequence_length"),
output_params=("prompt_embeds", "prompt_embeds_mask"),
onload_model_names=("joyai_image_text_encoder",),
)
def process(self, pipe: "JoyAIImagePipeline", prompt, positive, edit_image, max_sequence_length):
pipe.load_models_to_device(self.onload_model_names)
has_image = edit_image is not None
if has_image:
prompt_embeds, prompt_embeds_mask = self._encode_with_image(pipe, prompt, edit_image, max_sequence_length)
else:
prompt_embeds, prompt_embeds_mask = self._encode_text_only(pipe, prompt, max_sequence_length)
return {"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask}
def _encode_with_image(self, pipe, prompt, edit_image, max_sequence_length):
template = self.prompt_template_encode['multiple_images']
drop_idx = self.prompt_template_encode_start_idx['multiple_images']
image_tokens = '<image>\n'
prompt = f"<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n"
prompt = prompt.replace('<image>\n', '<|vision_start|><|image_pad|><|vision_end|>')
prompt = template.format(prompt)
inputs = pipe.processor(text=[prompt], images=[edit_image], padding=True, return_tensors="pt").to(pipe.device)
last_hidden_states = pipe.text_encoder(**inputs)
prompt_embeds = last_hidden_states[:, drop_idx:]
prompt_embeds_mask = inputs['attention_mask'][:, drop_idx:]
if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length:
prompt_embeds = prompt_embeds[:, -max_sequence_length:, :]
prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:]
return prompt_embeds, prompt_embeds_mask
def _encode_text_only(self, pipe, prompt, max_sequence_length):
# TODO: may support for text-only encoding in the future.
raise NotImplementedError("Text-only encoding is not implemented yet. Please provide edit_image for now.")
return prompt_embeds, encoder_attention_mask
class JoyAIImageUnit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "height", "width"),
output_params=("ref_latents", "num_items", "is_multi_item"),
onload_model_names=("wan_video_vae",),
)
def process(self, pipe: "JoyAIImagePipeline", edit_image, tiled, tile_size, tile_stride, height, width):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
# Resize edit image to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents
edit_image = edit_image.resize((width, height), Image.LANCZOS)
images = [pipe.preprocess_image(edit_image).transpose(0, 1)]
latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=1).to(device=pipe.device, dtype=pipe.torch_dtype)
return {"ref_latents": ref_vae, "edit_image": edit_image}
class JoyAIImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("seed", "height", "width", "rand_device"),
output_params=("noise"),
)
def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device):
latent_h = height // pipe.vae.upsampling_factor
latent_w = width // pipe.vae.upsampling_factor
shape = (1, 1, pipe.vae.z_dim, 1, latent_h, latent_w)
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class JoyAIImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
if input_image is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(input_image, Image.Image):
input_image = [input_image]
input_image = [pipe.preprocess_image(img).transpose(0, 1) for img in input_image]
latents = pipe.vae.encode(input_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
input_latents = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(input_image)))
return {"latents": noise, "input_latents": input_latents}
def model_fn_joyai_image(
dit,
latents,
timestep,
prompt_embeds,
prompt_embeds_mask,
ref_latents=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
img = torch.cat([ref_latents, latents], dim=1) if ref_latents is not None else latents
img = dit(
hidden_states=img,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
img = img[:, -latents.size(1):]
return img

View File

@@ -0,0 +1,20 @@
def JoyAIImageTextEncoderStateDictConverter(state_dict):
"""Convert HuggingFace Qwen3VL checkpoint keys to DiffSynth wrapper keys.
Mapping (checkpoint -> wrapper):
- lm_head.weight -> model.lm_head.weight
- model.language_model.* -> model.model.language_model.*
- model.visual.* -> model.model.visual.*
"""
state_dict_ = {}
for key in state_dict:
if key == "lm_head.weight":
new_key = "model.lm_head.weight"
elif key.startswith("model.language_model."):
new_key = "model.model." + key[len("model."):]
elif key.startswith("model.visual."):
new_key = "model.model." + key[len("model."):]
else:
new_key = key
state_dict_[new_key] = state_dict[key]
return state_dict_