From 945b43492e394179ee668dbf9e5ff87dae6312ee Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 7 Mar 2025 17:43:30 +0800 Subject: [PATCH] load hunyuani2v model --- diffsynth/configs/model_config.py | 22 +++ diffsynth/models/hunyuan_video_dit.py | 12 +- .../models/hunyuan_video_text_encoder.py | 31 ++-- diffsynth/prompters/hunyuan_video_prompter.py | 147 +++++++++++++++++- .../tokenizer_2/preprocessor_config.json | 45 ++++++ examples/HunyuanVideo/hunyuanvideo_i2v.py | 88 +++++++++++ 6 files changed, 327 insertions(+), 18 deletions(-) create mode 100644 diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/preprocessor_config.json create mode 100644 examples/HunyuanVideo/hunyuanvideo_i2v.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 676af03..718ba73 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -112,6 +112,7 @@ model_loader_configs = [ (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"), (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), + (None, "ae3c22aaa28bfae6f3688f796c9814ae", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"), (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"), (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), @@ -135,6 +136,7 @@ huggingface_model_loader_configs = [ ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"), ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"), ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"), + ("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"), ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"), ] patch_model_loader_configs = [ @@ -677,6 +679,25 @@ preset_models_on_modelscope = { "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt" ], }, + "HunyuanVideoI2V":{ + "file_list": [ + ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"), + ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"), + ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"), + ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"), + ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"), + ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"), + ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"), + ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"), + ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers") + ], + "load_path": [ + "models/HunyuanVideoI2V/text_encoder/model.safetensors", + "models/HunyuanVideoI2V/text_encoder_2", + "models/HunyuanVideoI2V/vae/pytorch_model.pt", + "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt" + ], + }, "HunyuanVideo-fp8":{ "file_list": [ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"), @@ -753,4 +774,5 @@ Preset_model_id: TypeAlias = Literal[ "StableDiffusion3.5-medium", "HunyuanVideo", "HunyuanVideo-fp8", + "HunyuanVideoI2V", ] diff --git a/diffsynth/models/hunyuan_video_dit.py b/diffsynth/models/hunyuan_video_dit.py index 4f4b49c..f008a87 100644 --- a/diffsynth/models/hunyuan_video_dit.py +++ b/diffsynth/models/hunyuan_video_dit.py @@ -4,6 +4,7 @@ from .utils import init_weights_on_device from einops import rearrange, repeat from tqdm import tqdm from typing import Union, Tuple, List +from .utils import hash_state_dict_keys def HunyuanVideoRope(latents): @@ -555,7 +556,7 @@ class FinalLayer(torch.nn.Module): class HunyuanVideoDiT(torch.nn.Module): - def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40): + def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True): super().__init__() self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size) self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size) @@ -565,7 +566,7 @@ class HunyuanVideoDiT(torch.nn.Module): torch.nn.SiLU(), torch.nn.Linear(hidden_size, hidden_size) ) - self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") + self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)]) self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)]) self.final_layer = FinalLayer(hidden_size) @@ -610,7 +611,9 @@ class HunyuanVideoDiT(torch.nn.Module): ): B, C, T, H, W = x.shape - vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32) + vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + if self.guidance_in is not None: + vec += self.guidance_in(guidance * 1000, dtype=torch.float32) img = self.img_in(x) txt = self.txt_in(prompt_emb, t, text_mask) @@ -783,6 +786,7 @@ class HunyuanVideoDiTStateDictConverter: pass def from_civitai(self, state_dict): + origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True) if "module" in state_dict: state_dict = state_dict["module"] direct_dict = { @@ -882,4 +886,6 @@ class HunyuanVideoDiTStateDictConverter: state_dict_[name_] = param else: pass + if origin_hash_key == "ae3c22aaa28bfae6f3688f796c9814ae": + return state_dict_, {"in_channels": 33, "guidance_embed":False} return state_dict_ diff --git a/diffsynth/models/hunyuan_video_text_encoder.py b/diffsynth/models/hunyuan_video_text_encoder.py index df5755f..ce7a680 100644 --- a/diffsynth/models/hunyuan_video_text_encoder.py +++ b/diffsynth/models/hunyuan_video_text_encoder.py @@ -1,24 +1,18 @@ -from transformers import LlamaModel, LlamaConfig, DynamicCache +from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration from copy import deepcopy import torch class HunyuanVideoLLMEncoder(LlamaModel): + def __init__(self, config: LlamaConfig): super().__init__(config) self.auto_offload = False - def enable_auto_offload(self, **kwargs): self.auto_offload = True - - def forward( - self, - input_ids, - attention_mask, - hidden_state_skip_layer=2 - ): + def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2): embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens inputs_embeds = embed_tokens(input_ids) @@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel): break return hidden_states + + +class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration): + + def __init__(self, config): + super().__init__(config) + self.auto_offload = False + + def enable_auto_offload(self, **kwargs): + self.auto_offload = True + + # TODO: implement the low VRAM inference for MLLM. + def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2): + outputs = super().forward(input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + pixel_values=pixel_values) + hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] + return hidden_state diff --git a/diffsynth/prompters/hunyuan_video_prompter.py b/diffsynth/prompters/hunyuan_video_prompter.py index 3b5a9fe..26dc5c3 100644 --- a/diffsynth/prompters/hunyuan_video_prompter.py +++ b/diffsynth/prompters/hunyuan_video_prompter.py @@ -1,8 +1,9 @@ from .base_prompter import BasePrompter from ..models.sd3_text_encoder import SD3TextEncoder1 -from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder -from transformers import CLIPTokenizer, LlamaTokenizerFast +from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder +from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor import os, torch +from typing import Union PROMPT_TEMPLATE_ENCODE = ( "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " @@ -18,6 +19,24 @@ PROMPT_TEMPLATE_ENCODE_VIDEO = ( "5. camera angles, movements, and transitions used in the video:<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") +PROMPT_TEMPLATE_ENCODE_I2V = ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + +PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + PROMPT_TEMPLATE = { "dit-llm-encode": { "template": PROMPT_TEMPLATE_ENCODE, @@ -27,6 +46,22 @@ PROMPT_TEMPLATE = { "template": PROMPT_TEMPLATE_ENCODE_VIDEO, "crop_start": 95, }, + "dit-llm-encode-i2v": { + "template": PROMPT_TEMPLATE_ENCODE_I2V, + "crop_start": 36, + "image_emb_start": 5, + "image_emb_end": 581, + "image_emb_len": 576, + "double_return_token_id": 271 + }, + "dit-llm-encode-video-i2v": { + "template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, + "crop_start": 103, + "image_emb_start": 5, + "image_emb_end": 581, + "image_emb_len": 576, + "double_return_token_id": 271 + }, } NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" @@ -52,13 +87,27 @@ class HunyuanVideoPrompter(BasePrompter): self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right') self.text_encoder_1: SD3TextEncoder1 = None self.text_encoder_2: HunyuanVideoLLMEncoder = None + self.i2v_mode = False self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode'] self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video'] - def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None): + def fetch_models(self, + text_encoder_1: SD3TextEncoder1 = None, + text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None): self.text_encoder_1 = text_encoder_1 self.text_encoder_2 = text_encoder_2 + if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder): + # processor + # TODO: may need to replace processor with local implementation + base_path = os.path.dirname(os.path.dirname(__file__)) + tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2") + self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path) + # template + self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v'] + self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v'] + # mode setting + self.i2v_mode = True def apply_text_to_template(self, text, template): assert isinstance(template, str) @@ -107,8 +156,91 @@ class HunyuanVideoPrompter(BasePrompter): return last_hidden_state, attention_mask + def encode_prompt_using_mllm(self, + prompt, + images, + max_length, + device, + crop_start, + hidden_state_skip_layer=2, + use_attention_mask=True, + image_embed_interleave=2): + image_outputs = self.processor(images, return_tensors="pt")[ + "pixel_values" + ].to(device) + max_length += crop_start + inputs = self.tokenizer_2(prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True) + input_ids = inputs.input_ids.to(device) + attention_mask = inputs.attention_mask.to(device) + last_hidden_state = self.text_encoder_2(input_ids=input_ids, + attention_mask=attention_mask, + hidden_state_skip_layer=hidden_state_skip_layer, + pixel_values=image_outputs) + + text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576)) + image_crop_start = self.prompt_template_video.get("image_emb_start", 5) + image_crop_end = self.prompt_template_video.get("image_emb_end", 581) + batch_indices, last_double_return_token_indices = torch.where( + input_ids == self.prompt_template_video.get("double_return_token_id", 271)) + if last_double_return_token_indices.shape[0] == 3: + # in case the prompt is too long + last_double_return_token_indices = torch.cat(( + last_double_return_token_indices, + torch.tensor([input_ids.shape[-1]]), + )) + batch_indices = torch.cat((batch_indices, torch.tensor([0]))) + last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1]) + batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1] + assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4) + assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576)) + attention_mask_assistant_crop_start = (last_double_return_token_indices - 4) + attention_mask_assistant_crop_end = last_double_return_token_indices + text_last_hidden_state = [] + text_attention_mask = [] + image_last_hidden_state = [] + image_attention_mask = [] + for i in range(input_ids.shape[0]): + text_last_hidden_state.append( + torch.cat([ + last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()], + last_hidden_state[i, assistant_crop_end[i].item():], + ])) + text_attention_mask.append( + torch.cat([ + attention_mask[ + i, + crop_start:attention_mask_assistant_crop_start[i].item(), + ], + attention_mask[i, attention_mask_assistant_crop_end[i].item():], + ]) if use_attention_mask else None) + image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end]) + image_attention_mask.append( + torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device). + to(attention_mask.dtype) if use_attention_mask else None) + + text_last_hidden_state = torch.stack(text_last_hidden_state) + text_attention_mask = torch.stack(text_attention_mask) + image_last_hidden_state = torch.stack(image_last_hidden_state) + image_attention_mask = torch.stack(image_attention_mask) + + image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :] + image_attention_mask = image_attention_mask[:, ::image_embed_interleave] + + assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and + image_last_hidden_state.shape[0] == image_attention_mask.shape[0]) + + last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1) + attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1) + + return last_hidden_state, attention_mask + def encode_prompt(self, prompt, + images=None, positive=True, device="cuda", clip_sequence_length=77, @@ -136,8 +268,11 @@ class HunyuanVideoPrompter(BasePrompter): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device) # LLM - prompt_emb, attention_mask = self.encode_prompt_using_llm( - prompt_formated, llm_sequence_length, device, crop_start, - hidden_state_skip_layer, use_attention_mask) + if images is None: + prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start, + hidden_state_skip_layer, use_attention_mask) + else: + prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device, + crop_start, hidden_state_skip_layer, use_attention_mask) return prompt_emb, pooled_prompt_emb, attention_mask diff --git a/diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/preprocessor_config.json b/diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/preprocessor_config.json new file mode 100644 index 0000000..cf5bb0c --- /dev/null +++ b/diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/preprocessor_config.json @@ -0,0 +1,45 @@ +{ + "_valid_processor_keys": [ + "images", + "do_resize", + "size", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format" + ], + "crop_size": { + "height": 336, + "width": 336 + }, + "do_center_crop": true, + "do_convert_rgb": true, + "do_normalize": true, + "do_rescale": true, + "do_resize": true, + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_processor_type": "CLIPImageProcessor", + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "processor_class": "LlavaProcessor", + "resample": 3, + "rescale_factor": 0.00392156862745098, + "size": { + "shortest_edge": 336 + } +} diff --git a/examples/HunyuanVideo/hunyuanvideo_i2v.py b/examples/HunyuanVideo/hunyuanvideo_i2v.py new file mode 100644 index 0000000..26d28a1 --- /dev/null +++ b/examples/HunyuanVideo/hunyuanvideo_i2v.py @@ -0,0 +1,88 @@ +import torch +from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video +from diffsynth.prompters.hunyuan_video_prompter import HunyuanVideoPrompter +from PIL import Image +import numpy as np +import torchvision.transforms as transforms + + +def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0): + num_patches = round((base_size / patch_size)**2) + assert max_ratio >= 1.0 + crop_size_list = [] + wp, hp = num_patches, 1 + while wp > 0: + if max(wp, hp) / min(wp, hp) <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + if (hp + 1) * wp <= num_patches: + hp += 1 + else: + wp -= 1 + return crop_size_list + + +def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): + aspect_ratio = float(height) / float(width) + closest_ratio_id = np.abs(ratios - aspect_ratio).argmin() + closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio)) + return buckets[closest_ratio_id], float(closest_ratio) + + +def prepare_vae_inputs(semantic_images, i2v_resolution="720p"): + if i2v_resolution == "720p": + bucket_hw_base_size = 960 + elif i2v_resolution == "540p": + bucket_hw_base_size = 720 + elif i2v_resolution == "360p": + bucket_hw_base_size = 480 + else: + raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]") + origin_size = semantic_images[0].size + + crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32) + aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) + closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list) + ref_image_transform = transforms.Compose([ + transforms.Resize(closest_size), + transforms.CenterCrop(closest_size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]) + ]) + + semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images] + semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2) + return semantic_image_pixel_values + + +model_manager = ModelManager() + +# The other modules are loaded in float16. + +model_manager.load_models( + [ + "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt" + ], + torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization. + device="cuda" +) + +model_manager.load_models( + [ + "models/HunyuanVideo/text_encoder/model.safetensors", + "models/HunyuanVideoI2V/text_encoder_2", + 'models/HunyuanVideoI2V/vae/pytorch_model.pt' + + ], + torch_dtype=torch.float16, + device="cuda" +) +# The computation device is "cuda". +pipe = HunyuanVideoPipeline.from_model_manager( + model_manager, + torch_dtype=torch.bfloat16, + device="cuda", + enable_vram_management=False +) +# Although you have enough VRAM, we still recommend you to enable offload. +pipe.enable_cpu_offload() +print() \ No newline at end of file