mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
fix bugs
This commit is contained in:
@@ -5,3 +5,4 @@ from .sd3_prompter import SD3Prompter
|
||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||
from .kolors_prompter import KolorsPrompter
|
||||
from .flux_prompter import FluxPrompter
|
||||
from .omost import OmostPromter
|
||||
|
||||
@@ -37,12 +37,12 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
||||
|
||||
|
||||
class BasePrompter:
|
||||
def __init__(self, refiners=[],extenders = []):
|
||||
def __init__(self, refiners=[], extenders=[]):
|
||||
self.refiners = refiners
|
||||
self.extenders = extenders
|
||||
|
||||
|
||||
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): # manager
|
||||
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
||||
for refiner_class in refiner_classes:
|
||||
refiner = refiner_class.from_model_manager(model_manager)
|
||||
self.refiners.append(refiner)
|
||||
@@ -63,7 +63,7 @@ class BasePrompter:
|
||||
return prompt
|
||||
|
||||
@torch.no_grad()
|
||||
def extend_prompt(self,prompt:str,positive = True):
|
||||
def extend_prompt(self, prompt:str, positive=True):
|
||||
extended_prompt = dict(prompt=prompt)
|
||||
for extender in self.extenders:
|
||||
extended_prompt = extender(extended_prompt)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
|
||||
# from .prompt_refiners import BeautifulPrompt
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
from transformers import AutoTokenizer, TextIteratorStreamer
|
||||
import difflib
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -225,10 +223,6 @@ class Canvas:
|
||||
prefixes=component['prefixes'],
|
||||
suffixes=component['suffixes']
|
||||
))
|
||||
|
||||
import pickle
|
||||
with open("tmp.pkl","wb+") as f:
|
||||
pickle.dump(bag_of_conditions,f)
|
||||
|
||||
return dict(
|
||||
initial_latent=initial_latent,
|
||||
@@ -261,10 +255,6 @@ class OmostPromter(torch.nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager):
|
||||
# model, model_path = model_manager.fetch_model("omost", require_model_path=True)
|
||||
# omost = OmostPromter(tokenizer_path=model_path, model=model)
|
||||
# return omost
|
||||
print(model_manager)
|
||||
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
omost = OmostPromter(
|
||||
|
||||
Reference in New Issue
Block a user