fix base model

This commit is contained in:
Artiprocher
2025-04-14 13:11:45 +08:00
parent 7a06a58f49
commit 2a5355b7cb
4 changed files with 47 additions and 9 deletions

View File

@@ -9,6 +9,7 @@ class FluxReferenceEmbedder(torch.nn.Module):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.idx_embedder = TimestepEmbeddings(256, 256)
self.proj = torch.nn.Linear(3072, 3072)
def forward(self, image_ids, idx, dtype, device):
pos_emb = self.pos_embedder(image_ids, device=device)
@@ -18,3 +19,13 @@ class FluxReferenceEmbedder(torch.nn.Module):
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
image_rotary_emb = pos_emb + idx_emb
return image_rotary_emb
def init(self):
self.idx_embedder.timestep_embedder[-1].load_state_dict({
"weight": torch.zeros((256, 256)),
"bias": torch.zeros((256,))
}),
self.proj.load_state_dict({
"weight": torch.eye(3072),
"bias": torch.zeros((3072,))
})

View File

@@ -709,6 +709,7 @@ def lets_dance_flux(
hidden_states_ref = dit.patchify(hidden_states_ref)
hidden_states_ref = dit.x_embedder(hidden_states_ref)
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
hidden_states_ref = reference_embedder.proj(hidden_states_ref)
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
# TeaCache

26
test.py Normal file
View File

@@ -0,0 +1,26 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
from PIL import Image
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
pipe.reference_embedder = FluxReferenceEmbedder().to(dtype=torch.bfloat16, device="cuda")
pipe.reference_embedder.init()
for i in range(4):
image = pipe(
prompt="a girl.",
num_inference_steps=30, embedded_guidance=3.5,
height=512, width=512,
reference_images=[Image.open("data/example4.jpg").resize((512, 512))]
)
image.save(f"image_{i}.jpg")

View File

@@ -41,15 +41,15 @@ class LightningModel(LightningModelForT2ILoRA):
self.freeze_parameters()
self.pipe.reference_embedder.requires_grad_(True)
self.pipe.reference_embedder.train()
self.add_lora_to_model(
self.pipe.denoising_model(),
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_target_modules=lora_target_modules,
init_lora_weights=init_lora_weights,
pretrained_lora_path=pretrained_lora_path,
state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
)
# self.add_lora_to_model(
# self.pipe.denoising_model(),
# lora_rank=lora_rank,
# lora_alpha=lora_alpha,
# lora_target_modules=lora_target_modules,
# init_lora_weights=init_lora_weights,
# pretrained_lora_path=pretrained_lora_path,
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
# )
def training_step(self, batch, batch_idx):