support lora

This commit is contained in:
Artiprocher
2024-01-10 14:34:02 +08:00
parent 8a460497fa
commit 9698e3988f
4 changed files with 87 additions and 5 deletions

View File

@@ -2,6 +2,7 @@ import torch, os, io
import numpy as np
from PIL import Image
import streamlit as st
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline

View File

@@ -1,4 +1,5 @@
import streamlit as st
st.set_page_config(layout="wide")
from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames
import torch, os, json
import numpy as np
@@ -9,11 +10,11 @@ class Runner:
pass
def load_pipeline(self, model_list, textual_inversion_folder, device, controlnet_units):
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
model_manager.load_textual_inversions(textual_inversion_folder)
model_manager.load_models(model_list)
model_manager.load_models(model_list, lora_alphas=lora_alphas)
pipe = SDVideoPipeline.from_model_manager(
model_manager,
[
@@ -100,6 +101,7 @@ config = {
"model_list": [],
"textual_inversion_folder": "models/textual_inversion",
"device": "cuda",
"lora_alphas": [],
"controlnet_units": []
},
"data": {
@@ -122,6 +124,14 @@ with st.expander("Model", expanded=True):
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
if animatediff_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
column_lora, column_lora_alpha = st.columns([2, 1])
with column_lora:
sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
with column_lora_alpha:
lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
if sd_lora_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
config["models"]["lora_alphas"].append(lora_alpha)
with st.expander("Data", expanded=True):