mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support flux any training
This commit is contained in:
@@ -277,7 +277,7 @@ class FluxLoRAConverter:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, alpha=1.0):
|
||||
def align_to_opensource_format(state_dict, alpha=None):
|
||||
prefix_rename_dict = {
|
||||
"single_blocks": "lora_unet_single_blocks",
|
||||
"blocks": "lora_unet_double_blocks",
|
||||
@@ -316,7 +316,8 @@ class FluxLoRAConverter:
|
||||
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
|
||||
state_dict_[rename] = param
|
||||
if rename.endswith("lora_up.weight"):
|
||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
||||
lora_alpha = alpha if alpha is not None else param.shape[-1]
|
||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0]
|
||||
return state_dict_
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -704,7 +704,8 @@ class FluxImageUnit_Step1x(PipelineUnit):
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
image = pipe.vae_encoder(image)
|
||||
inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image})
|
||||
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
|
||||
if inputs_shared.get("cfg_scale", 1) != 1:
|
||||
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
@@ -727,6 +728,8 @@ class FluxImageUnit_Flex(PipelineUnit):
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride):
|
||||
if pipe.dit.input_dim == 196:
|
||||
if flex_control_stop is None:
|
||||
flex_control_stop = 1
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if flex_inpaint_image is None:
|
||||
flex_inpaint_image = torch.zeros_like(latents)
|
||||
@@ -760,14 +763,15 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance):
|
||||
if infinityou_id_image is not None:
|
||||
return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance)
|
||||
return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
class InfinitYou:
|
||||
class InfinitYou(torch.nn.Module):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__()
|
||||
from facexlib.recognition import init_recognition_model
|
||||
from insightface.app import FaceAnalysis
|
||||
self.device = device
|
||||
@@ -791,16 +795,16 @@ class InfinitYou:
|
||||
face_info = self.app_160.get(id_image_cv2)
|
||||
return face_info
|
||||
|
||||
def extract_arcface_bgr_embedding(self, in_image, landmark):
|
||||
def extract_arcface_bgr_embedding(self, in_image, landmark, device):
|
||||
from insightface.utils import face_align
|
||||
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
|
||||
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
|
||||
arc_face_image = 2 * arc_face_image - 1
|
||||
arc_face_image = arc_face_image.contiguous().to(self.device)
|
||||
arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype)
|
||||
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
|
||||
return face_emb
|
||||
|
||||
def prepare_infinite_you(self, model, id_image, infinityou_guidance):
|
||||
def prepare_infinite_you(self, model, id_image, infinityou_guidance, device):
|
||||
import cv2
|
||||
if id_image is None:
|
||||
return {'id_emb': None}
|
||||
@@ -809,9 +813,9 @@ class InfinitYou:
|
||||
if len(face_info) == 0:
|
||||
raise ValueError('No face detected in the input ID image')
|
||||
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
|
||||
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
|
||||
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device)
|
||||
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
|
||||
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
|
||||
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype)
|
||||
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user