add new quality metric

This commit is contained in:
YunhongLu-ZJU
2025-02-17 14:42:20 +08:00
parent 77d0f4d297
commit 991ba162bd
69 changed files with 88 additions and 1461 deletions

View File

@@ -0,0 +1,3 @@
from .base_model import *
from .clip_model import *
from .cross_modeling import *

View File

@@ -4,13 +4,13 @@ from transformers import AutoTokenizer
from torch import nn, einsum
from trainer.models.base_model import BaseModelConfig
from .base_model import BaseModelConfig
from transformers import CLIPConfig
from typing import Any, Optional, Tuple, Union
import torch
from trainer.models.cross_modeling import Cross_model
from .cross_modeling import Cross_model
import gc
@@ -91,7 +91,7 @@ class XCLIPModel(HFCLIPModel):
@dataclass
class ClipModelConfig(BaseModelConfig):
_target_: str = "trainer.models.clip_model.CLIPModel"
_target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel"
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"