bugfix & update doc

This commit is contained in:
Artiprocher
2025-12-04 10:35:14 +08:00
parent 4a15618080
commit 7747f38561
6 changed files with 768 additions and 23 deletions

View File

@@ -62,7 +62,7 @@ class DiskMap:
param = self.files[file_id].get_tensor(name)
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
param = param.to(self.torch_dtype)
if param.device == "cpu":
if isinstance(param, torch.Tensor) and param.device == "cpu":
param = param.clone()
if isinstance(param, torch.Tensor):
self.num_params += param.numel()

View File

@@ -537,7 +537,7 @@ class ZImageDiT(nn.Module):
x = torch.cat(x, dim=0)
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device)
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
@@ -565,7 +565,7 @@ class ZImageDiT(nn.Module):
cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device)
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))