Merge pull request #1228 from modelscope/klein-bugfix

change klein image resize to crop
This commit is contained in:
Zhongjie Duan
2026-01-22 10:34:17 +08:00
committed by GitHub

View File

@@ -1,4 +1,4 @@
import torch, math import torch, math, torchvision
from PIL import Image from PIL import Image
from typing import Union from typing import Union
from tqdm import tqdm from tqdm import tqdm
@@ -477,10 +477,21 @@ class Flux2Unit_EditImageEmbedder(PipelineUnit):
width = round(width / 32) * 32 width = round(width / 32) * 32
height = round(height / 32) * 32 height = round(height / 32) * 32
return width, height return width, height
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def edit_image_auto_resize(self, edit_image): def edit_image_auto_resize(self, edit_image):
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
return edit_image.resize((calculated_width, calculated_height)) return self.crop_and_resize(edit_image, calculated_height, calculated_width)
def process_image_ids(self, image_latents, scale=10): def process_image_ids(self, image_latents, scale=10):
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]