mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
train
This commit is contained in:
32
test.py
32
test.py
@@ -102,31 +102,35 @@ model_manager.load_models([
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
|
||||
pipe.dit.load_state_dict(state_dict, strict=False)
|
||||
# state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
|
||||
# pipe.dit.load_state_dict(state_dict, strict=False)
|
||||
|
||||
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
|
||||
adapter.load_state_dict(state_dict, strict=False)
|
||||
# adapter.load_state_dict(state_dict, strict=False)
|
||||
|
||||
qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda")
|
||||
|
||||
sd = {}
|
||||
for i in range(1, 6):
|
||||
print(i)
|
||||
sd.update(load_state_dict(f"models/nexus_v1/epoch-8/model-0000{i}-of-00005.safetensors", torch_dtype=torch.bfloat16))
|
||||
pipe.dit.load_state_dict({i.replace("pipe.dit.", ""): sd[i] for i in sd if i.startswith("pipe.dit.")})
|
||||
qwenvl.load_state_dict({i.replace("qwenvl.", ""): sd[i] for i in sd if i.startswith("qwenvl.")})
|
||||
adapter.load_state_dict({i.replace("adapter.", ""): sd[i] for i in sd if i.startswith("adapter.")})
|
||||
for i in sd:
|
||||
if (not i.startswith("pipe.dit")) and (not i.startswith("qwenvl.")) and (not i.startswith("adapter.")):
|
||||
print(i)
|
||||
|
||||
with torch.no_grad():
|
||||
instruction = "Generate an image according to the following description: a beautiful Asian girl"
|
||||
instruction = "Generate an image according to the following description: hyper-realistic and detailed 2010s movie still portrait of Josip Broz Tito, by Paolo Sorrentino, Leica SL2 50mm, clear color, high quality, high textured, dramatic light, cinematic"
|
||||
emb = qwenvl(instruction, images=None)
|
||||
emb = adapter(emb)
|
||||
image = pipe("", image_emb=emb)
|
||||
image = pipe("", image_emb=emb, height=512, width=512)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
with torch.no_grad():
|
||||
instruction = "<|vision_start|><|image_pad|><|vision_end|> Add sunglasses."
|
||||
instruction = "<|vision_start|><|image_pad|><|vision_end|> transform the image into a cartoon style with vibrant colors and a confident expression."
|
||||
emb = qwenvl(instruction, images=[Image.open("image_1.jpg")])
|
||||
emb = adapter(emb)
|
||||
image = pipe("", image_emb=emb)
|
||||
image = pipe("", image_emb=emb, height=512, width=512)
|
||||
image.save("image_2.jpg")
|
||||
|
||||
with torch.no_grad():
|
||||
instruction = "<|vision_start|><|image_pad|><|vision_end|> Let her smile."
|
||||
emb = qwenvl(instruction, images=[Image.open("image_2.jpg")])
|
||||
emb = adapter(emb)
|
||||
image = pipe("", image_emb=emb)
|
||||
image.save("image_3.jpg")
|
||||
|
||||
Reference in New Issue
Block a user