mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
support size checker
This commit is contained in:
@@ -7,14 +7,26 @@ from torchvision.transforms import GaussianBlur
|
|||||||
|
|
||||||
class BasePipeline(torch.nn.Module):
|
class BasePipeline(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.torch_dtype = torch_dtype
|
self.torch_dtype = torch_dtype
|
||||||
|
self.height_division_factor = height_division_factor
|
||||||
|
self.width_division_factor = width_division_factor
|
||||||
self.cpu_offload = False
|
self.cpu_offload = False
|
||||||
self.model_names = []
|
self.model_names = []
|
||||||
|
|
||||||
|
|
||||||
|
def check_resize_height_width(self, height, width):
|
||||||
|
if height % self.height_division_factor != 0:
|
||||||
|
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||||
|
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
|
||||||
|
if width % self.width_division_factor != 0:
|
||||||
|
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||||
|
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
|
||||||
def preprocess_image(self, image):
|
def preprocess_image(self, image):
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||||
return image
|
return image
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from ..models.tiler import FastTileWorker
|
|||||||
class FluxImagePipeline(BasePipeline):
|
class FluxImagePipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
||||||
self.scheduler = FlowMatchScheduler()
|
self.scheduler = FlowMatchScheduler()
|
||||||
self.prompter = FluxPrompter()
|
self.prompter = FluxPrompter()
|
||||||
# models
|
# models
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ class ImageSizeManager:
|
|||||||
class HunyuanDiTImagePipeline(BasePipeline):
|
class HunyuanDiTImagePipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
||||||
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
||||||
self.prompter = HunyuanDiTPrompter()
|
self.prompter = HunyuanDiTPrompter()
|
||||||
self.image_size_manager = ImageSizeManager()
|
self.image_size_manager = ImageSizeManager()
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from tqdm import tqdm
|
|||||||
class SD3ImagePipeline(BasePipeline):
|
class SD3ImagePipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
||||||
self.scheduler = FlowMatchScheduler()
|
self.scheduler = FlowMatchScheduler()
|
||||||
self.prompter = SD3Prompter()
|
self.prompter = SD3Prompter()
|
||||||
# models
|
# models
|
||||||
|
|||||||
Reference in New Issue
Block a user