mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
2 Commits
dpo
...
qwen-image
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
94e7e800b2 | ||
|
|
e85f42b474 |
@@ -72,6 +72,7 @@ from ..models.flux_lora_encoder import FluxLoRAEncoder
|
||||
from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
|
||||
from ..models.nexus_gen import NexusGenAutoregressiveModel
|
||||
|
||||
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
@@ -165,6 +166,7 @@ model_loader_configs = [
|
||||
(None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"),
|
||||
(None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"),
|
||||
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
|
||||
(None, "ae9d13bfc578702baf6445d2cf3d1d46", ["qwen_image_accelerate_adapter"], [QwenImageAccelerateAdapter], "civitai"),
|
||||
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
|
||||
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
||||
]
|
||||
|
||||
63
diffsynth/models/qwen_image_accelerate_adapter.py
Normal file
63
diffsynth/models/qwen_image_accelerate_adapter.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from .qwen_image_dit import QwenImageTransformerBlock, AdaLayerNorm, TimestepEmbeddings
|
||||
from einops import rearrange
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class QwenImageAccelerateAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj_latents_in = torch.nn.Linear(64, 3072)
|
||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
QwenImageTransformerBlock(
|
||||
dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = AdaLayerNorm(3072, single=True)
|
||||
self.proj_out = torch.nn.Linear(3072, 64)
|
||||
self.proj_latents_out = torch.nn.Linear(64, 64)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
image=None,
|
||||
text=None,
|
||||
image_rotary_emb=None,
|
||||
timestep=None,
|
||||
):
|
||||
latents = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
image = image + self.proj_latents_in(latents)
|
||||
conditioning = self.time_text_embed(timestep, image.dtype)
|
||||
for block in self.transformer_blocks:
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
image = self.norm_out(image, conditioning)
|
||||
image = self.proj_out(image)
|
||||
image = image + self.proj_latents_out(latents)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageAccelerateAdapterStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageAccelerateAdapterStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -6,6 +6,7 @@ from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
@@ -27,12 +28,13 @@ class QwenImagePipeline(BasePipeline):
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
|
||||
self.accelerate_adapter: QwenImageAccelerateAdapter = None
|
||||
self.text_encoder: QwenImageTextEncoder = None
|
||||
self.dit: QwenImageDiT = None
|
||||
self.vae: QwenImageVAE = None
|
||||
self.tokenizer: Qwen2Tokenizer = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.in_iteration_models = ("accelerate_adapter", "dit",)
|
||||
self.units = [
|
||||
QwenImageUnit_ShapeChecker(),
|
||||
QwenImageUnit_NoiseInitializer(),
|
||||
@@ -187,6 +189,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
||||
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
||||
pipe.accelerate_adapter = model_manager.fetch_model("qwen_image_accelerate_adapter")
|
||||
if tokenizer_config is not None and pipe.text_encoder is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
from transformers import Qwen2Tokenizer
|
||||
@@ -433,6 +436,7 @@ class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
accelerate_adapter: QwenImageAccelerateAdapter = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
@@ -478,9 +482,12 @@ def model_fn_qwen_image(
|
||||
attention_mask=attention_mask,
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
|
||||
if accelerate_adapter is not None:
|
||||
image = accelerate_adapter(latents, image, text, image_rotary_emb, timestep)
|
||||
else:
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return latents
|
||||
|
||||
@@ -31,11 +31,13 @@ class FlowMatchScheduler():
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None):
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, random_sigmas=False):
|
||||
if shift is not None:
|
||||
self.shift = shift
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
if self.extra_one_step:
|
||||
if random_sigmas:
|
||||
self.sigmas = torch.Tensor(sorted([torch.rand((1,)).item() * (sigma_start - self.sigma_min) for i in range(num_inference_steps - 1)] + [sigma_start], reverse=True))
|
||||
elif self.extra_one_step:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||
else:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
# This script is for initializing a Qwen-Image-Accelerate-Adapter
|
||||
from diffsynth import load_state_dict, hash_state_dict_keys
|
||||
from diffsynth.pipelines.qwen_image import QwenImageAccelerateAdapter
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
state_dict_dit = {}
|
||||
for i in range(1, 10):
|
||||
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
|
||||
|
||||
adapter = QwenImageAccelerateAdapter().to(dtype=torch.bfloat16, device="cuda")
|
||||
state_dict_adapter = adapter.state_dict()
|
||||
|
||||
state_dict_init = {}
|
||||
for k in state_dict_adapter:
|
||||
if k.startswith("transformer_blocks"):
|
||||
name = k.replace("transformer_blocks.0.", "transformer_blocks.59.")
|
||||
param = state_dict_dit[name]
|
||||
if "_mod." in k:
|
||||
param[2*3072: 3*3072] = 0
|
||||
param[5*3072: 6*3072] = 0
|
||||
state_dict_init[k] = param
|
||||
elif k in state_dict_dit:
|
||||
state_dict_init[k] = state_dict_dit[k]
|
||||
else:
|
||||
state_dict_init[k] = torch.zeros_like(state_dict_adapter[k])
|
||||
print("Zero initialized:", k)
|
||||
adapter.load_state_dict(state_dict_init)
|
||||
|
||||
print(hash_state_dict_keys(state_dict_init))
|
||||
save_file(state_dict_init, "models/adapter.safetensors")
|
||||
18
test_accelerate.py
Normal file
18
test_accelerate.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig("models/adapter.safetensors")
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=4, cfg_scale=1)
|
||||
image.save("image.jpg")
|
||||
Reference in New Issue
Block a user