diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 6d3100d..fc811a3 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -20,10 +20,11 @@ class RoPEEmbedding(torch.nn.Module): self.axes_dim = axes_dim - def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + def rope(self, pos: torch.Tensor, dim: int, theta: int, device="cpu") -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + scale = scale.to(device) omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape @@ -36,9 +37,9 @@ class RoPEEmbedding(torch.nn.Module): return out.float() - def forward(self, ids): + def forward(self, ids, device="cpu"): n_axes = ids.shape[-1] - emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) + emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta, device) for i in range(n_axes)], dim=-3) return emb.unsqueeze(1) diff --git a/diffsynth/models/flux_reference_embedder.py b/diffsynth/models/flux_reference_embedder.py index 994ffaa..e5c7a53 100644 --- a/diffsynth/models/flux_reference_embedder.py +++ b/diffsynth/models/flux_reference_embedder.py @@ -10,9 +10,9 @@ class FluxReferenceEmbedder(torch.nn.Module): self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) self.idx_embedder = TimestepEmbeddings(256, 256) - def forward(self, image_ids, idx, dtype): - pos_emb = self.pos_embedder(image_ids) - idx_emb = self.idx_embedder(idx, dtype=dtype) + def forward(self, image_ids, idx, dtype, device): + pos_emb = self.pos_embedder(image_ids, device=device) + idx_emb = self.idx_embedder(idx, dtype=dtype).to(device) length = pos_emb.shape[2] pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W") idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index bf41c7d..b98f34a 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -694,7 +694,7 @@ def lets_dance_flux( prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) else: prompt_emb = dit.context_embedder(prompt_emb) - image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1), device=hidden_states.device) attention_mask = None # Reference images diff --git a/train_flux_reference.py b/train_flux_reference.py index 4cc779b..29670be 100644 --- a/train_flux_reference.py +++ b/train_flux_reference.py @@ -54,7 +54,7 @@ class LightningModel(LightningModelForT2ILoRA): def training_step(self, batch, batch_idx): # Data - text, image = batch["text"], batch["image_2"] + text, image = batch["instruction"], batch["image_2"] image_ref = batch["image_1"] # Prepare input parameters @@ -77,8 +77,9 @@ class LightningModel(LightningModelForT2ILoRA): # Compute loss noise_pred = lets_dance_flux( self.pipe.denoising_model(), + reference_embedder=self.pipe.reference_embedder, hidden_states_ref=hidden_states_ref, - latents=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, + hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, use_gradient_checkpointing=self.use_gradient_checkpointing ) loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) @@ -191,19 +192,23 @@ if __name__ == '__main__': SingleTaskDataset( "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove", metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_change_add_remove.json", + height=512, width=512, ), SingleTaskDataset( "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout", metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_zoomin_zoomout.json", + height=512, width=512, ), SingleTaskDataset( "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer", keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")), metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_style_transfer.json", + height=512, width=512, ), SingleTaskDataset( "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid", metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_faceid.json", + height=512, width=512, ), ], dataset_weight=(4, 2, 2, 1),