refine infinityou

This commit is contained in:
Artiprocher
2025-03-31 20:19:32 +08:00
parent e88328d152
commit 4cd51fecf2
2 changed files with 79 additions and 86 deletions

View File

@@ -4,12 +4,10 @@ from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
from typing import List
import math
import torch
from tqdm import tqdm
import numpy as np
from PIL import Image
import cv2
from ..models.tiler import FastTileWorker
from transformers import SiglipVisionModel
from copy import deepcopy
@@ -33,6 +31,7 @@ class FluxImagePipeline(BasePipeline):
self.controlnet: FluxMultiControlNetManager = None
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None
self.infinityou_processor: InfinitYou = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
@@ -167,16 +166,7 @@ class FluxImagePipeline(BasePipeline):
# InfiniteYou
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
if self.image_proj_model is not None:
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis
insightface_root_path = 'models/InfiniteYou/insightface'
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
self.arcface_model = init_recognition_model('arcface', device=self.device)
self.infinityou_processor = InfinitYou(device=self.device)
@staticmethod
@@ -353,66 +343,6 @@ class FluxImagePipeline(BasePipeline):
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
stickwidth = 4
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
kps = np.array(kps)
w, h = image_pil.size
out_img = np.zeros([h, w, 3])
for i in range(len(limbSeq)):
index = limbSeq[i]
color = color_list[index[0]]
x = kps[index][:, 0]
y = kps[index][:, 1]
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
out_img = (out_img * 0.6).astype(np.uint8)
for idx_kp, kp in enumerate(kps):
color = color_list[idx_kp]
out_img = cv2.circle(out_img.copy(), (int(kp[0]), int(kp[1])), 10, color, -1)
out_img_pil = Image.fromarray(out_img.astype(np.uint8))
return out_img_pil
def extract_arcface_bgr_embedding(self, in_image, landmark):
from insightface.utils import face_align
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
arc_face_image = 2 * arc_face_image - 1
arc_face_image = arc_face_image.contiguous().to(self.device)
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
return face_emb
def _detect_face(self, id_image_cv2):
face_info = self.app_640.get(id_image_cv2)
if len(face_info) > 0:
return face_info
face_info = self.app_320.get(id_image_cv2)
if len(face_info) > 0:
return face_info
face_info = self.app_160.get(id_image_cv2)
return face_info
def prepare_infinite_you(self, id_image, controlnet_image, controlnet_guidance, height, width):
if id_image is None:
return {'id_emb': None}, controlnet_image
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
face_info = self._detect_face(id_image_cv2)
if len(face_info) == 0:
raise ValueError('No face detected in the input ID image')
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
id_emb = self.image_proj_model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
if controlnet_image is None:
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
controlnet_guidance = torch.Tensor([controlnet_guidance]).to(device=self.device, dtype=self.torch_dtype)
return {'id_emb': id_emb, 'controlnet_guidance': controlnet_guidance}, controlnet_image
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
@@ -423,6 +353,13 @@ class FluxImagePipeline(BasePipeline):
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
if self.infinityou_processor is not None and id_image is not None:
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
else:
return {}, controlnet_image
@torch.no_grad()
@@ -450,7 +387,6 @@ class FluxImagePipeline(BasePipeline):
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
controlnet_guidance=1.0,
# IP-Adapter
ipadapter_images=None,
ipadapter_scale=1.0,
@@ -460,7 +396,8 @@ class FluxImagePipeline(BasePipeline):
enable_eligen_on_negative=False,
enable_eligen_inpaint=False,
# InfiniteYou
id_image=None,
infinityou_id_image=None,
infinityou_guidance=1.0,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
@@ -489,7 +426,7 @@ class FluxImagePipeline(BasePipeline):
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# InfiniteYou
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(id_image, controlnet_image, controlnet_guidance, height, width)
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
# Entity control
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
@@ -529,7 +466,7 @@ class FluxImagePipeline(BasePipeline):
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -549,6 +486,58 @@ class FluxImagePipeline(BasePipeline):
# Offload all models
self.load_models_to_device([])
return image
class InfinitYou:
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis
self.device = device
self.torch_dtype = torch_dtype
insightface_root_path = 'models/InfiniteYou/insightface'
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
self.arcface_model = init_recognition_model('arcface', device=self.device)
def _detect_face(self, id_image_cv2):
face_info = self.app_640.get(id_image_cv2)
if len(face_info) > 0:
return face_info
face_info = self.app_320.get(id_image_cv2)
if len(face_info) > 0:
return face_info
face_info = self.app_160.get(id_image_cv2)
return face_info
def extract_arcface_bgr_embedding(self, in_image, landmark):
from insightface.utils import face_align
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
arc_face_image = 2 * arc_face_image - 1
arc_face_image = arc_face_image.contiguous().to(self.device)
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
return face_emb
def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
import cv2
if id_image is None:
return {'id_emb': None}, controlnet_image
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
face_info = self._detect_face(id_image_cv2)
if len(face_info) == 0:
raise ValueError('No face detected in the input ID image')
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
if controlnet_image is None:
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
class TeaCache:
@@ -612,7 +601,7 @@ def lets_dance_flux(
entity_masks=None,
ipadapter_kwargs_list={},
id_emb=None,
controlnet_guidance=None,
infinityou_guidance=None,
tea_cache: TeaCache = None,
**kwargs
):
@@ -659,7 +648,7 @@ def lets_dance_flux(
}
if id_emb is not None:
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': controlnet_guidance})
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
)