Support Torch Compile (#1368)

* support simple compile

* add support for compile

* minor fix

* minor fix

* minor fix
This commit is contained in:
Hong Zhang
2026-03-24 11:19:43 +08:00
committed by GitHub
parent e2a3a987da
commit ae8cb139e8
18 changed files with 59 additions and 2 deletions

View File

@@ -339,6 +339,38 @@ class BasePipeline(torch.nn.Module):
noise_pred = noise_pred_posi
return noise_pred
def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs):
"""
compile the pipeline with torch.compile. The models that will be compiled are determined by the `compilable_models` attribute of the pipeline.
If a model has `_repeated_blocks` attribute, we will compile these blocks with regional compilation. Otherwise, we will compile the whole model.
See https://docs.pytorch.org/docs/stable/generated/torch.compile.html#torch.compile for details about compilation arguments.
Args:
mode: The compilation mode, which will be passed to `torch.compile`, options are "default", "reduce-overhead", "max-autotune" and "max-autotune-no-cudagraphs. Default to "default".
dynamic: Whether to enable dynamic graph compilation to support dynamic input shapes, which will be passed to `torch.compile`. Default to True (recommended).
fullgraph: Whether to use full graph compilation, which will be passed to `torch.compile`. Default to False (recommended).
compile_models: The list of model names to be compiled. If None, we will compile the models in `pipeline.compilable_models`. Default to None.
**kwargs: Other arguments for `torch.compile`.
"""
compile_models = compile_models or getattr(self, "compilable_models", [])
if len(compile_models) == 0:
print("No compilable models in the pipeline. Skip compilation.")
return
for name in compile_models:
model = getattr(self, name, None)
if model is None:
print(f"Model '{name}' not found in the pipeline.")
continue
repeated_blocks = getattr(model, "_repeated_blocks", None)
# regional compilation for repeated blocks.
if repeated_blocks is not None:
for submod in model.modules():
if submod.__class__.__name__ in repeated_blocks:
submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
# compile the whole model.
else:
model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
print(f"{name} is compiled with mode={mode}, dynamic={dynamic}, fullgraph={fullgraph}.")
class PipelineUnitGraph:
def __init__(self):