support qwen-image-layered

This commit is contained in:
Artiprocher
2025-12-19 19:06:37 +08:00
parent 11315d7a40
commit c6722b3f56
18 changed files with 417 additions and 27 deletions

View File

@@ -19,7 +19,7 @@ def get_timestep_embedding(
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent).to(timesteps.device)
emb = torch.exp(exponent)
if align_dtype_to_timestep:
emb = emb.to(timesteps.dtype)
emb = timesteps[:, None].float() * emb[None, :]
@@ -78,7 +78,7 @@ class DiffusersCompatibleTimestepProj(torch.nn.Module):
class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
if diffusers_compatible_format:
@@ -87,10 +87,16 @@ class TimestepEmbeddings(torch.nn.Module):
self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
if use_additional_t_cond:
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
def forward(self, timestep, dtype):
def forward(self, timestep, dtype, addition_t_cond=None):
time_emb = self.time_proj(timestep).to(dtype)
time_emb = self.timestep_embedder(time_emb)
if addition_t_cond is not None:
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=dtype)
time_emb = time_emb + addition_t_emb
return time_emb