change klein image resize to crop

This commit is contained in:
Artiprocher
2026-01-22 10:33:29 +08:00
parent 273143136c
commit 548304667f

View File

@@ -1,4 +1,4 @@
import torch, math import torch, math, torchvision
from PIL import Image from PIL import Image
from typing import Union from typing import Union
from tqdm import tqdm from tqdm import tqdm
@@ -478,9 +478,20 @@ class Flux2Unit_EditImageEmbedder(PipelineUnit):
height = round(height / 32) * 32 height = round(height / 32) * 32
return width, height return width, height
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def edit_image_auto_resize(self, edit_image): def edit_image_auto_resize(self, edit_image):
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
return edit_image.resize((calculated_width, calculated_height)) return self.crop_and_resize(edit_image, calculated_height, calculated_width)
def process_image_ids(self, image_latents, scale=10): def process_image_ids(self, image_latents, scale=10):
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]