diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 648d23f..ffe55b8 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -364,6 +364,7 @@ class FluxDiT(torch.nn.Module): conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) if self.guidance_embedder is not None: + guidance = guidance * 1000 conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) prompt_emb = self.context_embedder(prompt_emb) image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index 956e9ba..55cfc14 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -65,9 +65,11 @@ class BasePipeline(torch.nn.Module): mask_scales += [100.0] * len(extended_prompt_dict.get("masks", [])) return prompt, local_prompts, masks, mask_scales + def enable_cpu_offload(self): self.cpu_offload = True + def load_models_to_device(self, loadmodel_names=[]): # only load models to device if cpu_offload is enabled if not self.cpu_offload: @@ -85,3 +87,9 @@ class BasePipeline(torch.nn.Module): model.to(self.device) # fresh the cuda cache torch.cuda.empty_cache() + + + def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16): + generator = None if seed is None else torch.Generator(device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + return noise diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 9ccdeb9..06f5649 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -58,7 +58,7 @@ class FluxImagePipeline(BasePipeline): return image - def encode_prompt(self, prompt, positive=True, t5_sequence_length=256): + def encode_prompt(self, prompt, positive=True, t5_sequence_length=512): prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt( prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length ) @@ -80,7 +80,7 @@ class FluxImagePipeline(BasePipeline): mask_scales= None, negative_prompt="", cfg_scale=1.0, - embedded_guidance=1.0, + embedded_guidance=3.5, input_image=None, denoising_strength=1.0, height=1024, @@ -90,6 +90,7 @@ class FluxImagePipeline(BasePipeline): tiled=False, tile_size=128, tile_stride=64, + seed=None, progress_bar_cmd=tqdm, progress_bar_st=None, ): @@ -104,10 +105,10 @@ class FluxImagePipeline(BasePipeline): self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) latents = self.encode_image(image, **tiler_kwargs) - noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) + noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: - latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) + latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) # Extend prompt self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) diff --git a/diffsynth/prompters/flux_prompter.py b/diffsynth/prompters/flux_prompter.py index 07d6dc7..9a6bd7d 100644 --- a/diffsynth/prompters/flux_prompter.py +++ b/diffsynth/prompters/flux_prompter.py @@ -57,7 +57,7 @@ class FluxPrompter(BasePrompter): prompt, positive=True, device="cuda", - t5_sequence_length=256, + t5_sequence_length=512, ): prompt = self.process_prompt(prompt, positive=positive) diff --git a/examples/image_synthesis/flux_text_to_image.py b/examples/image_synthesis/flux_text_to_image.py index a2e5199..6a50df3 100644 --- a/examples/image_synthesis/flux_text_to_image.py +++ b/examples/image_synthesis/flux_text_to_image.py @@ -12,30 +12,30 @@ model_manager.load_models([ ]) pipe = FluxImagePipeline.from_model_manager(model_manager) -prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." -negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," # Disable classifier-free guidance (consistent with the original implementation of FLUX.1) -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5 + num_inference_steps=50, embedded_guidance=3.5 ) image.save("image_1024.jpg") # Enable classifier-free guidance -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, negative_prompt=negative_prompt, - num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5 + num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5 ) image.save("image_1024_cfg.jpg") # Highres-fix -torch.manual_seed(7) +torch.manual_seed(10) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5, + num_inference_steps=50, embedded_guidance=3.5, input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True ) image.save("image_2048_highres.jpg") diff --git a/examples/image_synthesis/flux_text_to_image_low_vram.py b/examples/image_synthesis/flux_text_to_image_low_vram.py index b98929c..985f009 100644 --- a/examples/image_synthesis/flux_text_to_image_low_vram.py +++ b/examples/image_synthesis/flux_text_to_image_low_vram.py @@ -22,30 +22,30 @@ pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda") pipe.enable_cpu_offload() pipe.dit.quantize() -prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." -negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," # Disable classifier-free guidance (consistent with the original implementation of FLUX.1) -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5 + num_inference_steps=50, embedded_guidance=3.5 ) image.save("image_1024.jpg") # Enable classifier-free guidance -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, negative_prompt=negative_prompt, - num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5 + num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5 ) image.save("image_1024_cfg.jpg") # Highres-fix -torch.manual_seed(7) +torch.manual_seed(10) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5, + num_inference_steps=50, embedded_guidance=3.5, input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True ) -image.save("image_2048_highres.jpg") \ No newline at end of file +image.save("image_2048_highres.jpg")