diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 610e30a..d5dc35b 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -1,4 +1,4 @@ -import torch, math +import torch, math, torchvision from PIL import Image from typing import Union from tqdm import tqdm @@ -477,10 +477,21 @@ class Flux2Unit_EditImageEmbedder(PipelineUnit): width = round(width / 32) * 32 height = round(height / 32) * 32 return width, height + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image def edit_image_auto_resize(self, edit_image): calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) - return edit_image.resize((calculated_width, calculated_height)) + return self.crop_and_resize(edit_image, calculated_height, calculated_width) def process_image_ids(self, image_latents, scale=10): t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]