mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
qwen-image-acc-adapter
This commit is contained in:
63
diffsynth/models/qwen_image_accelerate_adapter.py
Normal file
63
diffsynth/models/qwen_image_accelerate_adapter.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from .qwen_image_dit import QwenImageTransformerBlock, AdaLayerNorm, TimestepEmbeddings
|
||||
from einops import rearrange
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class QwenImageAccelerateAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj_latents_in = torch.nn.Linear(64, 3072)
|
||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
QwenImageTransformerBlock(
|
||||
dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = AdaLayerNorm(3072, single=True)
|
||||
self.proj_out = torch.nn.Linear(3072, 64)
|
||||
self.proj_latents_out = torch.nn.Linear(64, 64)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
image=None,
|
||||
text=None,
|
||||
image_rotary_emb=None,
|
||||
timestep=None,
|
||||
):
|
||||
latents = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
image = image + self.proj_latents_in(latents)
|
||||
conditioning = self.time_text_embed(timestep, image.dtype)
|
||||
for block in self.transformer_blocks:
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
image = self.norm_out(image, conditioning)
|
||||
image = self.proj_out(image)
|
||||
image = image + self.proj_latents_out(latents)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageAccelerateAdapterStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageAccelerateAdapterStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
Reference in New Issue
Block a user