support CogVideoX-5B (#184)

* support cogvideo

* update examples
This commit is contained in:
Zhongjie Duan
2024-09-03 11:37:54 +08:00
committed by GitHub
parent fe485b3fa1
commit d154bee18a
22 changed files with 2653 additions and 107 deletions

View File

@@ -41,7 +41,7 @@ class RRDB(torch.nn.Module):
class RRDBNet(torch.nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
super(RRDBNet, self).__init__()
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
@@ -65,6 +65,21 @@ class RRDBNet(torch.nn.Module):
feat = self.lrelu(self.conv_up2(feat))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
@staticmethod
def state_dict_converter():
return RRDBNetStateDictConverter()
class RRDBNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
return state_dict, {"upcast_to_float32": True}
class ESRGAN(torch.nn.Module):
@@ -73,12 +88,8 @@ class ESRGAN(torch.nn.Module):
self.model = model
@staticmethod
def from_pretrained(model_path):
model = RRDBNet()
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
model.load_state_dict(state_dict)
model.eval()
return ESRGAN(model)
def from_model_manager(model_manager):
return ESRGAN(model_manager.fetch_model("esrgan"))
def process_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)