mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
lora retrieval
This commit is contained in:
46
lora/test_merger.py
Normal file
46
lora/test_merger.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from lora.dataset import LoraDataset
|
||||
from lora.merger import LoraPatcher
|
||||
from lora.utils import load_lora
|
||||
import torch, os
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
pipe.enable_auto_lora()
|
||||
|
||||
lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda")
|
||||
lora_patcher.load_state_dict(load_state_dict("models/lora_merger/epoch-3.safetensors"))
|
||||
|
||||
dataset = LoraDataset("data/lora/models", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4)
|
||||
|
||||
for seed in range(100):
|
||||
batch = dataset[0]
|
||||
num_lora = torch.randint(1, len(batch), (1,))[0]
|
||||
lora_state_dicts = [
|
||||
FluxLoRAConverter.align_to_diffsynth_format(load_lora(batch[i]["model_file"], device="cuda")) for i in range(num_lora)
|
||||
]
|
||||
image = pipe(
|
||||
prompt=batch[0]["text"],
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_nolora.jpg")
|
||||
for i in range(num_lora):
|
||||
image = pipe(
|
||||
prompt=batch[0]["text"],
|
||||
lora_state_dicts=[lora_state_dicts[i]],
|
||||
lora_patcher=lora_patcher,
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_{i}.jpg")
|
||||
image = pipe(
|
||||
prompt=batch[0]["text"],
|
||||
lora_state_dicts=lora_state_dicts,
|
||||
lora_patcher=lora_patcher,
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_merger.jpg")
|
||||
Reference in New Issue
Block a user