update flux pipeline

This commit is contained in:
Artiprocher
2024-10-10 17:05:04 +08:00
parent 41ea2f811a
commit fa0fa95bb6
6 changed files with 32 additions and 22 deletions

View File

@@ -364,6 +364,7 @@ class FluxDiT(torch.nn.Module):
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))

View File

@@ -65,9 +65,11 @@ class BasePipeline(torch.nn.Module):
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
@@ -85,3 +87,9 @@ class BasePipeline(torch.nn.Module):
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return noise

View File

@@ -58,7 +58,7 @@ class FluxImagePipeline(BasePipeline):
return image
def encode_prompt(self, prompt, positive=True, t5_sequence_length=256):
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
)
@@ -80,7 +80,7 @@ class FluxImagePipeline(BasePipeline):
mask_scales= None,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=1.0,
embedded_guidance=3.5,
input_image=None,
denoising_strength=1.0,
height=1024,
@@ -90,6 +90,7 @@ class FluxImagePipeline(BasePipeline):
tiled=False,
tile_size=128,
tile_stride=64,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
@@ -104,10 +105,10 @@ class FluxImagePipeline(BasePipeline):
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])

View File

@@ -57,7 +57,7 @@ class FluxPrompter(BasePrompter):
prompt,
positive=True,
device="cuda",
t5_sequence_length=256,
t5_sequence_length=512,
):
prompt = self.process_prompt(prompt, positive=positive)

View File

@@ -12,30 +12,30 @@ model_manager.load_models([
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
torch.manual_seed(6)
torch.manual_seed(9)
image = pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5
num_inference_steps=50, embedded_guidance=3.5
)
image.save("image_1024.jpg")
# Enable classifier-free guidance
torch.manual_seed(6)
torch.manual_seed(9)
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
)
image.save("image_1024_cfg.jpg")
# Highres-fix
torch.manual_seed(7)
torch.manual_seed(10)
image = pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5,
num_inference_steps=50, embedded_guidance=3.5,
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
)
image.save("image_2048_highres.jpg")

View File

@@ -22,30 +22,30 @@ pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
torch.manual_seed(6)
torch.manual_seed(9)
image = pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5
num_inference_steps=50, embedded_guidance=3.5
)
image.save("image_1024.jpg")
# Enable classifier-free guidance
torch.manual_seed(6)
torch.manual_seed(9)
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
)
image.save("image_1024_cfg.jpg")
# Highres-fix
torch.manual_seed(7)
torch.manual_seed(10)
image = pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5,
num_inference_steps=50, embedded_guidance=3.5,
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
)
image.save("image_2048_highres.jpg")
image.save("image_2048_highres.jpg")