mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -31,6 +31,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.controlnet: FluxMultiControlNetManager = None
|
||||
self.ipadapter: FluxIpAdapter = None
|
||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||
self.infinityou_processor: InfinitYou = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||
|
||||
|
||||
@@ -162,6 +163,11 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||
|
||||
# InfiniteYou
|
||||
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
|
||||
if self.image_proj_model is not None:
|
||||
self.infinityou_processor = InfinitYou(device=self.device)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||
@@ -347,6 +353,13 @@ class FluxImagePipeline(BasePipeline):
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||
|
||||
|
||||
def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
|
||||
if self.infinityou_processor is not None and id_image is not None:
|
||||
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
else:
|
||||
return {}, controlnet_image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -382,6 +395,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
eligen_entity_masks=None,
|
||||
enable_eligen_on_negative=False,
|
||||
enable_eligen_inpaint=False,
|
||||
# InfiniteYou
|
||||
infinityou_id_image=None,
|
||||
infinityou_guidance=1.0,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
@@ -409,6 +425,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# InfiniteYou
|
||||
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
|
||||
# Entity control
|
||||
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
||||
|
||||
@@ -430,7 +449,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
@@ -447,7 +466,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -467,6 +486,58 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
|
||||
|
||||
class InfinitYou:
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
from facexlib.recognition import init_recognition_model
|
||||
from insightface.app import FaceAnalysis
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
insightface_root_path = 'models/InfiniteYou/insightface'
|
||||
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
|
||||
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
|
||||
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
|
||||
self.arcface_model = init_recognition_model('arcface', device=self.device)
|
||||
|
||||
def _detect_face(self, id_image_cv2):
|
||||
face_info = self.app_640.get(id_image_cv2)
|
||||
if len(face_info) > 0:
|
||||
return face_info
|
||||
face_info = self.app_320.get(id_image_cv2)
|
||||
if len(face_info) > 0:
|
||||
return face_info
|
||||
face_info = self.app_160.get(id_image_cv2)
|
||||
return face_info
|
||||
|
||||
def extract_arcface_bgr_embedding(self, in_image, landmark):
|
||||
from insightface.utils import face_align
|
||||
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
|
||||
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
|
||||
arc_face_image = 2 * arc_face_image - 1
|
||||
arc_face_image = arc_face_image.contiguous().to(self.device)
|
||||
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
|
||||
return face_emb
|
||||
|
||||
def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
|
||||
import cv2
|
||||
if id_image is None:
|
||||
return {'id_emb': None}, controlnet_image
|
||||
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
|
||||
face_info = self._detect_face(id_image_cv2)
|
||||
if len(face_info) == 0:
|
||||
raise ValueError('No face detected in the input ID image')
|
||||
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
|
||||
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
|
||||
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
|
||||
if controlnet_image is None:
|
||||
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
|
||||
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
|
||||
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
|
||||
|
||||
|
||||
class TeaCache:
|
||||
@@ -529,6 +600,8 @@ def lets_dance_flux(
|
||||
entity_prompt_emb=None,
|
||||
entity_masks=None,
|
||||
ipadapter_kwargs_list={},
|
||||
id_emb=None,
|
||||
infinityou_guidance=None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -573,6 +646,9 @@ def lets_dance_flux(
|
||||
"tile_size": tile_size,
|
||||
"tile_stride": tile_stride,
|
||||
}
|
||||
if id_emb is not None:
|
||||
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
|
||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user