diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 5fc95c3..5dbcdea 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -42,6 +42,7 @@ qwen_image_series = [ "model_hash": "5722b5c873720009de96422993b15682", "model_name": "dinov3_image_encoder", "model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.dino_v3.DINOv3StateDictConverter", }, { # Example: diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py index 052f856..358ec1f 100644 --- a/diffsynth/models/dinov3_image_encoder.py +++ b/diffsynth/models/dinov3_image_encoder.py @@ -1,5 +1,5 @@ -from transformers import DINOv3ViTModel, DINOv3ViTImageProcessor -from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig +from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTModel, DINOv3ViTConfig +from transformers import DINOv3ViTImageProcessor import torch from ..core.device.npu_compatible_device import get_device_type @@ -82,7 +82,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel): hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) position_embeddings = self.rope_embeddings(pixel_values) - for i, layer_module in enumerate(self.layer): + for i, layer_module in enumerate(self.model.layer): layer_head_mask = head_mask[i] if head_mask is not None else None hidden_states = layer_module( hidden_states, diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 58e1d15..c00f678 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -1,11 +1,11 @@ -from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig +from transformers.models.siglip.modeling_siglip import SiglipVisionModel, SiglipVisionConfig from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessor import torch from diffsynth.core.device.npu_compatible_device import get_device_type -class Siglip2ImageEncoder(SiglipVisionTransformer): +class Siglip2ImageEncoder(SiglipVisionModel): def __init__(self): config = SiglipVisionConfig( attention_dropout = 0.0, diff --git a/diffsynth/utils/state_dict_converters/dino_v3.py b/diffsynth/utils/state_dict_converters/dino_v3.py new file mode 100644 index 0000000..8ea865d --- /dev/null +++ b/diffsynth/utils/state_dict_converters/dino_v3.py @@ -0,0 +1,9 @@ +def DINOv3StateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + value = state_dict[key] + if key.startswith("layer"): + new_state_dict["model." + key] = value + else: + new_state_dict[key] = value + return new_state_dict