mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
@@ -41,7 +41,7 @@ class RRDB(torch.nn.Module):
|
||||
|
||||
class RRDBNet(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
|
||||
super(RRDBNet, self).__init__()
|
||||
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
||||
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
||||
@@ -65,6 +65,21 @@ class RRDBNet(torch.nn.Module):
|
||||
feat = self.lrelu(self.conv_up2(feat))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return RRDBNetStateDictConverter()
|
||||
|
||||
|
||||
class RRDBNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
|
||||
class ESRGAN(torch.nn.Module):
|
||||
@@ -73,12 +88,8 @@ class ESRGAN(torch.nn.Module):
|
||||
self.model = model
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_path):
|
||||
model = RRDBNet()
|
||||
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
return ESRGAN(model)
|
||||
def from_model_manager(model_manager):
|
||||
return ESRGAN(model_manager.fetch_model("esrgan"))
|
||||
|
||||
def process_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
||||
|
||||
@@ -58,7 +58,7 @@ class IFBlock(nn.Module):
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
super(IFNet, self).__init__()
|
||||
self.block0 = IFBlock(7+4, c=90)
|
||||
self.block1 = IFBlock(7+4, c=90)
|
||||
@@ -113,7 +113,7 @@ class IFNetStateDictConverter:
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
|
||||
|
||||
|
||||
class RIFEInterpolater:
|
||||
@@ -125,7 +125,7 @@ class RIFEInterpolater:
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager):
|
||||
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
|
||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
||||
|
||||
def process_image(self, image):
|
||||
width, height = image.size
|
||||
@@ -203,7 +203,7 @@ class RIFESmoother(RIFEInterpolater):
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager):
|
||||
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
|
||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
||||
|
||||
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
||||
output_tensor = []
|
||||
|
||||
Reference in New Issue
Block a user