Support Torch Compile (#1368)

* support simple compile

* add support for compile

* minor fix

* minor fix

* minor fix
This commit is contained in:
Hong Zhang
2026-03-24 11:19:43 +08:00
committed by GitHub
parent e2a3a987da
commit ae8cb139e8
18 changed files with 59 additions and 2 deletions

View File

@@ -1270,6 +1270,9 @@ class LLMAdapter(nn.Module):
class AnimaDiT(MiniTrainDIT):
_repeated_blocks = ["Block"]
def __init__(self):
kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn}
super().__init__(**kwargs)

View File

@@ -879,6 +879,9 @@ class Flux2Modulation(nn.Module):
class Flux2DiT(torch.nn.Module):
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
def __init__(
self,
patch_size: int = 1,

View File

@@ -275,6 +275,9 @@ class AdaLayerNormContinuous(torch.nn.Module):
class FluxDiT(torch.nn.Module):
_repeated_blocks = ["FluxJointTransformerBlock", "FluxSingleTransformerBlock"]
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])

View File

@@ -1280,6 +1280,7 @@ class LTXModel(torch.nn.Module):
LTX model transformer implementation.
This class implements the transformer blocks for the LTX model.
"""
_repeated_blocks = ["BasicAVTransformerBlock"]
def __init__( # noqa: PLR0913
self,

View File

@@ -549,6 +549,9 @@ class QwenImageTransformerBlock(nn.Module):
class QwenImageDiT(torch.nn.Module):
_repeated_blocks = ["QwenImageTransformerBlock"]
def __init__(
self,
num_layers: int = 60,

View File

@@ -336,6 +336,9 @@ class WanToDanceInjector(nn.Module):
class WanModel(torch.nn.Module):
_repeated_blocks = ["DiTBlock"]
def __init__(
self,
dim: int,

View File

@@ -326,6 +326,7 @@ class RopeEmbedder:
class ZImageDiT(nn.Module):
_supports_gradient_checkpointing = True
_no_split_modules = ["ZImageTransformerBlock"]
_repeated_blocks = ["ZImageTransformerBlock"]
def __init__(
self,