mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support wan tensor parallel (preview)
This commit is contained in:
@@ -108,6 +108,16 @@ class RMSNorm(nn.Module):
|
||||
return self.norm(x.float()).to(dtype) * self.weight
|
||||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
def __init__(self, num_heads):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, q, k, v):
|
||||
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@@ -121,17 +131,16 @@ class SelfAttention(nn.Module):
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = RMSNorm(dim, eps=eps)
|
||||
self.norm_k = RMSNorm(dim, eps=eps)
|
||||
|
||||
self.attn = AttentionModule(self.num_heads)
|
||||
|
||||
def forward(self, x, freqs):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
x = flash_attention(
|
||||
q=rope_apply(q, freqs, self.num_heads),
|
||||
k=rope_apply(k, freqs, self.num_heads),
|
||||
v=v,
|
||||
num_heads=self.num_heads
|
||||
)
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
x = self.attn(q, k, v)
|
||||
return self.o(x)
|
||||
|
||||
|
||||
@@ -153,6 +162,8 @@ class CrossAttention(nn.Module):
|
||||
self.k_img = nn.Linear(dim, dim)
|
||||
self.v_img = nn.Linear(dim, dim)
|
||||
self.norm_k_img = RMSNorm(dim, eps=eps)
|
||||
|
||||
self.attn = AttentionModule(self.num_heads)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
if self.has_image_input:
|
||||
@@ -163,7 +174,7 @@ class CrossAttention(nn.Module):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(ctx))
|
||||
v = self.v(ctx)
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
x = self.attn(q, k, v)
|
||||
if self.has_image_input:
|
||||
k_img = self.norm_k_img(self.k_img(img))
|
||||
v_img = self.v_img(img)
|
||||
|
||||
@@ -225,7 +225,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
|
||||
# Initialize noise
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
|
||||
|
||||
@@ -37,7 +37,7 @@ class FlowMatchScheduler():
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
|
||||
Reference in New Issue
Block a user