support wan tensor parallel (preview)

This commit is contained in:
Artiprocher
2025-03-17 19:39:45 +08:00
parent 39890f023f
commit 04d03500ff
5 changed files with 147 additions and 9 deletions

View File

@@ -108,6 +108,16 @@ class RMSNorm(nn.Module):
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads):
super().__init__()
self.num_heads = num_heads
def forward(self, q, k, v):
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
return x
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
@@ -121,17 +131,16 @@ class SelfAttention(nn.Module):
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
x = flash_attention(
q=rope_apply(q, freqs, self.num_heads),
k=rope_apply(k, freqs, self.num_heads),
v=v,
num_heads=self.num_heads
)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
x = self.attn(q, k, v)
return self.o(x)
@@ -153,6 +162,8 @@ class CrossAttention(nn.Module):
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x: torch.Tensor, y: torch.Tensor):
if self.has_image_input:
@@ -163,7 +174,7 @@ class CrossAttention(nn.Module):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = self.attn(q, k, v)
if self.has_image_input:
k_img = self.norm_k_img(self.k_img(img))
v_img = self.v_img(img)

View File

@@ -225,7 +225,7 @@ class WanVideoPipeline(BasePipeline):
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Initialize noise
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)

View File

@@ -37,7 +37,7 @@ class FlowMatchScheduler():
self.linear_timesteps_weights = bsmntw_weighing
def step(self, model_output, timestep, sample, to_final=False):
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())