support z-image-omni-base training

This commit is contained in:
Artiprocher
2026-01-05 20:04:00 +08:00
parent 5745c9f200
commit 32449a6aa0
9 changed files with 128 additions and 4 deletions

View File

@@ -626,7 +626,7 @@ class ZImageDiT(nn.Module):
# Pad token
feats_cat = torch.cat(feats, dim=0)
feats_cat[torch.cat(inner_pad_mask)] = pad_token
feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device)
feats = list(feats_cat.split(item_seqlens, dim=0))
# RoPE