mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import torch
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
from .general_modules import TemporalTimesteps
|
||||
|
||||
|
||||
class MultiValueEncoder(torch.nn.Module):
|
||||
def __init__(self, encoders=()):
|
||||
super().__init__()
|
||||
if not isinstance(encoders, list):
|
||||
encoders = [encoders]
|
||||
self.encoders = torch.nn.ModuleList(encoders)
|
||||
|
||||
def __call__(self, values, dtype):
|
||||
@@ -28,12 +30,6 @@ class SingleValueEncoder(torch.nn.Module):
|
||||
self.positional_embedding = torch.nn.Parameter(
|
||||
torch.randn(self.prefer_len, dim_out)
|
||||
)
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
last_linear = self.prefer_value_embedder[-1]
|
||||
torch.nn.init.zeros_(last_linear.weight)
|
||||
torch.nn.init.zeros_(last_linear.bias)
|
||||
|
||||
def forward(self, value, dtype):
|
||||
value = value * 1000
|
||||
|
||||
Reference in New Issue
Block a user