mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
flex full support
This commit is contained in:
@@ -362,7 +362,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return {}, controlnet_image
|
return {}, controlnet_image
|
||||||
|
|
||||||
|
|
||||||
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, tiled=False, tile_size=64, tile_stride=32):
|
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32):
|
||||||
if self.dit.input_dim == 196:
|
if self.dit.input_dim == 196:
|
||||||
if flex_inpaint_image is None:
|
if flex_inpaint_image is None:
|
||||||
flex_inpaint_image = torch.zeros_like(latents)
|
flex_inpaint_image = torch.zeros_like(latents)
|
||||||
@@ -375,13 +375,16 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
|
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
|
||||||
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
|
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
|
||||||
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
|
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
|
||||||
|
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
|
||||||
if flex_control_image is None:
|
if flex_control_image is None:
|
||||||
flex_control_image = torch.zeros_like(latents)
|
flex_control_image = torch.zeros_like(latents)
|
||||||
else:
|
else:
|
||||||
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
|
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
|
||||||
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
||||||
flex_kwargs = {"flex_condition": flex_condition}
|
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
|
||||||
|
flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))]
|
||||||
|
flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
|
||||||
else:
|
else:
|
||||||
flex_kwargs = {}
|
flex_kwargs = {}
|
||||||
return flex_kwargs
|
return flex_kwargs
|
||||||
@@ -427,6 +430,8 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
flex_inpaint_image=None,
|
flex_inpaint_image=None,
|
||||||
flex_inpaint_mask=None,
|
flex_inpaint_mask=None,
|
||||||
flex_control_image=None,
|
flex_control_image=None,
|
||||||
|
flex_control_strength=0.5,
|
||||||
|
flex_control_stop=0.5,
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_l1_thresh=None,
|
tea_cache_l1_thresh=None,
|
||||||
# Tile
|
# Tile
|
||||||
@@ -635,6 +640,8 @@ def lets_dance_flux(
|
|||||||
id_emb=None,
|
id_emb=None,
|
||||||
infinityou_guidance=None,
|
infinityou_guidance=None,
|
||||||
flex_condition=None,
|
flex_condition=None,
|
||||||
|
flex_uncondition=None,
|
||||||
|
flex_control_stop_timestep=None,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -686,8 +693,12 @@ def lets_dance_flux(
|
|||||||
controlnet_frames, **controlnet_extra_kwargs
|
controlnet_frames, **controlnet_extra_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Flex
|
||||||
if flex_condition is not None:
|
if flex_condition is not None:
|
||||||
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
if timestep.tolist()[0] >= flex_control_stop_timestep:
|
||||||
|
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
|
||||||
|
|
||||||
if image_ids is None:
|
if image_ids is None:
|
||||||
image_ids = dit.prepare_image_ids(hidden_states)
|
image_ids = dit.prepare_image_ids(hidden_states)
|
||||||
|
|||||||
@@ -23,22 +23,27 @@ image = pipe(
|
|||||||
image.save("image_1.jpg")
|
image.save("image_1.jpg")
|
||||||
|
|
||||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||||
mask[0:300, 300:800] = 255
|
mask[200:400, 400:700] = 255
|
||||||
mask = Image.fromarray(mask)
|
mask = Image.fromarray(mask)
|
||||||
mask.save("image_mask.jpg")
|
mask.save("image_mask.jpg")
|
||||||
|
|
||||||
inpaint_image = np.array(image)
|
inpaint_image = image
|
||||||
inpaint_image[0:300, 300:800] = 0
|
|
||||||
inpaint_image = Image.fromarray(inpaint_image)
|
image = pipe(
|
||||||
inpaint_image.save("image_inpaint.jpg")
|
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
|
||||||
|
num_inference_steps=50, embedded_guidance=3.5,
|
||||||
|
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
||||||
|
seed=4
|
||||||
|
)
|
||||||
|
image.save("image_2.jpg")
|
||||||
|
|
||||||
control_image = Annotator("canny")(image)
|
control_image = Annotator("canny")(image)
|
||||||
control_image.save("image_control.jpg")
|
control_image.save("image_control.jpg")
|
||||||
|
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt="portrait of a beautiful Asian girl, long hair, yellow t-shirt, sunshine, beach",
|
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
|
||||||
num_inference_steps=50, embedded_guidance=3.5,
|
num_inference_steps=50, embedded_guidance=3.5,
|
||||||
flex_control_image=control_image,
|
flex_control_image=control_image,
|
||||||
seed=4
|
seed=4
|
||||||
)
|
)
|
||||||
image.save("image_2.jpg")
|
image.save("image_3.jpg")
|
||||||
|
|||||||
Reference in New Issue
Block a user