This commit is contained in:
xuyixuan.xyx
2025-05-07 11:22:13 +08:00
parent 290ec469ca
commit f17558a4c4
4 changed files with 47 additions and 21 deletions

32
test.py
View File

@@ -102,31 +102,35 @@ model_manager.load_models([
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict, strict=False)
# state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
# pipe.dit.load_state_dict(state_dict, strict=False)
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
adapter.load_state_dict(state_dict, strict=False)
# adapter.load_state_dict(state_dict, strict=False)
qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda")
sd = {}
for i in range(1, 6):
print(i)
sd.update(load_state_dict(f"models/nexus_v1/epoch-8/model-0000{i}-of-00005.safetensors", torch_dtype=torch.bfloat16))
pipe.dit.load_state_dict({i.replace("pipe.dit.", ""): sd[i] for i in sd if i.startswith("pipe.dit.")})
qwenvl.load_state_dict({i.replace("qwenvl.", ""): sd[i] for i in sd if i.startswith("qwenvl.")})
adapter.load_state_dict({i.replace("adapter.", ""): sd[i] for i in sd if i.startswith("adapter.")})
for i in sd:
if (not i.startswith("pipe.dit")) and (not i.startswith("qwenvl.")) and (not i.startswith("adapter.")):
print(i)
with torch.no_grad():
instruction = "Generate an image according to the following description: a beautiful Asian girl"
instruction = "Generate an image according to the following description: hyper-realistic and detailed 2010s movie still portrait of Josip Broz Tito, by Paolo Sorrentino, Leica SL2 50mm, clear color, high quality, high textured, dramatic light, cinematic"
emb = qwenvl(instruction, images=None)
emb = adapter(emb)
image = pipe("", image_emb=emb)
image = pipe("", image_emb=emb, height=512, width=512)
image.save("image_1.jpg")
with torch.no_grad():
instruction = "<|vision_start|><|image_pad|><|vision_end|> Add sunglasses."
instruction = "<|vision_start|><|image_pad|><|vision_end|> transform the image into a cartoon style with vibrant colors and a confident expression."
emb = qwenvl(instruction, images=[Image.open("image_1.jpg")])
emb = adapter(emb)
image = pipe("", image_emb=emb)
image = pipe("", image_emb=emb, height=512, width=512)
image.save("image_2.jpg")
with torch.no_grad():
instruction = "<|vision_start|><|image_pad|><|vision_end|> Let her smile."
emb = qwenvl(instruction, images=[Image.open("image_2.jpg")])
emb = adapter(emb)
image = pipe("", image_emb=emb)
image.save("image_3.jpg")