diff --git a/diffsynth/models/flux_reference_embedder.py b/diffsynth/models/flux_reference_embedder.py index e5c7a53..1593a79 100644 --- a/diffsynth/models/flux_reference_embedder.py +++ b/diffsynth/models/flux_reference_embedder.py @@ -9,6 +9,7 @@ class FluxReferenceEmbedder(torch.nn.Module): super().__init__() self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) self.idx_embedder = TimestepEmbeddings(256, 256) + self.proj = torch.nn.Linear(3072, 3072) def forward(self, image_ids, idx, dtype, device): pos_emb = self.pos_embedder(image_ids, device=device) @@ -18,3 +19,13 @@ class FluxReferenceEmbedder(torch.nn.Module): idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length) image_rotary_emb = pos_emb + idx_emb return image_rotary_emb + + def init(self): + self.idx_embedder.timestep_embedder[-1].load_state_dict({ + "weight": torch.zeros((256, 256)), + "bias": torch.zeros((256,)) + }), + self.proj.load_state_dict({ + "weight": torch.eye(3072), + "bias": torch.zeros((3072,)) + }) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index a165e13..03c8b3c 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -709,6 +709,7 @@ def lets_dance_flux( hidden_states_ref = dit.patchify(hidden_states_ref) hidden_states_ref = dit.x_embedder(hidden_states_ref) hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C") + hidden_states_ref = reference_embedder.proj(hidden_states_ref) hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1) # TeaCache diff --git a/test.py b/test.py new file mode 100644 index 0000000..da929c5 --- /dev/null +++ b/test.py @@ -0,0 +1,26 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict +from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder +from PIL import Image + + +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" +]) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +pipe.reference_embedder = FluxReferenceEmbedder().to(dtype=torch.bfloat16, device="cuda") +pipe.reference_embedder.init() + +for i in range(4): + image = pipe( + prompt="a girl.", + num_inference_steps=30, embedded_guidance=3.5, + height=512, width=512, + reference_images=[Image.open("data/example4.jpg").resize((512, 512))] + ) + image.save(f"image_{i}.jpg") \ No newline at end of file diff --git a/train_flux_reference.py b/train_flux_reference.py index 18b47a6..8d15102 100644 --- a/train_flux_reference.py +++ b/train_flux_reference.py @@ -41,15 +41,15 @@ class LightningModel(LightningModelForT2ILoRA): self.freeze_parameters() self.pipe.reference_embedder.requires_grad_(True) self.pipe.reference_embedder.train() - self.add_lora_to_model( - self.pipe.denoising_model(), - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_target_modules=lora_target_modules, - init_lora_weights=init_lora_weights, - pretrained_lora_path=pretrained_lora_path, - state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format - ) + # self.add_lora_to_model( + # self.pipe.denoising_model(), + # lora_rank=lora_rank, + # lora_alpha=lora_alpha, + # lora_target_modules=lora_target_modules, + # init_lora_weights=init_lora_weights, + # pretrained_lora_path=pretrained_lora_path, + # state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format + # ) def training_step(self, batch, batch_idx):