mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 18:58:11 +00:00
support reference image
This commit is contained in:
@@ -20,10 +20,11 @@ class RoPEEmbedding(torch.nn.Module):
|
|||||||
self.axes_dim = axes_dim
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
|
||||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
def rope(self, pos: torch.Tensor, dim: int, theta: int, device="cpu") -> torch.Tensor:
|
||||||
assert dim % 2 == 0, "The dimension must be even."
|
assert dim % 2 == 0, "The dimension must be even."
|
||||||
|
|
||||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
|
scale = scale.to(device)
|
||||||
omega = 1.0 / (theta**scale)
|
omega = 1.0 / (theta**scale)
|
||||||
|
|
||||||
batch_size, seq_length = pos.shape
|
batch_size, seq_length = pos.shape
|
||||||
@@ -36,9 +37,9 @@ class RoPEEmbedding(torch.nn.Module):
|
|||||||
return out.float()
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, ids):
|
def forward(self, ids, device="cpu"):
|
||||||
n_axes = ids.shape[-1]
|
n_axes = ids.shape[-1]
|
||||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta, device) for i in range(n_axes)], dim=-3)
|
||||||
return emb.unsqueeze(1)
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ class FluxReferenceEmbedder(torch.nn.Module):
|
|||||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
self.idx_embedder = TimestepEmbeddings(256, 256)
|
self.idx_embedder = TimestepEmbeddings(256, 256)
|
||||||
|
|
||||||
def forward(self, image_ids, idx, dtype):
|
def forward(self, image_ids, idx, dtype, device):
|
||||||
pos_emb = self.pos_embedder(image_ids)
|
pos_emb = self.pos_embedder(image_ids, device=device)
|
||||||
idx_emb = self.idx_embedder(idx, dtype=dtype)
|
idx_emb = self.idx_embedder(idx, dtype=dtype).to(device)
|
||||||
length = pos_emb.shape[2]
|
length = pos_emb.shape[2]
|
||||||
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
|
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
|
||||||
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
|
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
|
||||||
|
|||||||
@@ -694,7 +694,7 @@ def lets_dance_flux(
|
|||||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||||
else:
|
else:
|
||||||
prompt_emb = dit.context_embedder(prompt_emb)
|
prompt_emb = dit.context_embedder(prompt_emb)
|
||||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1), device=hidden_states.device)
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
|
|
||||||
# Reference images
|
# Reference images
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class LightningModel(LightningModelForT2ILoRA):
|
|||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
# Data
|
# Data
|
||||||
text, image = batch["text"], batch["image_2"]
|
text, image = batch["instruction"], batch["image_2"]
|
||||||
image_ref = batch["image_1"]
|
image_ref = batch["image_1"]
|
||||||
|
|
||||||
# Prepare input parameters
|
# Prepare input parameters
|
||||||
@@ -77,8 +77,9 @@ class LightningModel(LightningModelForT2ILoRA):
|
|||||||
# Compute loss
|
# Compute loss
|
||||||
noise_pred = lets_dance_flux(
|
noise_pred = lets_dance_flux(
|
||||||
self.pipe.denoising_model(),
|
self.pipe.denoising_model(),
|
||||||
|
reference_embedder=self.pipe.reference_embedder,
|
||||||
hidden_states_ref=hidden_states_ref,
|
hidden_states_ref=hidden_states_ref,
|
||||||
latents=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||||
)
|
)
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
@@ -191,19 +192,23 @@ if __name__ == '__main__':
|
|||||||
SingleTaskDataset(
|
SingleTaskDataset(
|
||||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_change_add_remove.json",
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_change_add_remove.json",
|
||||||
|
height=512, width=512,
|
||||||
),
|
),
|
||||||
SingleTaskDataset(
|
SingleTaskDataset(
|
||||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
|
||||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_zoomin_zoomout.json",
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_zoomin_zoomout.json",
|
||||||
|
height=512, width=512,
|
||||||
),
|
),
|
||||||
SingleTaskDataset(
|
SingleTaskDataset(
|
||||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||||
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
|
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
|
||||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_style_transfer.json",
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_style_transfer.json",
|
||||||
|
height=512, width=512,
|
||||||
),
|
),
|
||||||
SingleTaskDataset(
|
SingleTaskDataset(
|
||||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_faceid.json",
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_faceid.json",
|
||||||
|
height=512, width=512,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
dataset_weight=(4, 2, 2, 1),
|
dataset_weight=(4, 2, 2, 1),
|
||||||
|
|||||||
Reference in New Issue
Block a user