mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
47 lines
1.7 KiB
Python
47 lines
1.7 KiB
Python
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
|
|
from diffsynth.models.lora import FluxLoRAConverter
|
|
from diffsynth.pipelines.flux_image import lets_dance_flux
|
|
from lora.dataset import LoraDataset
|
|
from lora.merger import LoraPatcher
|
|
from lora.utils import load_lora
|
|
import torch, os
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs
|
|
from tqdm import tqdm
|
|
|
|
|
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
|
pipe.enable_auto_lora()
|
|
|
|
lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda")
|
|
lora_patcher.load_state_dict(load_state_dict("models/lora_merger/epoch-3.safetensors"))
|
|
|
|
dataset = LoraDataset("data/lora/models", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4)
|
|
|
|
for seed in range(100):
|
|
batch = dataset[0]
|
|
num_lora = torch.randint(1, len(batch), (1,))[0]
|
|
lora_state_dicts = [
|
|
FluxLoRAConverter.align_to_diffsynth_format(load_lora(batch[i]["model_file"], device="cuda")) for i in range(num_lora)
|
|
]
|
|
image = pipe(
|
|
prompt=batch[0]["text"],
|
|
seed=seed,
|
|
)
|
|
image.save(f"data/lora/lora_outputs/image_{seed}_nolora.jpg")
|
|
for i in range(num_lora):
|
|
image = pipe(
|
|
prompt=batch[0]["text"],
|
|
lora_state_dicts=[lora_state_dicts[i]],
|
|
lora_patcher=lora_patcher,
|
|
seed=seed,
|
|
)
|
|
image.save(f"data/lora/lora_outputs/image_{seed}_{i}.jpg")
|
|
image = pipe(
|
|
prompt=batch[0]["text"],
|
|
lora_state_dicts=lora_state_dicts,
|
|
lora_patcher=lora_patcher,
|
|
seed=seed,
|
|
)
|
|
image.save(f"data/lora/lora_outputs/image_{seed}_merger.jpg")
|