flux series vram management

This commit is contained in:
Artiprocher
2025-07-15 20:11:02 +08:00
parent cbd10fb27d
commit af6b1d4246
16 changed files with 629 additions and 27 deletions

View File

@@ -162,7 +162,7 @@ class TimestepEmbedder(nn.Module):
def forward(self, t):
t_freq = self.timestep_embedding(
t, self.frequency_embedding_size, self.max_period
).type(self.mlp[0].weight.dtype) # type: ignore
).type(t.dtype) # type: ignore
t_emb = self.mlp(t_freq)
return t_emb
@@ -656,7 +656,7 @@ class Qwen2Connector(torch.nn.Module):
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
x_mean = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(dtype=x.dtype, device=x.device))
global_out=self.global_proj_out(x_mean)
encoder_hidden_states = self.S(x,t,mask)