mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 10:48:11 +00:00
fix base model
This commit is contained in:
@@ -9,6 +9,7 @@ class FluxReferenceEmbedder(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
self.idx_embedder = TimestepEmbeddings(256, 256)
|
self.idx_embedder = TimestepEmbeddings(256, 256)
|
||||||
|
self.proj = torch.nn.Linear(3072, 3072)
|
||||||
|
|
||||||
def forward(self, image_ids, idx, dtype, device):
|
def forward(self, image_ids, idx, dtype, device):
|
||||||
pos_emb = self.pos_embedder(image_ids, device=device)
|
pos_emb = self.pos_embedder(image_ids, device=device)
|
||||||
@@ -18,3 +19,13 @@ class FluxReferenceEmbedder(torch.nn.Module):
|
|||||||
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
|
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
|
||||||
image_rotary_emb = pos_emb + idx_emb
|
image_rotary_emb = pos_emb + idx_emb
|
||||||
return image_rotary_emb
|
return image_rotary_emb
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.idx_embedder.timestep_embedder[-1].load_state_dict({
|
||||||
|
"weight": torch.zeros((256, 256)),
|
||||||
|
"bias": torch.zeros((256,))
|
||||||
|
}),
|
||||||
|
self.proj.load_state_dict({
|
||||||
|
"weight": torch.eye(3072),
|
||||||
|
"bias": torch.zeros((3072,))
|
||||||
|
})
|
||||||
|
|||||||
@@ -709,6 +709,7 @@ def lets_dance_flux(
|
|||||||
hidden_states_ref = dit.patchify(hidden_states_ref)
|
hidden_states_ref = dit.patchify(hidden_states_ref)
|
||||||
hidden_states_ref = dit.x_embedder(hidden_states_ref)
|
hidden_states_ref = dit.x_embedder(hidden_states_ref)
|
||||||
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
|
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
|
||||||
|
hidden_states_ref = reference_embedder.proj(hidden_states_ref)
|
||||||
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
|
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
|
||||||
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
|
|||||||
26
test.py
Normal file
26
test.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict
|
||||||
|
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
model_manager.load_models([
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||||
|
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||||
|
])
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
pipe.reference_embedder = FluxReferenceEmbedder().to(dtype=torch.bfloat16, device="cuda")
|
||||||
|
pipe.reference_embedder.init()
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
image = pipe(
|
||||||
|
prompt="a girl.",
|
||||||
|
num_inference_steps=30, embedded_guidance=3.5,
|
||||||
|
height=512, width=512,
|
||||||
|
reference_images=[Image.open("data/example4.jpg").resize((512, 512))]
|
||||||
|
)
|
||||||
|
image.save(f"image_{i}.jpg")
|
||||||
@@ -41,15 +41,15 @@ class LightningModel(LightningModelForT2ILoRA):
|
|||||||
self.freeze_parameters()
|
self.freeze_parameters()
|
||||||
self.pipe.reference_embedder.requires_grad_(True)
|
self.pipe.reference_embedder.requires_grad_(True)
|
||||||
self.pipe.reference_embedder.train()
|
self.pipe.reference_embedder.train()
|
||||||
self.add_lora_to_model(
|
# self.add_lora_to_model(
|
||||||
self.pipe.denoising_model(),
|
# self.pipe.denoising_model(),
|
||||||
lora_rank=lora_rank,
|
# lora_rank=lora_rank,
|
||||||
lora_alpha=lora_alpha,
|
# lora_alpha=lora_alpha,
|
||||||
lora_target_modules=lora_target_modules,
|
# lora_target_modules=lora_target_modules,
|
||||||
init_lora_weights=init_lora_weights,
|
# init_lora_weights=init_lora_weights,
|
||||||
pretrained_lora_path=pretrained_lora_path,
|
# pretrained_lora_path=pretrained_lora_path,
|
||||||
state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
|
|||||||
Reference in New Issue
Block a user