mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
@@ -41,7 +41,7 @@ class RRDB(torch.nn.Module):
|
||||
|
||||
class RRDBNet(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
|
||||
super(RRDBNet, self).__init__()
|
||||
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
||||
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
||||
@@ -65,6 +65,21 @@ class RRDBNet(torch.nn.Module):
|
||||
feat = self.lrelu(self.conv_up2(feat))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return RRDBNetStateDictConverter()
|
||||
|
||||
|
||||
class RRDBNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
|
||||
class ESRGAN(torch.nn.Module):
|
||||
@@ -73,12 +88,8 @@ class ESRGAN(torch.nn.Module):
|
||||
self.model = model
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_path):
|
||||
model = RRDBNet()
|
||||
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
return ESRGAN(model)
|
||||
def from_model_manager(model_manager):
|
||||
return ESRGAN(model_manager.fetch_model("esrgan"))
|
||||
|
||||
def process_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
||||
|
||||
Reference in New Issue
Block a user