hunyuanvideo dit

This commit is contained in:
Artiprocher
2024-12-17 14:45:23 +08:00
parent 7c0520d029
commit 05e2028c5d
5 changed files with 706 additions and 6 deletions

View File

@@ -52,9 +52,9 @@ class PatchEmbed(torch.nn.Module):
class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out):
def __init__(self, dim_in, dim_out, computation_device=None):
super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)