diff --git a/README.md b/README.md index 9508526..1250216 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ We believe that a well-developed open-source code framework can lower the thresh > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. +- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. + - **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online - [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated - [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) module upgraded, supporting layer-level disk offload, releasing both memory and VRAM simultaneously @@ -420,6 +422,7 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| |[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| |[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| +|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| diff --git a/README_zh.md b/README_zh.md index d24f408..cdfc72f 100644 --- a/README_zh.md +++ b/README_zh.md @@ -33,6 +33,8 @@ DiffSynth 目前包括两个开源项目: > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 +- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。 + - **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线 - [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中 - [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload,同时释放内存与显存 @@ -420,6 +422,7 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| |[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| |[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| +|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 1ff1d63..dca078a 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -31,6 +31,38 @@ qwen_image_series = [ "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet", "extra_kwargs": {"additional_in_dim": 4}, }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors") + "model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8", + "model_name": "siglip2_image_encoder", + "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors") + "model_hash": "5722b5c873720009de96422993b15682", + "model_name": "dinov3_image_encoder", + "model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder", + }, + { + # Example: + "model_hash": "a166c33455cdbd89c0888a3645ca5c0f", + "model_name": "qwen_image_image2lora_coarse", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + }, + { + # Example: + "model_hash": "a5476e691767a4da6d3a6634a10f7408", + "model_name": "qwen_image_image2lora_fine", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + "extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64} + }, + { + # Example: + "model_hash": "0aad514690602ecaff932c701cb4b0bb", + "model_name": "qwen_image_image2lora_style", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + "extra_kwargs": {"compress_dim": 64, "use_residual": False} + }, ] wan_series = [ diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index 219983d..958dad4 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -32,6 +32,25 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", }, + "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": { + "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": { + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": { "diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", "diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule", diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py new file mode 100644 index 0000000..be2ee58 --- /dev/null +++ b/diffsynth/models/dinov3_image_encoder.py @@ -0,0 +1,94 @@ +from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast +from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig +import torch + + +class DINOv3ImageEncoder(DINOv3ViTModel): + def __init__(self): + config = DINOv3ViTConfig( + architectures = [ + "DINOv3ViTModel" + ], + attention_dropout = 0.0, + drop_path_rate = 0.0, + dtype = "float32", + hidden_act = "silu", + hidden_size = 4096, + image_size = 224, + initializer_range = 0.02, + intermediate_size = 8192, + key_bias = False, + layer_norm_eps = 1e-05, + layerscale_value = 1.0, + mlp_bias = True, + model_type = "dinov3_vit", + num_attention_heads = 32, + num_channels = 3, + num_hidden_layers = 40, + num_register_tokens = 4, + patch_size = 16, + pos_embed_jitter = None, + pos_embed_rescale = 2.0, + pos_embed_shift = None, + proj_bias = True, + query_bias = False, + rope_theta = 100.0, + transformers_version = "4.56.1", + use_gated_mlp = True, + value_bias = False + ) + super().__init__(config) + self.processor = DINOv3ViTImageProcessorFast( + crop_size = None, + data_format = "channels_first", + default_to_square = True, + device = None, + disable_grouping = None, + do_center_crop = None, + do_convert_rgb = None, + do_normalize = True, + do_rescale = True, + do_resize = True, + image_mean = [ + 0.485, + 0.456, + 0.406 + ], + image_processor_type = "DINOv3ViTImageProcessorFast", + image_std = [ + 0.229, + 0.224, + 0.225 + ], + input_data_format = None, + resample = 2, + rescale_factor = 0.00392156862745098, + return_tensors = None, + size = { + "height": 224, + "width": 224 + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + inputs = self.processor(images=image, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) + bool_masked_pos = None + head_mask = None + + pixel_values = pixel_values.to(torch_dtype) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] if head_mask is not None else None + hidden_states = layer_module( + hidden_states, + attention_mask=layer_head_mask, + position_embeddings=position_embeddings, + ) + + sequence_output = self.norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + + return pooled_output diff --git a/diffsynth/models/qwen_image_image2lora.py b/diffsynth/models/qwen_image_image2lora.py new file mode 100644 index 0000000..6aefbf2 --- /dev/null +++ b/diffsynth/models/qwen_image_image2lora.py @@ -0,0 +1,128 @@ +import torch + + +class CompressedMLP(torch.nn.Module): + def __init__(self, in_dim, mid_dim, out_dim, bias=False): + super().__init__() + self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias) + self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias) + + def forward(self, x, residual=None): + x = self.proj_in(x) + if residual is not None: x = x + residual + x = self.proj_out(x) + return x + + +class ImageEmbeddingToLoraMatrix(torch.nn.Module): + def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank): + super().__init__() + self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank) + self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank) + self.lora_a_dim = lora_a_dim + self.lora_b_dim = lora_b_dim + self.rank = rank + + def forward(self, x, residual=None): + lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim) + lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank) + return lora_a, lora_b + + +class SequencialMLP(torch.nn.Module): + def __init__(self, length, in_dim, mid_dim, out_dim, bias=False): + super().__init__() + self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias) + self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias) + self.length = length + self.in_dim = in_dim + self.mid_dim = mid_dim + + def forward(self, x): + x = x.view(self.length, self.in_dim) + x = self.proj_in(x) + x = x.view(1, self.length * self.mid_dim) + x = self.proj_out(x) + return x + + +class LoRATrainerBlock(torch.nn.Module): + def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = lora_patterns + self.block_id = block_id + self.layers = [] + for name, lora_a_dim, lora_b_dim in self.lora_patterns: + self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank)) + self.layers = torch.nn.ModuleList(self.layers) + if use_residual: + self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim) + else: + self.proj_residual = None + + def forward(self, x, residual=None): + lora = {} + if self.proj_residual is not None: residual = self.proj_residual(residual) + for lora_pattern, layer in zip(self.lora_patterns, self.layers): + name = lora_pattern[0] + lora_a, lora_b = layer(x, residual=residual) + lora[f"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight"] = lora_a + lora[f"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight"] = lora_b + return lora + + +class QwenImageImage2LoRAModel(torch.nn.Module): + def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = [ + [ + ("attn.to_q", 3072, 3072), + ("attn.to_k", 3072, 3072), + ("attn.to_v", 3072, 3072), + ("attn.to_out.0", 3072, 3072), + ], + [ + ("img_mlp.net.2", 3072*4, 3072), + ("img_mod.1", 3072, 3072*6), + ], + [ + ("attn.add_q_proj", 3072, 3072), + ("attn.add_k_proj", 3072, 3072), + ("attn.add_v_proj", 3072, 3072), + ("attn.to_add_out", 3072, 3072), + ], + [ + ("txt_mlp.net.2", 3072*4, 3072), + ("txt_mod.1", 3072, 3072*6), + ], + ] + self.num_blocks = num_blocks + self.blocks = [] + for lora_patterns in self.lora_patterns: + for block_id in range(self.num_blocks): + self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim)) + self.blocks = torch.nn.ModuleList(self.blocks) + self.residual_scale = 0.05 + self.use_residual = use_residual + + def forward(self, x, residual=None): + if residual is not None: + if self.use_residual: + residual = residual * self.residual_scale + else: + residual = None + lora = {} + for block in self.blocks: + lora.update(block(x, residual)) + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if ".proj_a." in name: + state_dict[name] = state_dict[name] * 0.3 + elif ".proj_b.proj_out." in name: + state_dict[name] = state_dict[name] * 0 + elif ".proj_residual.proj_out." in name: + state_dict[name] = state_dict[name] * 0.3 + self.load_state_dict(state_dict) diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py new file mode 100644 index 0000000..10184f8 --- /dev/null +++ b/diffsynth/models/siglip2_image_encoder.py @@ -0,0 +1,70 @@ +from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig +from transformers import SiglipImageProcessor +import torch + + +class Siglip2ImageEncoder(SiglipVisionTransformer): + def __init__(self): + config = SiglipVisionConfig( + attention_dropout = 0.0, + dtype = "float32", + hidden_act = "gelu_pytorch_tanh", + hidden_size = 1536, + image_size = 384, + intermediate_size = 6144, + layer_norm_eps = 1e-06, + model_type = "siglip_vision_model", + num_attention_heads = 16, + num_channels = 3, + num_hidden_layers = 40, + patch_size = 16, + transformers_version = "4.56.1", + _attn_implementation = "sdpa" + ) + super().__init__(config) + self.processor = SiglipImageProcessor( + do_convert_rgb = None, + do_normalize = True, + do_rescale = True, + do_resize = True, + image_mean = [ + 0.5, + 0.5, + 0.5 + ], + image_processor_type = "SiglipImageProcessor", + image_std = [ + 0.5, + 0.5, + 0.5 + ], + processor_class = "SiglipProcessor", + resample = 2, + rescale_factor = 0.00392156862745098, + size = { + "height": 384, + "width": 384 + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] + pixel_values = pixel_values.to(device=device, dtype=torch_dtype) + output_attentions = False + output_hidden_states = False + interpolate_pos_encoding = False + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return pooler_output diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index fc6581e..f6a9bcd 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -8,11 +8,15 @@ import numpy as np from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora.merge import merge_lora from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_vae import QwenImageVAE from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet +from ..models.siglip2_image_encoder import Siglip2ImageEncoder +from ..models.dinov3_image_encoder import DINOv3ImageEncoder +from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel class QwenImagePipeline(BasePipeline): @@ -30,6 +34,11 @@ class QwenImagePipeline(BasePipeline): self.vae: QwenImageVAE = None self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None self.tokenizer: Qwen2Tokenizer = None + self.siglip2_image_encoder: Siglip2ImageEncoder = None + self.dinov3_image_encoder: DINOv3ImageEncoder = None + self.image2lora_style: DINOv3ImageEncoder = None + self.image2lora_coarse: DINOv3ImageEncoder = None + self.image2lora_fine: QwenImageImage2LoRAModel = None self.processor: Qwen2VLProcessor = None self.in_iteration_models = ("dit", "blockwise_controlnet") self.units = [ @@ -72,6 +81,11 @@ class QwenImagePipeline(BasePipeline): processor_config.download_if_necessary() from transformers import Qwen2VLProcessor pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path) + pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") + pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") + pipe.image2lora_style = model_pool.fetch_model("qwen_image_image2lora_style") + pipe.image2lora_coarse = model_pool.fetch_model("qwen_image_image2lora_coarse") + pipe.image2lora_fine = model_pool.fetch_model("qwen_image_image2lora_fine") # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state() @@ -515,6 +529,116 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit): return {"edit_latents": edit_latents, "edit_image": resized_edit_image} +class QwenImageUnit_Image2LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_images",), + output_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), + onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder", "text_encoder"), + ) + from ..core.data.operators import ImageCropAndResize + self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8) + self.processor_highres = ImageCropAndResize(height=1024, width=1024) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image): + prompt = [prompt] + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return prompt_embeds.view(1, -1) + + def encode_images_using_siglip2(self, pipe: QwenImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["siglip2_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_dinov3(self, pipe: QwenImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["dinov3_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_qwenvl(self, pipe: QwenImagePipeline, images: list[Image.Image], highres=False): + pipe.load_models_to_device(["text_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) if highres else self.processor_lowres(image) + embs.append(self.encode_prompt_edit(pipe, prompt="", edit_image=image)) + embs = torch.stack(embs) + return embs + + def encode_images(self, pipe: QwenImagePipeline, images: list[Image.Image]): + if images is None: + return {} + if not isinstance(images, list): + images = [images] + embs_siglip2 = self.encode_images_using_siglip2(pipe, images) + embs_dinov3 = self.encode_images_using_dinov3(pipe, images) + x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) + residual = None + residual_highres = None + if pipe.image2lora_coarse is not None: + residual = self.encode_images_using_qwenvl(pipe, images, highres=False) + if pipe.image2lora_fine is not None: + residual_highres = self.encode_images_using_qwenvl(pipe, images, highres=True) + return x, residual, residual_highres + + def process(self, pipe: QwenImagePipeline, image2lora_images): + if image2lora_images is None: + return {} + x, residual, residual_highres = self.encode_images(pipe, image2lora_images) + return {"image2lora_x": x, "image2lora_residual": residual, "image2lora_residual_highres": residual_highres} + + +class QwenImageUnit_Image2LoRADecode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), + output_params=("lora",), + onload_model_names=("image2lora_coarse", "image2lora_fine", "image2lora_style"), + ) + + def process(self, pipe: QwenImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres): + if image2lora_x is None: + return {} + loras = [] + if pipe.image2lora_style is not None: + pipe.load_models_to_device(["image2lora_style"]) + for x in image2lora_x: + loras.append(pipe.image2lora_style(x=x, residual=None)) + if pipe.image2lora_coarse is not None: + pipe.load_models_to_device(["image2lora_coarse"]) + for x, residual in zip(image2lora_x, image2lora_residual): + loras.append(pipe.image2lora_coarse(x=x, residual=residual)) + if pipe.image2lora_fine is not None: + pipe.load_models_to_device(["image2lora_fine"]) + for x, residual in zip(image2lora_x, image2lora_residual_highres): + loras.append(pipe.image2lora_fine(x=x, residual=residual)) + lora = merge_lora(loras, alpha=1 / len(image2lora_x)) + return {"lora": lora} + + class QwenImageUnit_ContextImageEmbedder(PipelineUnit): def __init__(self): super().__init__( diff --git a/diffsynth/utils/lora/__init__.py b/diffsynth/utils/lora/__init__.py index 1ebbbe5..8eb5901 100644 --- a/diffsynth/utils/lora/__init__.py +++ b/diffsynth/utils/lora/__init__.py @@ -1 +1,3 @@ from .general import GeneralLoRALoader +from .merge import merge_lora +from .reset_rank import reset_lora_rank \ No newline at end of file diff --git a/diffsynth/utils/lora/reset_rank.py b/diffsynth/utils/lora/reset_rank.py new file mode 100644 index 0000000..9522b04 --- /dev/null +++ b/diffsynth/utils/lora/reset_rank.py @@ -0,0 +1,20 @@ +import torch + +def decomposite(tensor_A, tensor_B, rank): + dtype, device = tensor_A.dtype, tensor_A.device + weight = tensor_B @ tensor_A + U, S, V = torch.pca_lowrank(weight.float(), q=rank) + tensor_A = (V.T).to(dtype=dtype, device=device).contiguous() + tensor_B = (U @ torch.diag(S)).to(dtype=dtype, device=device).contiguous() + return tensor_A, tensor_B + +def reset_lora_rank(lora, rank): + lora_merged = {} + keys = [i for i in lora.keys() if ".lora_A." in i] + for key in keys: + tensor_A = lora[key] + tensor_B = lora[key.replace(".lora_A.", ".lora_B.")] + tensor_A, tensor_B = decomposite(tensor_A, tensor_B, rank) + lora_merged[key] = tensor_A + lora_merged[key.replace(".lora_A.", ".lora_B.")] = tensor_B + return lora_merged \ No newline at end of file diff --git a/docs/en/Model_Details/Qwen-Image.md b/docs/en/Model_Details/Qwen-Image.md index b0d9eb6..2f6a0dc 100644 --- a/docs/en/Model_Details/Qwen-Image.md +++ b/docs/en/Model_Details/Qwen-Image.md @@ -93,6 +93,7 @@ graph LR; | [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | | [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) | | [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - | +|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| Special Training Scripts: diff --git a/docs/zh/Model_Details/Qwen-Image.md b/docs/zh/Model_Details/Qwen-Image.md index 70585ba..f2609ac 100644 --- a/docs/zh/Model_Details/Qwen-Image.md +++ b/docs/zh/Model_Details/Qwen-Image.md @@ -93,6 +93,7 @@ graph LR; |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| |[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| |[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| +|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| 特殊训练脚本: diff --git a/examples/qwen_image/model_inference/Qwen-Image-i2L.py b/examples/qwen_image/model_inference/Qwen-Image-i2L.py new file mode 100644 index 0000000..87061d8 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-i2L.py @@ -0,0 +1,110 @@ +from diffsynth.pipelines.qwen_image import ( + QwenImagePipeline, ModelConfig, + QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode +) +from diffsynth.utils.lora import merge_lora +from diffsynth import load_state_dict +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + + +def demo_style(): + # Load models + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Style.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + ) + + # Load images + snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/style/1/*", + local_dir="data/examples" + ) + images = [ + Image.open("data/examples/assets/style/1/0.jpg"), + Image.open("data/examples/assets/style/1/1.jpg"), + Image.open("data/examples/assets/style/1/2.jpg"), + Image.open("data/examples/assets/style/1/3.jpg"), + Image.open("data/examples/assets/style/1/4.jpg"), + ] + + # Model inference + with torch.no_grad(): + embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] + save_file(lora, "model_style.safetensors") + + +def demo_coarse_fine_bias(): + # Load models + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Coarse.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Fine.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + ) + + # Load images + snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/lora/3/*", + local_dir="data/examples" + ) + images = [ + Image.open("data/examples/assets/lora/3/0.jpg"), + Image.open("data/examples/assets/lora/3/1.jpg"), + Image.open("data/examples/assets/lora/3/2.jpg"), + Image.open("data/examples/assets/lora/3/3.jpg"), + Image.open("data/examples/assets/lora/3/4.jpg"), + Image.open("data/examples/assets/lora/3/5.jpg"), + ] + + # Model inference + with torch.no_grad(): + embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] + lora_bias = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Bias.safetensors") + lora_bias.download_if_necessary() + lora_bias = load_state_dict(lora_bias.path, torch_dtype=torch.bfloat16, device="cuda") + lora = merge_lora([lora, lora_bias]) + save_file(lora, "model_coarse_fine_bias.safetensors") + + +def generate_image(lora_path, prompt, seed): + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + ) + pipe.load_lora(pipe.dit, lora_path) + image = pipe(prompt, seed=seed, height=1024, width=1024, num_inference_steps=50) + return image + + +demo_style() +image = generate_image("model_style.safetensors", "a cat", 0) +image.save("image_1.jpg") + +demo_coarse_fine_bias() +image = generate_image("model_coarse_fine_bias.safetensors", "bowl", 1) +image.save("image_2.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py new file mode 100644 index 0000000..b91d606 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py @@ -0,0 +1,134 @@ +from diffsynth.pipelines.qwen_image import ( + QwenImagePipeline, ModelConfig, + QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode +) +from diffsynth.utils.lora import merge_lora +from diffsynth import load_state_dict +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +vram_config_disk_offload = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +def demo_style(): + # Load models + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Style.safetensors", **vram_config_disk_offload), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, + ) + + # Load images + snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/style/1/*", + local_dir="data/examples" + ) + images = [ + Image.open("data/examples/assets/style/1/0.jpg"), + Image.open("data/examples/assets/style/1/1.jpg"), + Image.open("data/examples/assets/style/1/2.jpg"), + Image.open("data/examples/assets/style/1/3.jpg"), + Image.open("data/examples/assets/style/1/4.jpg"), + ] + + # Model inference + with torch.no_grad(): + embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] + save_file(lora, "model_style.safetensors") + + +def demo_coarse_fine_bias(): + # Load models + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Coarse.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Fine.safetensors", **vram_config_disk_offload), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, + ) + + # Load images + snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/lora/3/*", + local_dir="data/examples" + ) + images = [ + Image.open("data/examples/assets/lora/3/0.jpg"), + Image.open("data/examples/assets/lora/3/1.jpg"), + Image.open("data/examples/assets/lora/3/2.jpg"), + Image.open("data/examples/assets/lora/3/3.jpg"), + Image.open("data/examples/assets/lora/3/4.jpg"), + Image.open("data/examples/assets/lora/3/5.jpg"), + ] + + # Model inference + with torch.no_grad(): + embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] + lora_bias = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Bias.safetensors") + lora_bias.download_if_necessary() + lora_bias = load_state_dict(lora_bias.path, torch_dtype=torch.bfloat16, device="cuda") + lora = merge_lora([lora, lora_bias]) + save_file(lora, "model_coarse_fine_bias.safetensors") + + +def generate_image(lora_path, prompt, seed): + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, + ) + pipe.load_lora(pipe.dit, lora_path) + image = pipe(prompt, seed=seed, height=1024, width=1024, num_inference_steps=50) + return image + + +demo_style() +image = generate_image("model_style.safetensors", "a cat", 0) +image.save("image_1.jpg") + +demo_coarse_fine_bias() +image = generate_image("model_coarse_fine_bias.safetensors", "bowl", 1) +image.save("image_2.jpg")