mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
flux series vram management
This commit is contained in:
@@ -104,6 +104,7 @@ class InfiniteYouImageProjector(nn.Module):
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
latents = latents.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
|
||||
@@ -40,7 +40,8 @@ class SingleValueEncoder(torch.nn.Module):
|
||||
emb = self.prefer_proj(value).to(dtype)
|
||||
emb = self.prefer_value_embedder(emb).squeeze(0)
|
||||
base_embeddings = emb.expand(self.prefer_len, -1)
|
||||
learned_embeddings = base_embeddings + self.positional_embedding
|
||||
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
|
||||
learned_embeddings = base_embeddings + positional_embedding
|
||||
return learned_embeddings
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -162,7 +162,7 @@ class TimestepEmbedder(nn.Module):
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(
|
||||
t, self.frequency_embedding_size, self.max_period
|
||||
).type(self.mlp[0].weight.dtype) # type: ignore
|
||||
).type(t.dtype) # type: ignore
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
@@ -656,7 +656,7 @@ class Qwen2Connector(torch.nn.Module):
|
||||
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||
x_mean = (x * mask_float).sum(
|
||||
dim=1
|
||||
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
|
||||
) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(dtype=x.dtype, device=x.device))
|
||||
|
||||
global_out=self.global_proj_out(x_mean)
|
||||
encoder_hidden_states = self.S(x,t,mask)
|
||||
|
||||
Reference in New Issue
Block a user