mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 08:40:47 +00:00
...
This commit is contained in:
@@ -38,6 +38,41 @@ class AutoWrappedModule(torch.nn.Module):
|
||||
return module(*args, **kwargs)
|
||||
|
||||
|
||||
class WanAutoCastLayerNorm(torch.nn.LayerNorm):
|
||||
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||
self.weight = module.weight
|
||||
self.bias = module.bias
|
||||
self.offload_dtype = offload_dtype
|
||||
self.offload_device = offload_device
|
||||
self.onload_dtype = onload_dtype
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
with torch.amp.autocast(device_type=x.device.type):
|
||||
x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x)
|
||||
return x
|
||||
|
||||
|
||||
class AutoWrappedLinear(torch.nn.Linear):
|
||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
|
||||
Reference in New Issue
Block a user