mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 16:18:13 +00:00
update flux pipeline
This commit is contained in:
@@ -364,6 +364,7 @@ class FluxDiT(torch.nn.Module):
|
|||||||
|
|
||||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||||
if self.guidance_embedder is not None:
|
if self.guidance_embedder is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
prompt_emb = self.context_embedder(prompt_emb)
|
||||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
|
|||||||
@@ -65,9 +65,11 @@ class BasePipeline(torch.nn.Module):
|
|||||||
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
||||||
return prompt, local_prompts, masks, mask_scales
|
return prompt, local_prompts, masks, mask_scales
|
||||||
|
|
||||||
|
|
||||||
def enable_cpu_offload(self):
|
def enable_cpu_offload(self):
|
||||||
self.cpu_offload = True
|
self.cpu_offload = True
|
||||||
|
|
||||||
|
|
||||||
def load_models_to_device(self, loadmodel_names=[]):
|
def load_models_to_device(self, loadmodel_names=[]):
|
||||||
# only load models to device if cpu_offload is enabled
|
# only load models to device if cpu_offload is enabled
|
||||||
if not self.cpu_offload:
|
if not self.cpu_offload:
|
||||||
@@ -85,3 +87,9 @@ class BasePipeline(torch.nn.Module):
|
|||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
# fresh the cuda cache
|
# fresh the cuda cache
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
|
||||||
|
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
||||||
|
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
return noise
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=256):
|
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
|
||||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
||||||
)
|
)
|
||||||
@@ -80,7 +80,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
mask_scales= None,
|
mask_scales= None,
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
cfg_scale=1.0,
|
cfg_scale=1.0,
|
||||||
embedded_guidance=1.0,
|
embedded_guidance=3.5,
|
||||||
input_image=None,
|
input_image=None,
|
||||||
denoising_strength=1.0,
|
denoising_strength=1.0,
|
||||||
height=1024,
|
height=1024,
|
||||||
@@ -90,6 +90,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
tiled=False,
|
tiled=False,
|
||||||
tile_size=128,
|
tile_size=128,
|
||||||
tile_stride=64,
|
tile_stride=64,
|
||||||
|
seed=None,
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
@@ -104,10 +105,10 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.load_models_to_device(['vae_encoder'])
|
self.load_models_to_device(['vae_encoder'])
|
||||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
latents = self.encode_image(image, **tiler_kwargs)
|
latents = self.encode_image(image, **tiler_kwargs)
|
||||||
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
else:
|
else:
|
||||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
# Extend prompt
|
# Extend prompt
|
||||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
|
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class FluxPrompter(BasePrompter):
|
|||||||
prompt,
|
prompt,
|
||||||
positive=True,
|
positive=True,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
t5_sequence_length=256,
|
t5_sequence_length=512,
|
||||||
):
|
):
|
||||||
prompt = self.process_prompt(prompt, positive=positive)
|
prompt = self.process_prompt(prompt, positive=positive)
|
||||||
|
|
||||||
|
|||||||
@@ -12,30 +12,30 @@ model_manager.load_models([
|
|||||||
])
|
])
|
||||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
|
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
|
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||||
|
|
||||||
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
|
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
|
||||||
torch.manual_seed(6)
|
torch.manual_seed(9)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_inference_steps=30, embedded_guidance=3.5
|
num_inference_steps=50, embedded_guidance=3.5
|
||||||
)
|
)
|
||||||
image.save("image_1024.jpg")
|
image.save("image_1024.jpg")
|
||||||
|
|
||||||
# Enable classifier-free guidance
|
# Enable classifier-free guidance
|
||||||
torch.manual_seed(6)
|
torch.manual_seed(9)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt=prompt, negative_prompt=negative_prompt,
|
prompt=prompt, negative_prompt=negative_prompt,
|
||||||
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
|
num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
|
||||||
)
|
)
|
||||||
image.save("image_1024_cfg.jpg")
|
image.save("image_1024_cfg.jpg")
|
||||||
|
|
||||||
# Highres-fix
|
# Highres-fix
|
||||||
torch.manual_seed(7)
|
torch.manual_seed(10)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_inference_steps=30, embedded_guidance=3.5,
|
num_inference_steps=50, embedded_guidance=3.5,
|
||||||
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
|
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
|
||||||
)
|
)
|
||||||
image.save("image_2048_highres.jpg")
|
image.save("image_2048_highres.jpg")
|
||||||
|
|||||||
@@ -22,30 +22,30 @@ pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda")
|
|||||||
pipe.enable_cpu_offload()
|
pipe.enable_cpu_offload()
|
||||||
pipe.dit.quantize()
|
pipe.dit.quantize()
|
||||||
|
|
||||||
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
|
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
|
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||||
|
|
||||||
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
|
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
|
||||||
torch.manual_seed(6)
|
torch.manual_seed(9)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_inference_steps=30, embedded_guidance=3.5
|
num_inference_steps=50, embedded_guidance=3.5
|
||||||
)
|
)
|
||||||
image.save("image_1024.jpg")
|
image.save("image_1024.jpg")
|
||||||
|
|
||||||
# Enable classifier-free guidance
|
# Enable classifier-free guidance
|
||||||
torch.manual_seed(6)
|
torch.manual_seed(9)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt=prompt, negative_prompt=negative_prompt,
|
prompt=prompt, negative_prompt=negative_prompt,
|
||||||
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
|
num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
|
||||||
)
|
)
|
||||||
image.save("image_1024_cfg.jpg")
|
image.save("image_1024_cfg.jpg")
|
||||||
|
|
||||||
# Highres-fix
|
# Highres-fix
|
||||||
torch.manual_seed(7)
|
torch.manual_seed(10)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_inference_steps=30, embedded_guidance=3.5,
|
num_inference_steps=50, embedded_guidance=3.5,
|
||||||
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
|
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
|
||||||
)
|
)
|
||||||
image.save("image_2048_highres.jpg")
|
image.save("image_2048_highres.jpg")
|
||||||
|
|||||||
Reference in New Issue
Block a user