qwen-image

This commit is contained in:
Artiprocher
2025-08-04 20:24:13 +08:00
parent db124fa6bc
commit c35f2d8bda
22 changed files with 2905 additions and 14 deletions

View File

@@ -45,6 +45,7 @@ def get_timestep_embedding(
scale: float = 1,
max_period: int = 10000,
computation_device = None,
align_dtype_to_timestep = False,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -63,6 +64,8 @@ def get_timestep_embedding(
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent).to(timesteps.device)
if align_dtype_to_timestep:
emb = emb.to(timesteps.dtype)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
@@ -82,12 +85,14 @@ def get_timestep_embedding(
class TemporalTimesteps(torch.nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.computation_device = computation_device
self.scale = scale
self.align_dtype_to_timestep = align_dtype_to_timestep
def forward(self, timesteps):
t_emb = get_timestep_embedding(
@@ -96,6 +101,8 @@ class TemporalTimesteps(torch.nn.Module):
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
computation_device=self.computation_device,
scale=self.scale,
align_dtype_to_timestep=self.align_dtype_to_timestep,
)
return t_emb