diff --git a/README.md b/README.md
index 15d6597..704e9ac 100644
--- a/README.md
+++ b/README.md
@@ -32,6 +32,9 @@ We believe that a well-developed open-source code framework can lower the thresh
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
+
+- **March 2, 2026** Added support for [Anima](https://modelscope.cn/models/circlestone-labs/Anima). For details, please refer to the [documentation](docs/en/Model_Details/Anima.md). This is an interesting anime-style image generation model. We look forward to its future updates.
+
- **February 26, 2026** Added full and lora training support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details.
- **February 10, 2026** Added inference support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details. Support for model training will be implemented in the future.
@@ -343,6 +346,60 @@ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
+#### Anima: [/docs/en/Model_Details/Anima.md](/docs/en/Model_Details/Anima.md)
+
+
+
+Quick Start
+
+Run the following code to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 8GB VRAM.
+
+```python
+from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": "disk",
+ "offload_device": "disk",
+ "onload_dtype": "disk",
+ "onload_device": "disk",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = AnimaImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
+negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
+image = pipe(prompt, seed=0, num_inference_steps=50)
+image.save("image.jpg")
+```
+
+
+
+
+
+Examples
+
+Example code for Anima is located at: [/examples/anima/](/examples/anima/)
+
+| Model ID | Inference | Low VRAM Inference | Full Training | Validation after Full Training | LoRA Training | Validation after LoRA Training |
+|-|-|-|-|-|-|-|
+|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|
+
+
+
#### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
diff --git a/README_zh.md b/README_zh.md
index ec3cae7..9c95503 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -32,6 +32,9 @@ DiffSynth 目前包括两个开源项目:
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
+
+- **2026年3月2日** 新增对[Anima](https://modelscope.cn/models/circlestone-labs/Anima)的支持,详见[文档](docs/zh/Model_Details/Anima.md)。这是一个有趣的动漫风格图像生成模型,我们期待其后续的模型更新。
+
- **2026年2月26日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型全量微调与LoRA训练支持,详见[文档](docs/zh/Model_Details/LTX-2.md)。
- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。
@@ -343,6 +346,60 @@ FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
+#### Anima: [/docs/zh/Model_Details/Anima.md](/docs/zh/Model_Details/Anima.md)
+
+
+
+快速开始
+
+运行以下代码可以快速加载 [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
+
+```python
+from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": "disk",
+ "offload_device": "disk",
+ "onload_dtype": "disk",
+ "onload_device": "disk",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = AnimaImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
+negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
+image = pipe(prompt, seed=0, num_inference_steps=50)
+image.save("image.jpg")
+```
+
+
+
+
+
+示例代码
+
+Anima 的示例代码位于:[/examples/anima/](/examples/anima/)
+
+|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
+|-|-|-|-|-|-|-|
+|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|
+
+
+
#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py
index fbca133..f9fa595 100644
--- a/diffsynth/configs/model_configs.py
+++ b/diffsynth/configs/model_configs.py
@@ -719,4 +719,20 @@ ltx2_series = [
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
},
]
-MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series
+anima_series = [
+ {
+ # Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors")
+ "model_hash": "a9995952c2d8e63cf82e115005eb61b9",
+ "model_name": "z_image_text_encoder",
+ "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
+ "extra_kwargs": {"model_size": "0.6B"},
+ },
+ {
+ # Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors")
+ "model_hash": "417673936471e79e31ed4d186d7a3f4a",
+ "model_name": "anima_dit",
+ "model_class": "diffsynth.models.anima_dit.AnimaDiT",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
+ }
+]
+MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series
diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py
index 0f360ef..d86f5fa 100644
--- a/diffsynth/configs/vram_management_module_maps.py
+++ b/diffsynth/configs/vram_management_module_maps.py
@@ -243,4 +243,10 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
+ "diffsynth.models.anima_dit.AnimaDiT": {
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
+ },
}
diff --git a/diffsynth/models/anima_dit.py b/diffsynth/models/anima_dit.py
new file mode 100644
index 0000000..dbd1407
--- /dev/null
+++ b/diffsynth/models/anima_dit.py
@@ -0,0 +1,1304 @@
+# original code from: comfy/ldm/cosmos/predict2.py
+
+import torch
+from torch import nn
+from einops import rearrange, repeat
+from einops.layers.torch import Rearrange
+import logging
+from typing import Callable, Optional, Tuple, List
+import math
+from torchvision import transforms
+from ..core.attention import attention_forward
+from ..core.gradient import gradient_checkpoint_forward
+
+
+class VideoPositionEmb(nn.Module):
+ def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
+ """
+ It delegates the embedding generation to generate_embeddings function.
+ """
+ B_T_H_W_C = x_B_T_H_W_C.shape
+ embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
+
+ return embeddings
+
+ def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
+ raise NotImplementedError
+
+
+def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
+ """
+ Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
+
+ Args:
+ x (torch.Tensor): The input tensor to normalize.
+ dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
+ eps (float, optional): A small constant to ensure numerical stability during division.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+ """
+ if dim is None:
+ dim = list(range(1, x.ndim))
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
+ return x / norm.to(x.dtype)
+
+
+class LearnablePosEmbAxis(VideoPositionEmb):
+ def __init__(
+ self,
+ *, # enforce keyword arguments
+ interpolation: str,
+ model_channels: int,
+ len_h: int,
+ len_w: int,
+ len_t: int,
+ device=None,
+ dtype=None,
+ **kwargs,
+ ):
+ """
+ Args:
+ interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
+ """
+ del kwargs # unused
+ super().__init__()
+ self.interpolation = interpolation
+ assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
+
+ self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
+ self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
+ self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
+
+ def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
+ B, T, H, W, _ = B_T_H_W_C
+ if self.interpolation == "crop":
+ emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
+ emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
+ emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
+ emb = (
+ repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
+ + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
+ + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
+ )
+ assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
+ else:
+ raise ValueError(f"Unknown interpolation method {self.interpolation}")
+
+ return normalize(emb, dim=-1, eps=1e-6)
+
+
+class VideoRopePosition3DEmb(VideoPositionEmb):
+ def __init__(
+ self,
+ *, # enforce keyword arguments
+ head_dim: int,
+ len_h: int,
+ len_w: int,
+ len_t: int,
+ base_fps: int = 24,
+ h_extrapolation_ratio: float = 1.0,
+ w_extrapolation_ratio: float = 1.0,
+ t_extrapolation_ratio: float = 1.0,
+ enable_fps_modulation: bool = True,
+ device=None,
+ **kwargs, # used for compatibility with other positional embeddings; unused in this class
+ ):
+ del kwargs
+ super().__init__()
+ self.base_fps = base_fps
+ self.max_h = len_h
+ self.max_w = len_w
+ self.enable_fps_modulation = enable_fps_modulation
+
+ dim = head_dim
+ dim_h = dim // 6 * 2
+ dim_w = dim_h
+ dim_t = dim - 2 * dim_h
+ assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
+ self.register_buffer(
+ "dim_spatial_range",
+ torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
+ persistent=False,
+ )
+ self.register_buffer(
+ "dim_temporal_range",
+ torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
+ persistent=False,
+ )
+
+ self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
+ self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
+ self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
+
+ def generate_embeddings(
+ self,
+ B_T_H_W_C: torch.Size,
+ fps: Optional[torch.Tensor] = None,
+ h_ntk_factor: Optional[float] = None,
+ w_ntk_factor: Optional[float] = None,
+ t_ntk_factor: Optional[float] = None,
+ device=None,
+ dtype=None,
+ ):
+ """
+ Generate embeddings for the given input size.
+
+ Args:
+ B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
+ fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
+ h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
+ w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
+ t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
+
+ Returns:
+ Not specified in the original code snippet.
+ """
+ h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
+ w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
+ t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
+
+ h_theta = 10000.0 * h_ntk_factor
+ w_theta = 10000.0 * w_ntk_factor
+ t_theta = 10000.0 * t_ntk_factor
+
+ h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
+ w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
+ temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
+
+ B, T, H, W, _ = B_T_H_W_C
+ seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
+ uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
+ assert (
+ uniform_fps or B == 1 or T == 1
+ ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
+ half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
+ half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
+
+ # apply sequence scaling in temporal dimension
+ if fps is None or self.enable_fps_modulation is False: # image case
+ half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
+ else:
+ half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
+
+ half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
+ half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
+ half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
+
+ em_T_H_W_D = torch.cat(
+ [
+ repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
+ repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
+ repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
+ ]
+ , dim=-2,
+ )
+
+ return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
+
+
+def apply_rotary_pos_emb(
+ t: torch.Tensor,
+ freqs: torch.Tensor,
+) -> torch.Tensor:
+ t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
+ t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
+ t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
+ return t_out
+
+
+# ---------------------- Feed Forward Network -----------------------
+class GPT2FeedForward(nn.Module):
+ def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
+ super().__init__()
+ self.activation = nn.GELU()
+ self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
+ self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
+
+ self._layer_id = None
+ self._dim = d_model
+ self._hidden_dim = d_ff
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.layer1(x)
+
+ x = self.activation(x)
+ x = self.layer2(x)
+ return x
+
+
+def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
+ """Computes multi-head attention using PyTorch's native implementation.
+
+ This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
+ It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
+ attention, and rearranges the output back to the original format.
+
+ The input tensor names use the following dimension conventions:
+
+ - B: batch size
+ - S: sequence length
+ - H: number of attention heads
+ - D: head dimension
+
+ Args:
+ q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
+ k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
+ v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
+
+ Returns:
+ Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
+ """
+ in_q_shape = q_B_S_H_D.shape
+ in_k_shape = k_B_S_H_D.shape
+ q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
+ k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
+ v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
+ return attention_forward(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, out_pattern="b s (n d)")
+
+
+class Attention(nn.Module):
+ """
+ A flexible attention module supporting both self-attention and cross-attention mechanisms.
+
+ This module implements a multi-head attention layer that can operate in either self-attention
+ or cross-attention mode. The mode is determined by whether a context dimension is provided.
+ The implementation uses scaled dot-product attention and supports optional bias terms and
+ dropout regularization.
+
+ Args:
+ query_dim (int): The dimensionality of the query vectors.
+ context_dim (int, optional): The dimensionality of the context (key/value) vectors.
+ If None, the module operates in self-attention mode using query_dim. Default: None
+ n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
+ head_dim (int, optional): The dimension of each attention head. Default: 64
+ dropout (float, optional): Dropout probability applied to the output. Default: 0.0
+ qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
+ backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
+
+ Examples:
+ >>> # Self-attention with 512 dimensions and 8 heads
+ >>> self_attn = Attention(query_dim=512)
+ >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
+ >>> out = self_attn(x) # (32, 16, 512)
+
+ >>> # Cross-attention
+ >>> cross_attn = Attention(query_dim=512, context_dim=256)
+ >>> query = torch.randn(32, 16, 512)
+ >>> context = torch.randn(32, 8, 256)
+ >>> out = cross_attn(query, context) # (32, 16, 512)
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ context_dim: Optional[int] = None,
+ n_heads: int = 8,
+ head_dim: int = 64,
+ dropout: float = 0.0,
+ device=None,
+ dtype=None,
+ operations=None,
+ ) -> None:
+ super().__init__()
+ logging.debug(
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{n_heads} heads with a dimension of {head_dim}."
+ )
+ self.is_selfattn = context_dim is None # self attention
+
+ context_dim = query_dim if context_dim is None else context_dim
+ inner_dim = head_dim * n_heads
+
+ self.n_heads = n_heads
+ self.head_dim = head_dim
+ self.query_dim = query_dim
+ self.context_dim = context_dim
+
+ self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
+ self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
+
+ self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
+ self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
+
+ self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
+ self.v_norm = nn.Identity()
+
+ self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
+ self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
+
+ self.attn_op = torch_attention_op
+
+ self._query_dim = query_dim
+ self._context_dim = context_dim
+ self._inner_dim = inner_dim
+
+ def compute_qkv(
+ self,
+ x: torch.Tensor,
+ context: Optional[torch.Tensor] = None,
+ rope_emb: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ q = self.q_proj(x)
+ context = x if context is None else context
+ k = self.k_proj(context)
+ v = self.v_proj(context)
+ q, k, v = map(
+ lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
+ (q, k, v),
+ )
+
+ def apply_norm_and_rotary_pos_emb(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ v = self.v_norm(v)
+ if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
+ q = apply_rotary_pos_emb(q, rope_emb)
+ k = apply_rotary_pos_emb(k, rope_emb)
+ return q, k, v
+
+ q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
+
+ return q, k, v
+
+ def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
+ result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
+ return self.output_dropout(self.output_proj(result))
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: Optional[torch.Tensor] = None,
+ rope_emb: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
+ ) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): The query tensor of shape [B, Mq, K]
+ context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
+ """
+ q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
+ return self.compute_attention(q, k, v, transformer_options=transformer_options)
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int):
+ super().__init__()
+ self.num_channels = num_channels
+
+ def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
+ assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
+ timesteps = timesteps_B_T.flatten().float()
+ half_dim = self.num_channels // 2
+ exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
+ exponent = exponent / (half_dim - 0.0)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ sin_emb = torch.sin(emb)
+ cos_emb = torch.cos(emb)
+ emb = torch.cat([cos_emb, sin_emb], dim=-1)
+
+ return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
+ super().__init__()
+ logging.debug(
+ f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
+ )
+ self.in_dim = in_features
+ self.out_dim = out_features
+ self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
+ self.activation = nn.SiLU()
+ self.use_adaln_lora = use_adaln_lora
+ if use_adaln_lora:
+ self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
+ else:
+ self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
+
+ def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ emb = self.linear_1(sample)
+ emb = self.activation(emb)
+ emb = self.linear_2(emb)
+
+ if self.use_adaln_lora:
+ adaln_lora_B_T_3D = emb
+ emb_B_T_D = sample
+ else:
+ adaln_lora_B_T_3D = None
+ emb_B_T_D = emb
+
+ return emb_B_T_D, adaln_lora_B_T_3D
+
+
+class PatchEmbed(nn.Module):
+ """
+ PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
+ depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
+ making it suitable for video and image processing tasks. It supports dividing the input into patches
+ and embedding each patch into a vector of size `out_channels`.
+
+ Parameters:
+ - spatial_patch_size (int): The size of each spatial patch.
+ - temporal_patch_size (int): The size of each temporal patch.
+ - in_channels (int): Number of input channels. Default: 3.
+ - out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
+ - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
+ """
+
+ def __init__(
+ self,
+ spatial_patch_size: int,
+ temporal_patch_size: int,
+ in_channels: int = 3,
+ out_channels: int = 768,
+ device=None, dtype=None, operations=None
+ ):
+ super().__init__()
+ self.spatial_patch_size = spatial_patch_size
+ self.temporal_patch_size = temporal_patch_size
+
+ self.proj = nn.Sequential(
+ Rearrange(
+ "b c (t r) (h m) (w n) -> b t h w (c r m n)",
+ r=temporal_patch_size,
+ m=spatial_patch_size,
+ n=spatial_patch_size,
+ ),
+ operations.Linear(
+ in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
+ ),
+ )
+ self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the PatchEmbed module.
+
+ Parameters:
+ - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
+ B is the batch size,
+ C is the number of channels,
+ T is the temporal dimension,
+ H is the height, and
+ W is the width of the input.
+
+ Returns:
+ - torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
+ """
+ assert x.dim() == 5
+ _, _, T, H, W = x.shape
+ assert (
+ H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
+ ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
+ assert T % self.temporal_patch_size == 0
+ x = self.proj(x)
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of video DiT.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ spatial_patch_size: int,
+ temporal_patch_size: int,
+ out_channels: int,
+ use_adaln_lora: bool = False,
+ adaln_lora_dim: int = 256,
+ device=None, dtype=None, operations=None
+ ):
+ super().__init__()
+ self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = operations.Linear(
+ hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
+ )
+ self.hidden_size = hidden_size
+ self.n_adaln_chunks = 2
+ self.use_adaln_lora = use_adaln_lora
+ self.adaln_lora_dim = adaln_lora_dim
+ if use_adaln_lora:
+ self.adaln_modulation = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
+ operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
+ )
+ else:
+ self.adaln_modulation = nn.Sequential(
+ nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
+ )
+
+ def forward(
+ self,
+ x_B_T_H_W_D: torch.Tensor,
+ emb_B_T_D: torch.Tensor,
+ adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
+ ):
+ if self.use_adaln_lora:
+ assert adaln_lora_B_T_3D is not None
+ shift_B_T_D, scale_B_T_D = (
+ self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
+ ).chunk(2, dim=-1)
+ else:
+ shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
+
+ shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
+ scale_B_T_D, "b t d -> b t 1 1 d"
+ )
+
+ def _fn(
+ _x_B_T_H_W_D: torch.Tensor,
+ _norm_layer: nn.Module,
+ _scale_B_T_1_1_D: torch.Tensor,
+ _shift_B_T_1_1_D: torch.Tensor,
+ ) -> torch.Tensor:
+ return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
+
+ x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
+ x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
+ return x_B_T_H_W_O
+
+
+class Block(nn.Module):
+ """
+ A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
+ Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
+
+ Parameters:
+ x_dim (int): Dimension of input features
+ context_dim (int): Dimension of context features for cross-attention
+ num_heads (int): Number of attention heads
+ mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
+ use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
+ adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
+
+ The block applies the following sequence:
+ 1. Self-attention with AdaLN modulation
+ 2. Cross-attention with AdaLN modulation
+ 3. MLP with AdaLN modulation
+
+ Each component uses skip connections and layer normalization.
+ """
+
+ def __init__(
+ self,
+ x_dim: int,
+ context_dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ use_adaln_lora: bool = False,
+ adaln_lora_dim: int = 256,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+ super().__init__()
+ self.x_dim = x_dim
+ self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
+ self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
+
+ self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
+ self.cross_attn = Attention(
+ x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
+ )
+
+ self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
+ self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
+
+ self.use_adaln_lora = use_adaln_lora
+ if self.use_adaln_lora:
+ self.adaln_modulation_self_attn = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
+ operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
+ )
+ self.adaln_modulation_cross_attn = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
+ operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
+ )
+ self.adaln_modulation_mlp = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
+ operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
+ )
+ else:
+ self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
+ self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
+ self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
+
+ def forward(
+ self,
+ x_B_T_H_W_D: torch.Tensor,
+ emb_B_T_D: torch.Tensor,
+ crossattn_emb: torch.Tensor,
+ rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
+ adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
+ extra_per_block_pos_emb: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
+ ) -> torch.Tensor:
+ residual_dtype = x_B_T_H_W_D.dtype
+ compute_dtype = emb_B_T_D.dtype
+ if extra_per_block_pos_emb is not None:
+ x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
+
+ if self.use_adaln_lora:
+ shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
+ self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
+ ).chunk(3, dim=-1)
+ shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
+ self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
+ ).chunk(3, dim=-1)
+ shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
+ self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
+ ).chunk(3, dim=-1)
+ else:
+ shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
+ emb_B_T_D
+ ).chunk(3, dim=-1)
+ shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
+ emb_B_T_D
+ ).chunk(3, dim=-1)
+ shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
+
+ # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
+ shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
+ scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
+ gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
+
+ shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
+ scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
+ gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
+
+ shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
+ scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
+ gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
+
+ B, T, H, W, D = x_B_T_H_W_D.shape
+
+ def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
+ return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
+
+ normalized_x_B_T_H_W_D = _fn(
+ x_B_T_H_W_D,
+ self.layer_norm_self_attn,
+ scale_self_attn_B_T_1_1_D,
+ shift_self_attn_B_T_1_1_D,
+ )
+ result_B_T_H_W_D = rearrange(
+ self.self_attn(
+ # normalized_x_B_T_HW_D,
+ rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
+ None,
+ rope_emb=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
+ ),
+ "b (t h w) d -> b t h w d",
+ t=T,
+ h=H,
+ w=W,
+ )
+ x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
+
+ def _x_fn(
+ _x_B_T_H_W_D: torch.Tensor,
+ layer_norm_cross_attn: Callable,
+ _scale_cross_attn_B_T_1_1_D: torch.Tensor,
+ _shift_cross_attn_B_T_1_1_D: torch.Tensor,
+ transformer_options: Optional[dict] = {},
+ ) -> torch.Tensor:
+ _normalized_x_B_T_H_W_D = _fn(
+ _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
+ )
+ _result_B_T_H_W_D = rearrange(
+ self.cross_attn(
+ rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
+ crossattn_emb,
+ rope_emb=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
+ ),
+ "b (t h w) d -> b t h w d",
+ t=T,
+ h=H,
+ w=W,
+ )
+ return _result_B_T_H_W_D
+
+ result_B_T_H_W_D = _x_fn(
+ x_B_T_H_W_D,
+ self.layer_norm_cross_attn,
+ scale_cross_attn_B_T_1_1_D,
+ shift_cross_attn_B_T_1_1_D,
+ transformer_options=transformer_options,
+ )
+ x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
+
+ normalized_x_B_T_H_W_D = _fn(
+ x_B_T_H_W_D,
+ self.layer_norm_mlp,
+ scale_mlp_B_T_1_1_D,
+ shift_mlp_B_T_1_1_D,
+ )
+ result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
+ x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
+ return x_B_T_H_W_D
+
+
+class MiniTrainDIT(nn.Module):
+ """
+ A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
+ A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
+
+ Args:
+ max_img_h (int): Maximum height of the input images.
+ max_img_w (int): Maximum width of the input images.
+ max_frames (int): Maximum number of frames in the video sequence.
+ in_channels (int): Number of input channels (e.g., RGB channels for color images).
+ out_channels (int): Number of output channels.
+ patch_spatial (tuple): Spatial resolution of patches for input processing.
+ patch_temporal (int): Temporal resolution of patches for input processing.
+ concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
+ model_channels (int): Base number of channels used throughout the model.
+ num_blocks (int): Number of transformer blocks.
+ num_heads (int): Number of heads in the multi-head attention layers.
+ mlp_ratio (float): Expansion ratio for MLP blocks.
+ crossattn_emb_channels (int): Number of embedding channels for cross-attention.
+ pos_emb_cls (str): Type of positional embeddings.
+ pos_emb_learnable (bool): Whether positional embeddings are learnable.
+ pos_emb_interpolation (str): Method for interpolating positional embeddings.
+ min_fps (int): Minimum frames per second.
+ max_fps (int): Maximum frames per second.
+ use_adaln_lora (bool): Whether to use AdaLN-LoRA.
+ adaln_lora_dim (int): Dimension for AdaLN-LoRA.
+ rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
+ rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
+ rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
+ extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
+ extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
+ extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
+ extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
+ """
+
+ def __init__(
+ self,
+ max_img_h: int,
+ max_img_w: int,
+ max_frames: int,
+ in_channels: int,
+ out_channels: int,
+ patch_spatial: int, # tuple,
+ patch_temporal: int,
+ concat_padding_mask: bool = True,
+ # attention settings
+ model_channels: int = 768,
+ num_blocks: int = 10,
+ num_heads: int = 16,
+ mlp_ratio: float = 4.0,
+ # cross attention settings
+ crossattn_emb_channels: int = 1024,
+ # positional embedding settings
+ pos_emb_cls: str = "sincos",
+ pos_emb_learnable: bool = False,
+ pos_emb_interpolation: str = "crop",
+ min_fps: int = 1,
+ max_fps: int = 30,
+ use_adaln_lora: bool = False,
+ adaln_lora_dim: int = 256,
+ rope_h_extrapolation_ratio: float = 1.0,
+ rope_w_extrapolation_ratio: float = 1.0,
+ rope_t_extrapolation_ratio: float = 1.0,
+ extra_per_block_abs_pos_emb: bool = False,
+ extra_h_extrapolation_ratio: float = 1.0,
+ extra_w_extrapolation_ratio: float = 1.0,
+ extra_t_extrapolation_ratio: float = 1.0,
+ rope_enable_fps_modulation: bool = True,
+ image_model=None,
+ device=None,
+ dtype=None,
+ operations=None,
+ ) -> None:
+ super().__init__()
+ self.dtype = dtype
+ self.max_img_h = max_img_h
+ self.max_img_w = max_img_w
+ self.max_frames = max_frames
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.patch_spatial = patch_spatial
+ self.patch_temporal = patch_temporal
+ self.num_heads = num_heads
+ self.num_blocks = num_blocks
+ self.model_channels = model_channels
+ self.concat_padding_mask = concat_padding_mask
+ # positional embedding settings
+ self.pos_emb_cls = pos_emb_cls
+ self.pos_emb_learnable = pos_emb_learnable
+ self.pos_emb_interpolation = pos_emb_interpolation
+ self.min_fps = min_fps
+ self.max_fps = max_fps
+ self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
+ self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
+ self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
+ self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
+ self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
+ self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
+ self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
+ self.rope_enable_fps_modulation = rope_enable_fps_modulation
+
+ self.build_pos_embed(device=device, dtype=dtype)
+ self.use_adaln_lora = use_adaln_lora
+ self.adaln_lora_dim = adaln_lora_dim
+ self.t_embedder = nn.Sequential(
+ Timesteps(model_channels),
+ TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
+ )
+
+ in_channels = in_channels + 1 if concat_padding_mask else in_channels
+ self.x_embedder = PatchEmbed(
+ spatial_patch_size=patch_spatial,
+ temporal_patch_size=patch_temporal,
+ in_channels=in_channels,
+ out_channels=model_channels,
+ device=device, dtype=dtype, operations=operations,
+ )
+
+ self.blocks = nn.ModuleList(
+ [
+ Block(
+ x_dim=model_channels,
+ context_dim=crossattn_emb_channels,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ use_adaln_lora=use_adaln_lora,
+ adaln_lora_dim=adaln_lora_dim,
+ device=device, dtype=dtype, operations=operations,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
+
+ self.final_layer = FinalLayer(
+ hidden_size=self.model_channels,
+ spatial_patch_size=self.patch_spatial,
+ temporal_patch_size=self.patch_temporal,
+ out_channels=self.out_channels,
+ use_adaln_lora=self.use_adaln_lora,
+ adaln_lora_dim=self.adaln_lora_dim,
+ device=device, dtype=dtype, operations=operations,
+ )
+
+ self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
+
+ def build_pos_embed(self, device=None, dtype=None) -> None:
+ if self.pos_emb_cls == "rope3d":
+ cls_type = VideoRopePosition3DEmb
+ else:
+ raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
+
+ logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
+ kwargs = dict(
+ model_channels=self.model_channels,
+ len_h=self.max_img_h // self.patch_spatial,
+ len_w=self.max_img_w // self.patch_spatial,
+ len_t=self.max_frames // self.patch_temporal,
+ max_fps=self.max_fps,
+ min_fps=self.min_fps,
+ is_learnable=self.pos_emb_learnable,
+ interpolation=self.pos_emb_interpolation,
+ head_dim=self.model_channels // self.num_heads,
+ h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
+ w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
+ t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
+ enable_fps_modulation=self.rope_enable_fps_modulation,
+ device=device,
+ )
+ self.pos_embedder = cls_type(
+ **kwargs, # type: ignore
+ )
+
+ if self.extra_per_block_abs_pos_emb:
+ kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
+ kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
+ kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
+ kwargs["device"] = device
+ kwargs["dtype"] = dtype
+ self.extra_pos_embedder = LearnablePosEmbAxis(
+ **kwargs, # type: ignore
+ )
+
+ def prepare_embedded_sequence(
+ self,
+ x_B_C_T_H_W: torch.Tensor,
+ fps: Optional[torch.Tensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
+
+ Args:
+ x_B_C_T_H_W (torch.Tensor): video
+ fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
+ If None, a default value (`self.base_fps`) will be used.
+ padding_mask (Optional[torch.Tensor]): current it is not used
+
+ Returns:
+ Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ - A tensor of shape (B, T, H, W, D) with the embedded sequence.
+ - An optional positional embedding tensor, returned only if the positional embedding class
+ (`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
+
+ Notes:
+ - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
+ - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
+ - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
+ the `self.pos_embedder` with the shape [T, H, W].
+ - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
+ `self.pos_embedder` with the fps tensor.
+ - Otherwise, the positional embeddings are generated without considering fps.
+ """
+ if self.concat_padding_mask:
+ if padding_mask is None:
+ padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
+ else:
+ padding_mask = transforms.functional.resize(
+ padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
+ )
+ x_B_C_T_H_W = torch.cat(
+ [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
+ )
+ x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
+
+ if self.extra_per_block_abs_pos_emb:
+ extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
+ else:
+ extra_pos_emb = None
+
+ if "rope" in self.pos_emb_cls.lower():
+ return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
+ x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
+
+ return x_B_T_H_W_D, None, extra_pos_emb
+
+ def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
+ x_B_C_Tt_Hp_Wp = rearrange(
+ x_B_T_H_W_M,
+ "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
+ p1=self.patch_spatial,
+ p2=self.patch_spatial,
+ t=self.patch_temporal,
+ )
+ return x_B_C_Tt_Hp_Wp
+
+ def pad_to_patch_size(self, img, patch_size=(2, 2), padding_mode="circular"):
+ if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
+ padding_mode = "reflect"
+
+ pad = ()
+ for i in range(img.ndim - 2):
+ pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
+
+ return torch.nn.functional.pad(img, pad, mode=padding_mode)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ context: torch.Tensor,
+ fps: Optional[torch.Tensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ **kwargs,
+ ):
+ orig_shape = list(x.shape)
+ x = self.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
+ x_B_C_T_H_W = x
+ timesteps_B_T = timesteps
+ crossattn_emb = context
+ """
+ Args:
+ x: (B, C, T, H, W) tensor of spatial-temp inputs
+ timesteps: (B, ) tensor of timesteps
+ crossattn_emb: (B, N, D) tensor of cross-attention embeddings
+ """
+ x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
+ x_B_C_T_H_W,
+ fps=fps,
+ padding_mask=padding_mask,
+ )
+
+ if timesteps_B_T.ndim == 1:
+ timesteps_B_T = timesteps_B_T.unsqueeze(1)
+ t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
+ t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
+
+ # for logging purpose
+ affline_scale_log_info = {}
+ affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
+ self.affline_scale_log_info = affline_scale_log_info
+ self.affline_emb = t_embedding_B_T_D
+ self.crossattn_emb = crossattn_emb
+
+ if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
+ assert (
+ x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
+ ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
+
+ block_kwargs = {
+ "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
+ "adaln_lora_B_T_3D": adaln_lora_B_T_3D,
+ "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
+ "transformer_options": kwargs.get("transformer_options", {}),
+ }
+
+ # The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
+ # in fp32, but run attention and MLP modules in fp16.
+ # An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
+ # quality degradation and visual artifacts.
+ if x_B_T_H_W_D.dtype == torch.float16:
+ x_B_T_H_W_D = x_B_T_H_W_D.float()
+
+ for block in self.blocks:
+ x_B_T_H_W_D = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ x_B_T_H_W_D=x_B_T_H_W_D,
+ emb_B_T_D=t_embedding_B_T_D,
+ crossattn_emb=crossattn_emb,
+ **block_kwargs,
+ )
+
+ x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
+ x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
+ return x_B_C_Tt_Hp_Wp
+
+
+def rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb2(x, cos, sin, unsqueeze_dim=1):
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ x_embed = (x * cos) + (rotate_half(x) * sin)
+ return x_embed
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, head_dim):
+ super().__init__()
+ self.rope_theta = 10000
+ inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class LLMAdapterAttention(nn.Module):
+ def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
+ super().__init__()
+
+ inner_dim = head_dim * n_heads
+ self.n_heads = n_heads
+ self.head_dim = head_dim
+ self.query_dim = query_dim
+ self.context_dim = context_dim
+
+ self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
+ self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
+
+ self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
+ self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
+
+ self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
+
+ self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
+
+ def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
+ context = x if context is None else context
+ input_shape = x.shape[:-1]
+ q_shape = (*input_shape, self.n_heads, self.head_dim)
+ context_shape = context.shape[:-1]
+ kv_shape = (*context_shape, self.n_heads, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
+ value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
+
+ if position_embeddings is not None:
+ assert position_embeddings_context is not None
+ cos, sin = position_embeddings
+ query_states = apply_rotary_pos_emb2(query_states, cos, sin)
+ cos, sin = position_embeddings_context
+ key_states = apply_rotary_pos_emb2(key_states, cos, sin)
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
+
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output
+
+ def init_weights(self):
+ torch.nn.init.zeros_(self.o_proj.weight)
+
+
+class LLMAdapterTransformerBlock(nn.Module):
+ def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
+ super().__init__()
+ self.use_self_attn = use_self_attn
+
+ if self.use_self_attn:
+ self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
+ self.self_attn = LLMAdapterAttention(
+ query_dim=model_dim,
+ context_dim=model_dim,
+ n_heads=num_heads,
+ head_dim=model_dim//num_heads,
+ device=device,
+ dtype=dtype,
+ operations=operations,
+ )
+
+ self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
+ self.cross_attn = LLMAdapterAttention(
+ query_dim=model_dim,
+ context_dim=source_dim,
+ n_heads=num_heads,
+ head_dim=model_dim//num_heads,
+ device=device,
+ dtype=dtype,
+ operations=operations,
+ )
+
+ self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
+ self.mlp = nn.Sequential(
+ operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
+ nn.GELU(),
+ operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
+ )
+
+ def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
+ if self.use_self_attn:
+ normed = self.norm_self_attn(x)
+ attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
+ x = x + attn_out
+
+ normed = self.norm_cross_attn(x)
+ attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
+ x = x + attn_out
+
+ x = x + self.mlp(self.norm_mlp(x))
+ return x
+
+ def init_weights(self):
+ torch.nn.init.zeros_(self.mlp[2].weight)
+ self.cross_attn.init_weights()
+
+
+class LLMAdapter(nn.Module):
+ def __init__(
+ self,
+ source_dim=1024,
+ target_dim=1024,
+ model_dim=1024,
+ num_layers=6,
+ num_heads=16,
+ use_self_attn=True,
+ layer_norm=False,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+ super().__init__()
+
+ self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
+ if model_dim != target_dim:
+ self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
+ else:
+ self.in_proj = nn.Identity()
+ self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
+ self.blocks = nn.ModuleList([
+ LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
+ ])
+ self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
+ self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
+
+ def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
+ if target_attention_mask is not None:
+ target_attention_mask = target_attention_mask.to(torch.bool)
+ if target_attention_mask.ndim == 2:
+ target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
+
+ if source_attention_mask is not None:
+ source_attention_mask = source_attention_mask.to(torch.bool)
+ if source_attention_mask.ndim == 2:
+ source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
+
+ context = source_hidden_states
+ x = self.in_proj(self.embed(target_input_ids).to(context.dtype))
+ position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
+ position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
+ position_embeddings = self.rotary_emb(x, position_ids)
+ position_embeddings_context = self.rotary_emb(x, position_ids_context)
+ for block in self.blocks:
+ x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
+ return self.norm(self.out_proj(x))
+
+
+class AnimaDiT(MiniTrainDIT):
+ def __init__(self):
+ kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn}
+ super().__init__(**kwargs)
+ self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
+
+ def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
+ if text_ids is not None:
+ out = self.llm_adapter(text_embeds, text_ids)
+ if t5xxl_weights is not None:
+ out = out * t5xxl_weights
+
+ if out.shape[1] < 512:
+ out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
+ return out
+ else:
+ return text_embeds
+
+ def forward(
+ self,
+ x, timesteps, context,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ **kwargs
+ ):
+ t5xxl_ids = kwargs.pop("t5xxl_ids", None)
+ if t5xxl_ids is not None:
+ context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
+ return super().forward(
+ x, timesteps, context,
+ use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ **kwargs
+ )
diff --git a/diffsynth/pipelines/anima_image.py b/diffsynth/pipelines/anima_image.py
new file mode 100644
index 0000000..732ede5
--- /dev/null
+++ b/diffsynth/pipelines/anima_image.py
@@ -0,0 +1,261 @@
+import torch, math
+from PIL import Image
+from typing import Union
+from tqdm import tqdm
+from einops import rearrange
+import numpy as np
+from math import prod
+from transformers import AutoTokenizer
+
+from ..core.device.npu_compatible_device import get_device_type
+from ..diffusion import FlowMatchScheduler
+from ..core import ModelConfig, gradient_checkpoint_forward
+from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
+from ..utils.lora.merge import merge_lora
+
+from ..models.anima_dit import AnimaDiT
+from ..models.z_image_text_encoder import ZImageTextEncoder
+from ..models.wan_video_vae import WanVideoVAE
+
+
+class AnimaImagePipeline(BasePipeline):
+
+ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
+ super().__init__(
+ device=device, torch_dtype=torch_dtype,
+ height_division_factor=16, width_division_factor=16,
+ )
+ self.scheduler = FlowMatchScheduler("Z-Image")
+ self.text_encoder: ZImageTextEncoder = None
+ self.dit: AnimaDiT = None
+ self.vae: WanVideoVAE = None
+ self.tokenizer: AutoTokenizer = None
+ self.tokenizer_t5xxl: AutoTokenizer = None
+ self.in_iteration_models = ("dit",)
+ self.units = [
+ AnimaUnit_ShapeChecker(),
+ AnimaUnit_NoiseInitializer(),
+ AnimaUnit_InputImageEmbedder(),
+ AnimaUnit_PromptEmbedder(),
+ ]
+ self.model_fn = model_fn_anima
+
+
+ @staticmethod
+ def from_pretrained(
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = get_device_type(),
+ model_configs: list[ModelConfig] = [],
+ tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ tokenizer_t5xxl_config: ModelConfig = ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
+ vram_limit: float = None,
+ ):
+ # Initialize pipeline
+ pipe = AnimaImagePipeline(device=device, torch_dtype=torch_dtype)
+ model_pool = pipe.download_and_load_models(model_configs, vram_limit)
+
+ # Fetch models
+ pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
+ pipe.dit = model_pool.fetch_model("anima_dit")
+ pipe.vae = model_pool.fetch_model("wan_video_vae")
+ if tokenizer_config is not None:
+ tokenizer_config.download_if_necessary()
+ pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
+ if tokenizer_t5xxl_config is not None:
+ tokenizer_t5xxl_config.download_if_necessary()
+ pipe.tokenizer_t5xxl = AutoTokenizer.from_pretrained(tokenizer_t5xxl_config.path)
+ # VRAM Management
+ pipe.vram_management_enabled = pipe.check_vram_management_state()
+ return pipe
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt: str,
+ negative_prompt: str = "",
+ cfg_scale: float = 4.0,
+ # Image
+ input_image: Image.Image = None,
+ denoising_strength: float = 1.0,
+ # Shape
+ height: int = 1024,
+ width: int = 1024,
+ # Randomness
+ seed: int = None,
+ rand_device: str = "cpu",
+ # Steps
+ num_inference_steps: int = 30,
+ sigma_shift: float = None,
+ # Progress bar
+ progress_bar_cmd = tqdm,
+ ):
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+
+ # Parameters
+ inputs_posi = {
+ "prompt": prompt,
+ }
+ inputs_nega = {
+ "negative_prompt": negative_prompt,
+ }
+ inputs_shared = {
+ "cfg_scale": cfg_scale,
+ "input_image": input_image, "denoising_strength": denoising_strength,
+ "height": height, "width": width,
+ "seed": seed, "rand_device": rand_device,
+ "num_inference_steps": num_inference_steps,
+ }
+ for unit in self.units:
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+
+ # Denoise
+ self.load_models_to_device(self.in_iteration_models)
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ inputs_shared, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ image = self.vae.decode(inputs_shared["latents"].unsqueeze(2), device=self.device).squeeze(2)
+ image = self.vae_output_to_image(image)
+ self.load_models_to_device([])
+
+ return image
+
+
+class AnimaUnit_ShapeChecker(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("height", "width"),
+ output_params=("height", "width"),
+ )
+
+ def process(self, pipe: AnimaImagePipeline, height, width):
+ height, width = pipe.check_resize_height_width(height, width)
+ return {"height": height, "width": width}
+
+
+
+class AnimaUnit_NoiseInitializer(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("height", "width", "seed", "rand_device"),
+ output_params=("noise",),
+ )
+
+ def process(self, pipe: AnimaImagePipeline, height, width, seed, rand_device):
+ noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
+ return {"noise": noise}
+
+
+
+class AnimaUnit_InputImageEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "noise"),
+ output_params=("latents", "input_latents"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: AnimaImagePipeline, input_image, noise):
+ if input_image is None:
+ return {"latents": noise, "input_latents": None}
+ pipe.load_models_to_device(['vae'])
+ if isinstance(input_image, list):
+ input_latents = []
+ for image in input_image:
+ image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ input_latents.append(pipe.vae.encode(image))
+ input_latents = torch.concat(input_latents, dim=0)
+ else:
+ image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ input_latents = pipe.vae.encode(image.unsqueeze(2), device=pipe.device).squeeze(2)
+ if pipe.scheduler.training:
+ return {"latents": noise, "input_latents": input_latents}
+ else:
+ latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
+ return {"latents": latents, "input_latents": input_latents}
+
+
+class AnimaUnit_PromptEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt": "prompt"},
+ input_params_nega={"prompt": "negative_prompt"},
+ output_params=("prompt_emb",),
+ onload_model_names=("text_encoder",)
+ )
+
+ def encode_prompt(
+ self,
+ pipe: AnimaImagePipeline,
+ prompt,
+ device = None,
+ max_sequence_length: int = 512,
+ ):
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ text_inputs = pipe.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids.to(device)
+ prompt_masks = text_inputs.attention_mask.to(device).bool()
+
+ prompt_embeds = pipe.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_masks,
+ output_hidden_states=True,
+ ).hidden_states[-1]
+
+ t5xxl_text_inputs = pipe.tokenizer_t5xxl(
+ prompt,
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ t5xxl_ids = t5xxl_text_inputs.input_ids.to(device)
+
+ return prompt_embeds.to(pipe.torch_dtype), t5xxl_ids
+
+ def process(self, pipe: AnimaImagePipeline, prompt):
+ pipe.load_models_to_device(self.onload_model_names)
+ prompt_embeds, t5xxl_ids = self.encode_prompt(pipe, prompt, pipe.device)
+ return {"prompt_emb": prompt_embeds, "t5xxl_ids": t5xxl_ids}
+
+
+def model_fn_anima(
+ dit: AnimaDiT = None,
+ latents=None,
+ timestep=None,
+ prompt_emb=None,
+ t5xxl_ids=None,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ **kwargs
+):
+ latents = latents.unsqueeze(2)
+ timestep = timestep / 1000
+ model_output = dit(
+ x=latents,
+ timesteps=timestep,
+ context=prompt_emb,
+ t5xxl_ids=t5xxl_ids,
+ )
+ model_output = model_output.squeeze(2)
+ return model_output
diff --git a/diffsynth/utils/state_dict_converters/anima_dit.py b/diffsynth/utils/state_dict_converters/anima_dit.py
new file mode 100644
index 0000000..16afc76
--- /dev/null
+++ b/diffsynth/utils/state_dict_converters/anima_dit.py
@@ -0,0 +1,6 @@
+def AnimaDiTStateDictConverter(state_dict):
+ new_state_dict = {}
+ for key in state_dict:
+ value = state_dict[key]
+ new_state_dict[key.replace("net.", "")] = value
+ return new_state_dict
diff --git a/docs/en/Model_Details/Anima.md b/docs/en/Model_Details/Anima.md
new file mode 100644
index 0000000..0f3ae5a
--- /dev/null
+++ b/docs/en/Model_Details/Anima.md
@@ -0,0 +1,139 @@
+# Anima
+
+Anima is an image generation model trained and open-sourced by CircleStone Labs and Comfy Org.
+
+## Installation
+
+Before using this project for model inference and training, please install DiffSynth-Studio first.
+
+```shell
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+For more installation information, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
+
+## Quick Start
+
+The following code demonstrates how to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model for inference. VRAM management is enabled by default, allowing the framework to automatically control model parameter loading based on available VRAM. Minimum 8GB VRAM required.
+
+```python
+from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": "disk",
+ "offload_device": "disk",
+ "onload_dtype": "disk",
+ "onload_device": "disk",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = AnimaImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
+negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
+image = pipe(prompt, seed=0, num_inference_steps=50)
+image.save("image.jpg")
+```
+
+## Model Overview
+
+|Model ID|Inference|Low VRAM Inference|Full Training|Post-Full Training Validation|LoRA Training|Post-LoRA Training Validation|
+|-|-|-|-|-|-|-|
+|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|
+
+Special training scripts:
+
+* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)
+* FP8 Precision Training: [doc](../Training/FP8_Precision.md)
+* Two-Stage Split Training: [doc](../Training/Split_Training.md)
+* End-to-End Direct Distillation: [doc](../Training/Direct_Distill.md)
+
+## Model Inference
+
+Models are loaded through `AnimaImagePipeline.from_pretrained`, see [Model Inference](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
+
+Input parameters for `AnimaImagePipeline` inference include:
+
+* `prompt`: Text description of the desired image content.
+* `negative_prompt`: Content to exclude from the generated image (default: `""`).
+* `cfg_scale`: Classifier-free guidance parameter (default: 4.0).
+* `input_image`: Input image for image-to-image generation (default: `None`).
+* `denoising_strength`: Controls similarity to input image (default: 1.0).
+* `height`: Image height (must be multiple of 16, default: 1024).
+* `width`: Image width (must be multiple of 16, default: 1024).
+* `seed`: Random seed (default: `None`).
+* `rand_device`: Device for random noise generation (default: `"cpu"`).
+* `num_inference_steps`: Inference steps (default: 30).
+* `sigma_shift`: Scheduler sigma offset (default: `None`).
+* `progress_bar_cmd`: Progress bar implementation (default: `tqdm.tqdm`).
+
+For VRAM constraints, enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). Recommended low-VRAM configurations are provided in the "Model Overview" table above.
+
+## Model Training
+
+Anima models are trained through [`examples/anima/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/train.py) with parameters including:
+
+* General Training Parameters
+ * Dataset Configuration
+ * `--dataset_base_path`: Dataset root directory.
+ * `--dataset_metadata_path`: Metadata file path.
+ * `--dataset_repeat`: Dataset repetition per epoch.
+ * `--dataset_num_workers`: Dataloader worker count.
+ * `--data_file_keys`: Metadata fields to load (comma-separated).
+ * Model Loading
+ * `--model_paths`: Model paths (JSON format).
+ * `--model_id_with_origin_paths`: Model IDs with origin paths (e.g., `"anima-team/anima-1B:text_encoder/*.safetensors"`).
+ * `--extra_inputs`: Additional pipeline inputs (e.g., `controlnet_inputs` for ControlNet).
+ * `--fp8_models`: FP8-formatted models (same format as `--model_paths`).
+ * Training Configuration
+ * `--learning_rate`: Learning rate.
+ * `--num_epochs`: Training epochs.
+ * `--trainable_models`: Trainable components (e.g., `dit`, `vae`, `text_encoder`).
+ * `--find_unused_parameters`: Handle unused parameters in DDP training.
+ * `--weight_decay`: Weight decay value.
+ * `--task`: Training task (default: `sft`).
+ * Output Configuration
+ * `--output_path`: Model output directory.
+ * `--remove_prefix_in_ckpt`: Remove state dict prefixes.
+ * `--save_steps`: Model saving interval.
+ * LoRA Configuration
+ * `--lora_base_model`: Target model for LoRA.
+ * `--lora_target_modules`: Target modules for LoRA.
+ * `--lora_rank`: LoRA rank.
+ * `--lora_checkpoint`: LoRA checkpoint path.
+ * `--preset_lora_path`: Preloaded LoRA checkpoint path.
+ * `--preset_lora_model`: Model to merge LoRA with (e.g., `dit`).
+ * Gradient Configuration
+ * `--use_gradient_checkpointing`: Enable gradient checkpointing.
+ * `--use_gradient_checkpointing_offload`: Offload checkpointing to CPU.
+ * `--gradient_accumulation_steps`: Gradient accumulation steps.
+ * Image Resolution
+ * `--height`: Image height (empty for dynamic resolution).
+ * `--width`: Image width (empty for dynamic resolution).
+ * `--max_pixels`: Maximum pixel area for dynamic resolution.
+* Anima-Specific Parameters
+ * `--tokenizer_path`: Tokenizer path for text-to-image models.
+ * `--tokenizer_t5xxl_path`: T5-XXL tokenizer path.
+
+We provide a sample image dataset for testing:
+
+```shell
+modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
+```
+
+For training script details, refer to [Model Training](../Pipeline_Usage/Model_Training.md). For advanced training techniques, see [Training Framework Documentation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/).
\ No newline at end of file
diff --git a/docs/en/index.rst b/docs/en/index.rst
index ca38620..c4e2736 100644
--- a/docs/en/index.rst
+++ b/docs/en/index.rst
@@ -27,6 +27,7 @@ Welcome to DiffSynth-Studio's Documentation
Model_Details/Qwen-Image
Model_Details/FLUX2
Model_Details/Z-Image
+ Model_Details/Anima
Model_Details/LTX-2
.. toctree::
diff --git a/docs/zh/Model_Details/Anima.md b/docs/zh/Model_Details/Anima.md
new file mode 100644
index 0000000..0d5576b
--- /dev/null
+++ b/docs/zh/Model_Details/Anima.md
@@ -0,0 +1,139 @@
+# Anima
+
+Anima 是由 CircleStone Labs 与 Comfy Org 训练并开源的图像生成模型。
+
+## 安装
+
+在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
+
+```shell
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
+
+## 快速开始
+
+运行以下代码可以快速加载 [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
+
+```python
+from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": "disk",
+ "offload_device": "disk",
+ "onload_dtype": "disk",
+ "onload_device": "disk",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = AnimaImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
+negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
+image = pipe(prompt, seed=0, num_inference_steps=50)
+image.save("image.jpg")
+```
+
+## 模型总览
+
+|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
+|-|-|-|-|-|-|-|
+|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|
+
+特殊训练脚本:
+
+* 差分 LoRA 训练:[doc](../Training/Differential_LoRA.md)
+* FP8 精度训练:[doc](../Training/FP8_Precision.md)
+* 两阶段拆分训练:[doc](../Training/Split_Training.md)
+* 端到端直接蒸馏:[doc](../Training/Direct_Distill.md)
+
+## 模型推理
+
+模型通过 `AnimaImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
+
+`AnimaImagePipeline` 推理的输入参数包括:
+
+* `prompt`: 提示词,描述画面中出现的内容。
+* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
+* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4.0。
+* `input_image`: 输入图像,用于图像到图像的生成。默认为 `None`。
+* `denoising_strength`: 去噪强度,控制生成图像与输入图像的相似度,默认值为 1.0。
+* `height`: 图像高度,需保证高度为 16 的倍数,默认值为 1024。
+* `width`: 图像宽度,需保证宽度为 16 的倍数,默认值为 1024。
+* `seed`: 随机种子。默认为 `None`,即完全随机。
+* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
+* `num_inference_steps`: 推理次数,默认值为 30。
+* `sigma_shift`: 调度器的 sigma 偏移量,默认为 `None`。
+* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
+
+如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
+
+## 模型训练
+
+Anima 系列模型统一通过 [`examples/anima/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/train.py) 进行训练,脚本的参数包括:
+
+* 通用训练参数
+ * 数据集基础配置
+ * `--dataset_base_path`: 数据集的根目录。
+ * `--dataset_metadata_path`: 数据集的元数据文件路径。
+ * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
+ * `--dataset_num_workers`: 每个 Dataloder 的进程数量。
+ * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
+ * 模型加载配置
+ * `--model_paths`: 要加载的模型路径。JSON 格式。
+ * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"anima-team/anima-1B:text_encoder/*.safetensors"`。用逗号分隔。
+ * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练 ControlNet 模型时需要额外参数 `controlnet_inputs`,以 `,` 分隔。
+ * `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。
+ * 训练基础配置
+ * `--learning_rate`: 学习率。
+ * `--num_epochs`: 轮数(Epoch)。
+ * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
+ * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。
+ * `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。
+ * `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。
+ * 输出配置
+ * `--output_path`: 模型保存路径。
+ * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
+ * `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。
+ * LoRA 配置
+ * `--lora_base_model`: LoRA 添加到哪个模型上。
+ * `--lora_target_modules`: LoRA 添加到哪些层上。
+ * `--lora_rank`: LoRA 的秩(Rank)。
+ * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
+ * `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。
+ * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
+ * 梯度配置
+ * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
+ * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
+ * `--gradient_accumulation_steps`: 梯度累积步数。
+ * 图像宽高配置(适用于图像生成模型和视频生成模型)
+ * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
+ * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
+ * `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。
+* Anima 专有参数
+ * `--tokenizer_path`: tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。
+ * `--tokenizer_t5xxl_path`: T5-XXL tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。
+
+我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
+
+```shell
+modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
+```
+
+我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
diff --git a/docs/zh/index.rst b/docs/zh/index.rst
index d2afefc..4ee551a 100644
--- a/docs/zh/index.rst
+++ b/docs/zh/index.rst
@@ -27,6 +27,7 @@
Model_Details/Qwen-Image
Model_Details/FLUX2
Model_Details/Z-Image
+ Model_Details/Anima
Model_Details/LTX-2
.. toctree::
diff --git a/examples/anima/model_inference/anima-preview.py b/examples/anima/model_inference/anima-preview.py
new file mode 100644
index 0000000..9440bdf
--- /dev/null
+++ b/examples/anima/model_inference/anima-preview.py
@@ -0,0 +1,19 @@
+from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
+import torch
+
+
+pipe = AnimaImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors"),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors"),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/")
+)
+prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
+negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
+image = pipe(prompt, seed=0, num_inference_steps=50)
+image.save("image.jpg")
diff --git a/examples/anima/model_inference_low_vram/anima-preview.py b/examples/anima/model_inference_low_vram/anima-preview.py
new file mode 100644
index 0000000..bfe8e24
--- /dev/null
+++ b/examples/anima/model_inference_low_vram/anima-preview.py
@@ -0,0 +1,30 @@
+from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
+import torch
+
+
+vram_config = {
+ "offload_dtype": "disk",
+ "offload_device": "disk",
+ "onload_dtype": "disk",
+ "onload_device": "disk",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = AnimaImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
+negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
+image = pipe(prompt, seed=0, num_inference_steps=50)
+image.save("image.jpg")
diff --git a/examples/anima/model_training/full/anima-preview.sh b/examples/anima/model_training/full/anima-preview.sh
new file mode 100644
index 0000000..58bf844
--- /dev/null
+++ b/examples/anima/model_training/full/anima-preview.sh
@@ -0,0 +1,14 @@
+accelerate launch examples/anima/model_training/train.py \
+ --dataset_base_path data/example_image_dataset \
+ --dataset_metadata_path data/example_image_dataset/metadata.csv \
+ --max_pixels 1048576 \
+ --dataset_repeat 50 \
+ --model_id_with_origin_paths "circlestone-labs/Anima:split_files/diffusion_models/anima-preview.safetensors,circlestone-labs/Anima:split_files/text_encoders/qwen_3_06b_base.safetensors,circlestone-labs/Anima:split_files/vae/qwen_image_vae.safetensors" \
+ --tokenizer_path "Qwen/Qwen3-0.6B:./" \
+ --tokenizer_t5xxl_path "stabilityai/stable-diffusion-3.5-large:tokenizer_3/" \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --output_path "./models/train/anima-preview_full" \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing
\ No newline at end of file
diff --git a/examples/anima/model_training/lora/anima-preview.sh b/examples/anima/model_training/lora/anima-preview.sh
new file mode 100644
index 0000000..462a844
--- /dev/null
+++ b/examples/anima/model_training/lora/anima-preview.sh
@@ -0,0 +1,16 @@
+accelerate launch examples/anima/model_training/train.py \
+ --dataset_base_path data/example_image_dataset \
+ --dataset_metadata_path data/example_image_dataset/metadata.csv \
+ --max_pixels 1048576 \
+ --dataset_repeat 50 \
+ --model_id_with_origin_paths "circlestone-labs/Anima:split_files/diffusion_models/anima-preview.safetensors,circlestone-labs/Anima:split_files/text_encoders/qwen_3_06b_base.safetensors,circlestone-labs/Anima:split_files/vae/qwen_image_vae.safetensors" \
+ --tokenizer_path "Qwen/Qwen3-0.6B:./" \
+ --tokenizer_t5xxl_path "stabilityai/stable-diffusion-3.5-large:tokenizer_3/" \
+ --learning_rate 1e-4 \
+ --num_epochs 5 \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --output_path "./models/train/anima-preview_lora" \
+ --lora_base_model "dit" \
+ --lora_target_modules "" \
+ --lora_rank 32 \
+ --use_gradient_checkpointing
\ No newline at end of file
diff --git a/examples/anima/model_training/train.py b/examples/anima/model_training/train.py
new file mode 100644
index 0000000..89e7b72
--- /dev/null
+++ b/examples/anima/model_training/train.py
@@ -0,0 +1,145 @@
+import torch, os, argparse, accelerate
+from diffsynth.core import UnifiedDataset
+from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
+from diffsynth.diffusion import *
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+class AnimaTrainingModule(DiffusionTrainingModule):
+ def __init__(
+ self,
+ model_paths=None, model_id_with_origin_paths=None,
+ tokenizer_path=None, tokenizer_t5xxl_path=None,
+ trainable_models=None,
+ lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
+ preset_lora_path=None, preset_lora_model=None,
+ use_gradient_checkpointing=True,
+ use_gradient_checkpointing_offload=False,
+ extra_inputs=None,
+ fp8_models=None,
+ offload_models=None,
+ device="cpu",
+ task="sft",
+ ):
+ super().__init__()
+ # Load models
+ model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
+ tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"))
+ tokenizer_t5xxl_config = self.parse_path_or_model_id(tokenizer_t5xxl_path, ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"))
+ self.pipe = AnimaImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, tokenizer_t5xxl_config=tokenizer_t5xxl_config)
+ self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
+
+ # Training mode
+ self.switch_pipe_to_training_mode(
+ self.pipe, trainable_models,
+ lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
+ preset_lora_path, preset_lora_model,
+ task=task,
+ )
+
+ # Other configs
+ self.use_gradient_checkpointing = use_gradient_checkpointing
+ self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
+ self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
+ self.fp8_models = fp8_models
+ self.task = task
+ self.task_to_loss = {
+ "sft:data_process": lambda pipe, *args: args,
+ "direct_distill:data_process": lambda pipe, *args: args,
+ "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
+ "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
+ "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
+ "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
+ }
+
+ def get_pipeline_inputs(self, data):
+ inputs_posi = {"prompt": data["prompt"]}
+ inputs_nega = {"negative_prompt": ""}
+ inputs_shared = {
+ # Assume you are using this pipeline for inference,
+ # please fill in the input parameters.
+ "input_image": data["image"],
+ "height": data["image"].size[1],
+ "width": data["image"].size[0],
+ # Please do not modify the following parameters
+ # unless you clearly know what this will cause.
+ "cfg_scale": 1,
+ "rand_device": self.pipe.device,
+ "use_gradient_checkpointing": self.use_gradient_checkpointing,
+ "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
+ }
+ inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
+ return inputs_shared, inputs_posi, inputs_nega
+
+ def forward(self, data, inputs=None):
+ if inputs is None: inputs = self.get_pipeline_inputs(data)
+ inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
+ for unit in self.pipe.units:
+ inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
+ loss = self.task_to_loss[self.task](self.pipe, *inputs)
+ return loss
+
+
+def anima_parser():
+ parser = argparse.ArgumentParser(description="Training script for Anima models.")
+ parser = add_general_config(parser)
+ parser = add_image_size_config(parser)
+ parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
+ parser.add_argument("--tokenizer_t5xxl_path", type=str, default=None, help="Path to tokenizer_t5xxl.")
+ return parser
+
+
+if __name__ == "__main__":
+ parser = anima_parser()
+ args = parser.parse_args()
+ accelerator = accelerate.Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
+ )
+ dataset = UnifiedDataset(
+ base_path=args.dataset_base_path,
+ metadata_path=args.dataset_metadata_path,
+ repeat=args.dataset_repeat,
+ data_file_keys=args.data_file_keys.split(","),
+ main_data_operator=UnifiedDataset.default_image_operator(
+ base_path=args.dataset_base_path,
+ max_pixels=args.max_pixels,
+ height=args.height,
+ width=args.width,
+ height_division_factor=16,
+ width_division_factor=16,
+ )
+ )
+ model = AnimaTrainingModule(
+ model_paths=args.model_paths,
+ model_id_with_origin_paths=args.model_id_with_origin_paths,
+ tokenizer_path=args.tokenizer_path,
+ tokenizer_t5xxl_path=args.tokenizer_t5xxl_path,
+ trainable_models=args.trainable_models,
+ lora_base_model=args.lora_base_model,
+ lora_target_modules=args.lora_target_modules,
+ lora_rank=args.lora_rank,
+ lora_checkpoint=args.lora_checkpoint,
+ preset_lora_path=args.preset_lora_path,
+ preset_lora_model=args.preset_lora_model,
+ use_gradient_checkpointing=args.use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
+ extra_inputs=args.extra_inputs,
+ fp8_models=args.fp8_models,
+ offload_models=args.offload_models,
+ task=args.task,
+ device=accelerator.device,
+ )
+ model_logger = ModelLogger(
+ args.output_path,
+ remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
+ )
+ launcher_map = {
+ "sft:data_process": launch_data_process_task,
+ "direct_distill:data_process": launch_data_process_task,
+ "sft": launch_training_task,
+ "sft:train": launch_training_task,
+ "direct_distill": launch_training_task,
+ "direct_distill:train": launch_training_task,
+ }
+ launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
\ No newline at end of file
diff --git a/examples/anima/model_training/validate_full/anima-preview.py b/examples/anima/model_training/validate_full/anima-preview.py
new file mode 100644
index 0000000..9f31a5a
--- /dev/null
+++ b/examples/anima/model_training/validate_full/anima-preview.py
@@ -0,0 +1,21 @@
+from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
+from diffsynth.core import load_state_dict
+import torch
+
+
+pipe = AnimaImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors"),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors"),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/")
+)
+state_dict = load_state_dict("./models/train/anima-preview_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
+pipe.dit.load_state_dict(state_dict)
+prompt = "a dog"
+image = pipe(prompt=prompt, seed=0)
+image.save("image.jpg")
\ No newline at end of file
diff --git a/examples/anima/model_training/validate_lora/anima-preview.py b/examples/anima/model_training/validate_lora/anima-preview.py
new file mode 100644
index 0000000..df107d2
--- /dev/null
+++ b/examples/anima/model_training/validate_lora/anima-preview.py
@@ -0,0 +1,19 @@
+from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
+import torch
+
+
+pipe = AnimaImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors"),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors"),
+ ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/")
+)
+pipe.load_lora(pipe.dit, "./models/train/anima-preview_lora/epoch-4.safetensors")
+prompt = "a dog"
+image = pipe(prompt=prompt, seed=0)
+image.save("image.jpg")
\ No newline at end of file