mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #1330 from modelscope/ses-doc
Research Tutorial Sec 2
This commit is contained in:
@@ -867,7 +867,7 @@ DiffSynth-Studio is not just an engineered model framework, but also an incubato
|
||||
|
||||
- Paper: [Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
|
||||
](https://arxiv.org/abs/2602.03208)
|
||||
- Sample Code: coming soon
|
||||
- Sample Code: [/docs/en/Research_Tutorial/inference_time_scaling.md](/docs/en/Research_Tutorial/inference_time_scaling.md)
|
||||
|
||||
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|
||||
|-|-|-|-|
|
||||
|
||||
@@ -867,7 +867,7 @@ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果
|
||||
|
||||
- 论文:[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
|
||||
](https://arxiv.org/abs/2602.03208)
|
||||
- 代码样例:coming soon
|
||||
- 代码样例:[/docs/en/Research_Tutorial/inference_time_scaling.md](/docs/en/Research_Tutorial/inference_time_scaling.md)
|
||||
|
||||
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|
||||
|-|-|-|-|
|
||||
|
||||
@@ -90,6 +90,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
initial_noise: torch.Tensor = None,
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# Progress bar
|
||||
@@ -109,7 +110,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"seed": seed, "rand_device": rand_device, "initial_noise": initial_noise,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
@@ -429,12 +430,15 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
|
||||
class Flux2Unit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
input_params=("height", "width", "seed", "rand_device", "initial_noise"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device, initial_noise):
|
||||
if initial_noise is not None:
|
||||
noise = initial_noise.clone()
|
||||
else:
|
||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
1
diffsynth/utils/ses/README.md
Normal file
1
diffsynth/utils/ses/README.md
Normal file
@@ -0,0 +1 @@
|
||||
Please see `docs/en/Research_Tutorial/inference_time_scaling.md` or `docs/zh/Research_Tutorial/inference_time_scaling.md` for more details.
|
||||
1
diffsynth/utils/ses/__init__.py
Normal file
1
diffsynth/utils/ses/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .ses import ses_search
|
||||
117
diffsynth/utils/ses/ses.py
Normal file
117
diffsynth/utils/ses/ses.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import pywt
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def split_dwt(z_tensor_cpu, wavelet_name, dwt_level):
|
||||
all_clow_np = []
|
||||
all_chigh_list = []
|
||||
z_tensor_cpu = z_tensor_cpu.cpu().float()
|
||||
|
||||
for i in range(z_tensor_cpu.shape[0]):
|
||||
z_numpy_ch = z_tensor_cpu[i].numpy()
|
||||
|
||||
coeffs_ch = pywt.wavedec2(z_numpy_ch, wavelet_name, level=dwt_level, mode='symmetric', axes=(-2, -1))
|
||||
|
||||
clow_np = coeffs_ch[0]
|
||||
chigh_list = coeffs_ch[1:]
|
||||
|
||||
all_clow_np.append(clow_np)
|
||||
all_chigh_list.append(chigh_list)
|
||||
|
||||
all_clow_tensor = torch.from_numpy(np.stack(all_clow_np, axis=0))
|
||||
return all_clow_tensor, all_chigh_list
|
||||
|
||||
|
||||
def reconstruct_dwt(c_low_tensor_cpu, c_high_coeffs, wavelet_name, original_shape):
|
||||
H_high, W_high = original_shape
|
||||
c_low_tensor_cpu = c_low_tensor_cpu.cpu().float()
|
||||
|
||||
clow_np = c_low_tensor_cpu.numpy()
|
||||
|
||||
if clow_np.ndim == 4 and clow_np.shape[0] == 1:
|
||||
clow_np = clow_np[0]
|
||||
|
||||
coeffs_combined = [clow_np] + c_high_coeffs
|
||||
z_recon_np = pywt.waverec2(coeffs_combined, wavelet_name, mode='symmetric', axes=(-2, -1))
|
||||
if z_recon_np.shape[-2] != H_high or z_recon_np.shape[-1] != W_high:
|
||||
z_recon_np = z_recon_np[..., :H_high, :W_high]
|
||||
z_recon_tensor = torch.from_numpy(z_recon_np)
|
||||
if z_recon_tensor.ndim == 3:
|
||||
z_recon_tensor = z_recon_tensor.unsqueeze(0)
|
||||
return z_recon_tensor
|
||||
|
||||
|
||||
def ses_search(
|
||||
base_latents,
|
||||
objective_reward_fn,
|
||||
total_eval_budget=30,
|
||||
popsize=10,
|
||||
k_elites=5,
|
||||
wavelet_name="db1",
|
||||
dwt_level=4,
|
||||
):
|
||||
latent_h, latent_w = base_latents.shape[-2], base_latents.shape[-1]
|
||||
c_low_init, c_high_fixed_batch = split_dwt(base_latents, wavelet_name, dwt_level)
|
||||
c_high_fixed = c_high_fixed_batch[0]
|
||||
c_low_shape = c_low_init.shape[1:]
|
||||
mu = torch.zeros_like(c_low_init.view(-1).cpu())
|
||||
sigma_sq = torch.ones_like(mu) * 1.0
|
||||
|
||||
best_overall = {"fitness": -float('inf'), "score": -float('inf'), "c_low": c_low_init[0]}
|
||||
eval_count = 0
|
||||
|
||||
elite_db = []
|
||||
n_generations = (total_eval_budget // popsize) + 5
|
||||
pbar = tqdm(total=total_eval_budget, desc="[SES] Searching", unit="img")
|
||||
|
||||
for gen in range(n_generations):
|
||||
if eval_count >= total_eval_budget: break
|
||||
|
||||
std = torch.sqrt(torch.clamp(sigma_sq, min=1e-9))
|
||||
z_noise = torch.randn(popsize, mu.shape[0])
|
||||
samples_flat = mu + z_noise * std
|
||||
samples_reshaped = samples_flat.view(popsize, *c_low_shape)
|
||||
|
||||
batch_results = []
|
||||
|
||||
for i in range(popsize):
|
||||
if eval_count >= total_eval_budget: break
|
||||
|
||||
c_low_sample = samples_reshaped[i].unsqueeze(0)
|
||||
z_recon = reconstruct_dwt(c_low_sample, c_high_fixed, wavelet_name, (latent_h, latent_w))
|
||||
z_recon = z_recon.to(base_latents.device, dtype=base_latents.dtype)
|
||||
# img = pipeline_callback(z_recon)
|
||||
|
||||
# score = scorer.get_score(img, prompt)
|
||||
score = objective_reward_fn(z_recon)
|
||||
res = {
|
||||
"score": score,
|
||||
"c_low": c_low_sample.cpu()
|
||||
}
|
||||
batch_results.append(res)
|
||||
if score > best_overall['score']:
|
||||
best_overall = res
|
||||
|
||||
eval_count += 1
|
||||
pbar.update(1)
|
||||
|
||||
if not batch_results: break
|
||||
elite_db.extend(batch_results)
|
||||
elite_db.sort(key=lambda x: x['score'], reverse=True)
|
||||
elite_db = elite_db[:k_elites]
|
||||
elites_flat = torch.stack([x['c_low'].view(-1) for x in elite_db])
|
||||
mu_new = torch.mean(elites_flat, dim=0)
|
||||
|
||||
if len(elite_db) > 1:
|
||||
sigma_sq_new = torch.var(elites_flat, dim=0, unbiased=True) + 1e-7
|
||||
else:
|
||||
sigma_sq_new = sigma_sq
|
||||
mu = mu_new
|
||||
sigma_sq = sigma_sq_new
|
||||
pbar.close()
|
||||
best_c_low = best_overall['c_low']
|
||||
final_latents = reconstruct_dwt(best_c_low, c_high_fixed, wavelet_name, (latent_h, latent_w))
|
||||
|
||||
return final_latents.to(base_latents.device, dtype=base_latents.dtype)
|
||||
@@ -80,7 +80,7 @@ This section introduces the independent core module `diffsynth.core` in `DiffSyn
|
||||
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
|
||||
|
||||
* [Training models from scratch](./Research_Tutorial/train_from_scratch.md)
|
||||
* Inference improvement techniques 【coming soon】
|
||||
* [Inference improvement techniques](./Research_Tutorial/inference_time_scaling.md)
|
||||
* Designing controllable generation models 【coming soon】
|
||||
* Creating new training paradigms 【coming soon】
|
||||
|
||||
|
||||
236
docs/en/Research_Tutorial/inference_time_scaling.ipynb
Normal file
236
docs/en/Research_Tutorial/inference_time_scaling.ipynb
Normal file
@@ -0,0 +1,236 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8db54992",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Inference Optimization Techniques\n",
|
||||
"\n",
|
||||
"DiffSynth-Studio aims to drive technological innovation through its foundational framework. This article demonstrates how to build a training-free image generation enhancement solution using DiffSynth-Studio, taking Inference-time scaling as an example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0911cad4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Image Quality Quantification\n",
|
||||
"\n",
|
||||
"First, we need to find an indicator to quantify image quality from generation models. Manual scoring is the most straightforward solution but too costly for large-scale applications. However, after collecting manual scores, training an image classification model to predict human scoring is completely feasible. PickScore [[1]](https://arxiv.org/abs/2305.01569) is such a model. Running the following code will automatically download and load the [PickScore model](https://modelscope.cn/models/AI-ModelScope/PickScore_v1)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4faca4ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from modelscope import AutoProcessor, AutoModel\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"class PickScore(torch.nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" self.processor = AutoProcessor.from_pretrained(\"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\")\n",
|
||||
" self.model = AutoModel.from_pretrained(\"AI-ModelScope/PickScore_v1\").eval().to(\"cuda\")\n",
|
||||
"\n",
|
||||
" def forward(self, image, prompt):\n",
|
||||
" image_inputs = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||
" text_inputs = self.processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||
" with torch.inference_mode():\n",
|
||||
" image_embs = self.model.get_image_features(**image_inputs).pooler_output\n",
|
||||
" image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)\n",
|
||||
" text_embs = self.model.get_text_features(**text_inputs).pooler_output\n",
|
||||
" text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)\n",
|
||||
" score = (text_embs @ image_embs.T).flatten().item()\n",
|
||||
" return score\n",
|
||||
"\n",
|
||||
"reward_model = PickScore()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5f807cec",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Inference-time Scaling Techniques\n",
|
||||
"\n",
|
||||
"Inference-time Scaling [[2]](https://arxiv.org/abs/2504.00294) is an interesting technique aiming to improve generation quality by increasing computational costs during inference. For example, in language models, models like [Qwen/Qwen3.5-27B](https://modelscope.cn/models/Qwen/Qwen3.5-27B) and [deepseek-ai/DeepSeek-R1](deepseek-ai/DeepSeek-R1) use \"thinking mode\" to guide the model to spend more time considering results more carefully, producing more accurate answers. Next, we'll use the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model as an example to explore how to design Inference-time Scaling solutions for image generation models.\n",
|
||||
"\n",
|
||||
"> Before starting, we slightly modified the `Flux2ImagePipeline` code to allow initialization with specific Gaussian noise matrices for result reproducibility. See `Flux2Unit_NoiseInitializer` in [diffsynth/pipelines/flux2_image.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/flux2_image.py).\n",
|
||||
"\n",
|
||||
"Run the following code to load the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c5818a87",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\n",
|
||||
"\n",
|
||||
"pipe = Flux2ImagePipeline.from_pretrained(\n",
|
||||
" torch_dtype=torch.bfloat16,\n",
|
||||
" device=\"cuda\",\n",
|
||||
" model_configs=[\n",
|
||||
" ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n",
|
||||
" ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n",
|
||||
" ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n",
|
||||
" ],\n",
|
||||
" tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f58e9945",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Generate a sketch cat image using the prompt `\"sketch, a cat\"` and score it with the PickScore model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ea2d258",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def evaluate_noise(noise, pipe, reward_model, prompt):\n",
|
||||
" # Generate an image and compute the score.\n",
|
||||
" image = pipe(\n",
|
||||
" prompt=prompt,\n",
|
||||
" num_inference_steps=4,\n",
|
||||
" initial_noise=noise,\n",
|
||||
" progress_bar_cmd=lambda x: x,\n",
|
||||
" )\n",
|
||||
" score = reward_model(image, prompt)\n",
|
||||
" return score\n",
|
||||
"\n",
|
||||
"torch.manual_seed(1)\n",
|
||||
"prompt = \"sketch, a cat\"\n",
|
||||
"noise = pipe.generate_noise((1, 128, 64, 64), rand_device=\"cuda\", rand_torch_dtype=pipe.torch_dtype)\n",
|
||||
"\n",
|
||||
"image_1 = pipe(prompt, num_inference_steps=4, initial_noise=noise)\n",
|
||||
"print(\"Score:\", reward_model(image_1, prompt))\n",
|
||||
"image_1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5e11694e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2.1 Best-of-N Random Search\n",
|
||||
"\n",
|
||||
"Model generation results have inherent randomness. Different random seeds produce different images - sometimes high quality, sometimes low. This leads to a simple Inference-time scaling solution: generate images using multiple random seeds, score them with PickScore, and retain only the highest-scoring image."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "241f10d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"def random_search(base_latents, objective_reward_fn, total_eval_budget):\n",
|
||||
" # Search for the noise randomly.\n",
|
||||
" best_noise = base_latents\n",
|
||||
" best_score = objective_reward_fn(base_latents)\n",
|
||||
" for it in tqdm(range(total_eval_budget - 1)):\n",
|
||||
" noise = pipe.generate_noise((1, 128, 64, 64), seed=None)\n",
|
||||
" score = objective_reward_fn(noise)\n",
|
||||
" if score > best_score:\n",
|
||||
" best_score, best_noise = score, noise\n",
|
||||
" return best_noise\n",
|
||||
"\n",
|
||||
"best_noise = random_search(\n",
|
||||
" base_latents=noise,\n",
|
||||
" objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\n",
|
||||
" total_eval_budget=50,\n",
|
||||
")\n",
|
||||
"image_2 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\n",
|
||||
"print(\"Score:\", reward_model(image_2, prompt))\n",
|
||||
"image_2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8e9bf966",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can clearly see that after multiple random searches, the final selected cat image shows richer fur details and significantly improved PickScore. However, this brute-force random search is extremely inefficient - generation time multiplies while easily hitting quality limits. Therefore, we need a more efficient search method that achieves higher scores within the same computational budget."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c9578349",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2.2 SES Search\n",
|
||||
"\n",
|
||||
"To overcome random search limitations, we introduce the Spectral Evolution Search (SES) algorithm [[3]](https://arxiv.org/abs/2602.03208). Detailed code is available at [diffsynth/utils/ses](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/utils/ses).\n",
|
||||
"\n",
|
||||
"Image generation in diffusion models is largely determined by low-frequency components in the initial noise. The SES algorithm decomposes Gaussian noise through wavelet transforms, fixes high-frequency details, and applies an evolution search using the cross-entropy method specifically on low-frequency components to find optimal initial noise with higher efficiency.\n",
|
||||
"\n",
|
||||
"Run the following code to perform efficient best Gaussian noise matrix search using SES."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "adeed2aa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from diffsynth.utils.ses import ses_search\n",
|
||||
"\n",
|
||||
"best_noise = ses_search(\n",
|
||||
" base_latents=noise,\n",
|
||||
" objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\n",
|
||||
" total_eval_budget=50,\n",
|
||||
")\n",
|
||||
"image_3 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\n",
|
||||
"print(\"Score:\", reward_model(image_3, prompt))\n",
|
||||
"image_3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "940a97f1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Observing the results, under the same computational budget, SES achieves significantly higher PickScore compared to random search. The \"sketch cat\" demonstrates more refined overall composition and more layered contrast between light and shadow.\n",
|
||||
"\n",
|
||||
"Inference-time scaling can achieve higher image quality at the cost of longer inference time. The generated image data can then be used to train the model itself through methods like DPO [[4]](https://arxiv.org/abs/2311.12908) or differential training [[5]](https://arxiv.org/abs/2412.12888), opening another interesting research direction."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "dzj8",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.19"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
140
docs/en/Research_Tutorial/inference_time_scaling.md
Normal file
140
docs/en/Research_Tutorial/inference_time_scaling.md
Normal file
@@ -0,0 +1,140 @@
|
||||
# Inference Optimization Techniques
|
||||
|
||||
DiffSynth-Studio aims to drive technological innovation through its foundational framework. This article demonstrates how to build a training-free image generation enhancement solution using DiffSynth-Studio, taking Inference-time scaling as an example.
|
||||
|
||||
Notebook: https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Research_Tutorial/inference_time_scaling.ipynb
|
||||
|
||||
## 1. Image Quality Quantification
|
||||
|
||||
First, we need to find an indicator to quantify image quality from generation models. Manual scoring is the most straightforward solution but too costly for large-scale applications. However, after collecting manual scores, training an image classification model to predict human scoring is completely feasible. PickScore [[1]](https://arxiv.org/abs/2305.01569) is such a model. Running the following code will automatically download and load the [PickScore model](https://modelscope.cn/models/AI-ModelScope/PickScore_v1).
|
||||
|
||||
```python
|
||||
from modelscope import AutoProcessor, AutoModel
|
||||
import torch
|
||||
|
||||
class PickScore(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||
self.model = AutoModel.from_pretrained("AI-ModelScope/PickScore_v1").eval().to("cuda")
|
||||
|
||||
def forward(self, image, prompt):
|
||||
image_inputs = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors="pt").to("cuda")
|
||||
text_inputs = self.processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors="pt").to("cuda")
|
||||
with torch.inference_mode():
|
||||
image_embs = self.model.get_image_features(**image_inputs).pooler_output
|
||||
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
||||
text_embs = self.model.get_text_features(**text_inputs).pooler_output
|
||||
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
||||
score = (text_embs @ image_embs.T).flatten().item()
|
||||
return score
|
||||
|
||||
reward_model = PickScore()
|
||||
```
|
||||
|
||||
## 2. Inference-time Scaling Techniques
|
||||
|
||||
Inference-time Scaling [[2]](https://arxiv.org/abs/2504.00294) is an interesting technique aiming to improve generation quality by increasing computational costs during inference. For example, in language models, models like [Qwen/Qwen3.5-27B](https://modelscope.cn/models/Qwen/Qwen3.5-27B) and [deepseek-ai/DeepSeek-R1](deepseek-ai/DeepSeek-R1) use "thinking mode" to guide the model to spend more time considering results more carefully, producing more accurate answers. Next, we'll use the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model as an example to explore how to design Inference-time Scaling solutions for image generation models.
|
||||
|
||||
> Before starting, we slightly modified the `Flux2ImagePipeline` code to allow initialization with specific Gaussian noise matrices for result reproducibility. See `Flux2Unit_NoiseInitializer` in [diffsynth/pipelines/flux2_image.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/flux2_image.py).
|
||||
|
||||
Run the following code to load the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
```
|
||||
|
||||
Generate a sketch cat image using the prompt `"sketch, a cat"` and score it with the PickScore model.
|
||||
|
||||
```python
|
||||
def evaluate_noise(noise, pipe, reward_model, prompt):
|
||||
# Generate an image and compute the score.
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
initial_noise=noise,
|
||||
progress_bar_cmd=lambda x: x,
|
||||
)
|
||||
score = reward_model(image, prompt)
|
||||
return score
|
||||
|
||||
torch.manual_seed(1)
|
||||
prompt = "sketch, a cat"
|
||||
noise = pipe.generate_noise((1, 128, 64, 64), rand_device="cuda", rand_torch_dtype=pipe.torch_dtype)
|
||||
|
||||
image_1 = pipe(prompt, num_inference_steps=4, initial_noise=noise)
|
||||
print("Score:", reward_model(image_1, prompt))
|
||||
image_1
|
||||
```
|
||||
|
||||

|
||||
|
||||
### 2.1 Best-of-N Random Search
|
||||
|
||||
Model generation results have inherent randomness. Different random seeds produce different images - sometimes high quality, sometimes low. This leads to a simple Inference-time scaling solution: generate images using multiple random seeds, score them with PickScore, and retain only the highest-scoring image.
|
||||
|
||||
```python
|
||||
from tqdm import tqdm
|
||||
|
||||
def random_search(base_latents, objective_reward_fn, total_eval_budget):
|
||||
# Search for the noise randomly.
|
||||
best_noise = base_latents
|
||||
best_score = objective_reward_fn(base_latents)
|
||||
for it in tqdm(range(total_eval_budget - 1)):
|
||||
noise = pipe.generate_noise((1, 128, 64, 64), seed=None)
|
||||
score = objective_reward_fn(noise)
|
||||
if score > best_score:
|
||||
best_score, best_noise = score, noise
|
||||
return best_noise
|
||||
|
||||
best_noise = random_search(
|
||||
base_latents=noise,
|
||||
objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),
|
||||
total_eval_budget=50,
|
||||
)
|
||||
image_2 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)
|
||||
print("Score:", reward_model(image_2, prompt))
|
||||
image_2
|
||||
```
|
||||
|
||||

|
||||
|
||||
We can clearly see that after multiple random searches, the final selected cat image shows richer fur details and significantly improved PickScore. However, this brute-force random search is extremely inefficient - generation time multiplies while easily hitting quality limits. Therefore, we need a more efficient search method that achieves higher scores within the same computational budget.
|
||||
|
||||
### 2.2 SES Search
|
||||
|
||||
To overcome random search limitations, we introduce the Spectral Evolution Search (SES) algorithm [[3]](https://arxiv.org/abs/2602.03208). Detailed code is available at [diffsynth/utils/ses](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/utils/ses).
|
||||
|
||||
Image generation in diffusion models is largely determined by low-frequency components in the initial noise. The SES algorithm decomposes Gaussian noise through wavelet transforms, fixes high-frequency details, and applies an evolution search using the cross-entropy method specifically on low-frequency components to find optimal initial noise with higher efficiency.
|
||||
|
||||
Run the following code to perform efficient best Gaussian noise matrix search using SES.
|
||||
|
||||
```python
|
||||
from diffsynth.utils.ses import ses_search
|
||||
|
||||
best_noise = ses_search(
|
||||
base_latents=noise,
|
||||
objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),
|
||||
total_eval_budget=50,
|
||||
)
|
||||
image_3 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)
|
||||
print("Score:", reward_model(image_3, prompt))
|
||||
image_3
|
||||
```
|
||||
|
||||

|
||||
|
||||
Observing the results, under the same computational budget, SES achieves significantly higher PickScore compared to random search. The "sketch cat" demonstrates more refined overall composition and more layered contrast between light and shadow.
|
||||
|
||||
Inference-time scaling can achieve higher image quality at the cost of longer inference time. The generated image data can then be used to train the model itself through methods like DPO [[4]](https://arxiv.org/abs/2311.12908) or differential training [[5]](https://arxiv.org/abs/2412.12888), opening another interesting research direction.
|
||||
@@ -65,6 +65,7 @@ Welcome to DiffSynth-Studio's Documentation
|
||||
:caption: Research Guide
|
||||
|
||||
Research_Tutorial/train_from_scratch
|
||||
Research_Tutorial/inference_time_scaling
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
@@ -80,7 +80,7 @@ graph LR;
|
||||
本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。
|
||||
|
||||
* [从零开始训练模型](./Research_Tutorial/train_from_scratch.md)
|
||||
* 推理改进优化技术【coming soon】
|
||||
* [推理改进优化技术](./Research_Tutorial/inference_time_scaling.md)
|
||||
* 设计可控生成模型【coming soon】
|
||||
* 创建新的训练范式【coming soon】
|
||||
|
||||
|
||||
236
docs/zh/Research_Tutorial/inference_time_scaling.ipynb
Normal file
236
docs/zh/Research_Tutorial/inference_time_scaling.ipynb
Normal file
@@ -0,0 +1,236 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8db54992",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 推理改进优化技术\n",
|
||||
"\n",
|
||||
"DiffSynth-Studio 旨在以基础框架驱动技术创新。本文以 Inference-time scaling 为例,展示如何基于 DiffSynth-Studio 构建免训练(Training-free)的图像生成增强方案。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0911cad4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. 图像质量量化\n",
|
||||
"\n",
|
||||
"首先,我们需要找到一个指标来量化图像生成模型生成的图像质量。最简单直接的方案是人工打分,但这样做的成本太高,无法大规模使用。不过,收集人工打分后,训练一个图像分类模型来预测人类的打分结果,是完全可行的。PickScore [[1]](https://arxiv.org/abs/2305.01569) 就是这样一个模型,运行下面的代码,将会自动下载并加载 [PickScore 模型](https://modelscope.cn/models/AI-ModelScope/PickScore_v1)。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4faca4ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from modelscope import AutoProcessor, AutoModel\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"class PickScore(torch.nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" self.processor = AutoProcessor.from_pretrained(\"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\")\n",
|
||||
" self.model = AutoModel.from_pretrained(\"AI-ModelScope/PickScore_v1\").eval().to(\"cuda\")\n",
|
||||
"\n",
|
||||
" def forward(self, image, prompt):\n",
|
||||
" image_inputs = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||
" text_inputs = self.processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||
" with torch.inference_mode():\n",
|
||||
" image_embs = self.model.get_image_features(**image_inputs).pooler_output\n",
|
||||
" image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)\n",
|
||||
" text_embs = self.model.get_text_features(**text_inputs).pooler_output\n",
|
||||
" text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)\n",
|
||||
" score = (text_embs @ image_embs.T).flatten().item()\n",
|
||||
" return score\n",
|
||||
"\n",
|
||||
"reward_model = PickScore()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5f807cec",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Inference-time Scaling 技术\n",
|
||||
"\n",
|
||||
"Inference-time Scaling [[2]](https://arxiv.org/abs/2504.00294) 是一类有趣的技术,旨在通过增加推理时的计算量来提升生成结果的质量。例如,在语言模型中,[Qwen/Qwen3.5-27B](https://modelscope.cn/models/Qwen/Qwen3.5-27B)、[deepseek-ai/DeepSeek-R1](deepseek-ai/DeepSeek-R1) 等模型通过“思考模式”引导模型花更多时间仔细思考,让回答结果更准确。接下来我们以模型 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 为例,探讨如何为图像生成模型设计 Inference-time Scaling 方案。\n",
|
||||
"\n",
|
||||
"> 在开始前,我们稍微改造了 `Flux2ImagePipeline` 的代码,使其能够根据输入的特定高斯噪声矩阵进行初始化,便于复现结果,详见 [diffsynth/pipelines/flux2_image.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/flux2_image.py) 中的 `Flux2Unit_NoiseInitializer`。\n",
|
||||
"\n",
|
||||
"运行以下代码,加载模型 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c5818a87",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\n",
|
||||
"\n",
|
||||
"pipe = Flux2ImagePipeline.from_pretrained(\n",
|
||||
" torch_dtype=torch.bfloat16,\n",
|
||||
" device=\"cuda\",\n",
|
||||
" model_configs=[\n",
|
||||
" ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n",
|
||||
" ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n",
|
||||
" ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n",
|
||||
" ],\n",
|
||||
" tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f58e9945",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"用提示词 `\"sketch, a cat\"` 生成一只素描猫猫,并用 PickScore 模型打分。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ea2d258",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def evaluate_noise(noise, pipe, reward_model, prompt):\n",
|
||||
" # Generate an image and compute the score.\n",
|
||||
" image = pipe(\n",
|
||||
" prompt=prompt,\n",
|
||||
" num_inference_steps=4,\n",
|
||||
" initial_noise=noise,\n",
|
||||
" progress_bar_cmd=lambda x: x,\n",
|
||||
" )\n",
|
||||
" score = reward_model(image, prompt)\n",
|
||||
" return score\n",
|
||||
"\n",
|
||||
"torch.manual_seed(1)\n",
|
||||
"prompt = \"sketch, a cat\"\n",
|
||||
"noise = pipe.generate_noise((1, 128, 64, 64), rand_device=\"cuda\", rand_torch_dtype=pipe.torch_dtype)\n",
|
||||
"\n",
|
||||
"image_1 = pipe(prompt, num_inference_steps=4, initial_noise=noise)\n",
|
||||
"print(\"Score:\", reward_model(image_1, prompt))\n",
|
||||
"image_1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5e11694e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2.1 Best-of-N 随机搜索\n",
|
||||
"\n",
|
||||
"模型的生成结果具有一定的随机性,如果用不同的随机种子,生成的图像结果也是不同的,有时图像质量高,有时图像质量低。那么,我们有一个简单的 Inference-time scaling 方案:使用多个不同的随机种子分别生成图像,然后利用 PickScore 进行打分,只保留分数最高的那一张。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "241f10d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"def random_search(base_latents, objective_reward_fn, total_eval_budget):\n",
|
||||
" # Search for the noise randomly.\n",
|
||||
" best_noise = base_latents\n",
|
||||
" best_score = objective_reward_fn(base_latents)\n",
|
||||
" for it in tqdm(range(total_eval_budget - 1)):\n",
|
||||
" noise = pipe.generate_noise((1, 128, 64, 64), seed=None)\n",
|
||||
" score = objective_reward_fn(noise)\n",
|
||||
" if score > best_score:\n",
|
||||
" best_score, best_noise = score, noise\n",
|
||||
" return best_noise\n",
|
||||
"\n",
|
||||
"best_noise = random_search(\n",
|
||||
" base_latents=noise,\n",
|
||||
" objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\n",
|
||||
" total_eval_budget=50,\n",
|
||||
")\n",
|
||||
"image_2 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\n",
|
||||
"print(\"Score:\", reward_model(image_2, prompt))\n",
|
||||
"image_2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8e9bf966",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"我们可以清晰地看到,经过多次随机搜索后,最终选出的猫猫毛发细节更加丰富,PickScore 分数也有明显提升。但这种暴力的随机搜索效率极低,生成时间成倍增长,且很容易触及质量上限。因此,我们希望能够找到一种更高效的搜索方法,在同等计算预算下达到更高的分数。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c9578349",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2.2 SES 搜索\n",
|
||||
"\n",
|
||||
"为了突破随机搜索的瓶颈,我们引入了 SES (Spectral Evolution Search) 算法 [[3]](https://arxiv.org/abs/2602.03208),详细的代码位于 [diffsynth/utils/ses](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/utils/ses)。\n",
|
||||
"\n",
|
||||
"扩散模型生成的图像,很大程度上由初始噪声的低频分量决定。SES 算法通过小波变换将高斯噪声分解,固定高频细节,专门针对低频部分使用交叉熵方法进行演化搜索,能以更高的效率找到优质的初始噪声。\n",
|
||||
"\n",
|
||||
"运行下面的代码,即可使用 SES 更高效地搜索最佳的高斯噪声矩阵。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "adeed2aa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from diffsynth.utils.ses import ses_search\n",
|
||||
"\n",
|
||||
"best_noise = ses_search(\n",
|
||||
" base_latents=noise,\n",
|
||||
" objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\n",
|
||||
" total_eval_budget=50,\n",
|
||||
")\n",
|
||||
"image_3 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\n",
|
||||
"print(\"Score:\", reward_model(image_3, prompt))\n",
|
||||
"image_3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "940a97f1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"可以观察到,在同样的计算预算下,相比于随机搜索,SES 的结果在 PickScore 得分上取得了显著的提升。“素描猫猫”展现出了更精致的整体构图以及更具层次感的明暗对比。\n",
|
||||
"\n",
|
||||
"Inference-time scaling 能够以更长推理时间为代价获得更高的图像质量,那么它生成的图像数据也可以用 DPO [[4]](https://arxiv.org/abs/2311.12908)、差分训练 [[5]](https://arxiv.org/abs/2412.12888) 等方式赋予模型自身,那就是另外一个有趣的探索方向了。"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "dzj8",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.19"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
140
docs/zh/Research_Tutorial/inference_time_scaling.md
Normal file
140
docs/zh/Research_Tutorial/inference_time_scaling.md
Normal file
@@ -0,0 +1,140 @@
|
||||
# 推理改进优化技术
|
||||
|
||||
DiffSynth-Studio 旨在以基础框架驱动技术创新。本文以 Inference-time scaling 为例,展示如何基于 DiffSynth-Studio 构建免训练(Training-free)的图像生成增强方案。
|
||||
|
||||
Notebook: https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/zh/Research_Tutorial/inference_time_scaling.ipynb
|
||||
|
||||
## 1. 图像质量量化
|
||||
|
||||
首先,我们需要找到一个指标来量化图像生成模型生成的图像质量。最简单直接的方案是人工打分,但这样做的成本太高,无法大规模使用。不过,收集人工打分后,训练一个图像分类模型来预测人类的打分结果,是完全可行的。PickScore [[1]](https://arxiv.org/abs/2305.01569) 就是这样一个模型,运行下面的代码,将会自动下载并加载 [PickScore 模型](https://modelscope.cn/models/AI-ModelScope/PickScore_v1)。
|
||||
|
||||
```python
|
||||
from modelscope import AutoProcessor, AutoModel
|
||||
import torch
|
||||
|
||||
class PickScore(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||
self.model = AutoModel.from_pretrained("AI-ModelScope/PickScore_v1").eval().to("cuda")
|
||||
|
||||
def forward(self, image, prompt):
|
||||
image_inputs = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors="pt").to("cuda")
|
||||
text_inputs = self.processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors="pt").to("cuda")
|
||||
with torch.inference_mode():
|
||||
image_embs = self.model.get_image_features(**image_inputs).pooler_output
|
||||
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
||||
text_embs = self.model.get_text_features(**text_inputs).pooler_output
|
||||
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
||||
score = (text_embs @ image_embs.T).flatten().item()
|
||||
return score
|
||||
|
||||
reward_model = PickScore()
|
||||
```
|
||||
|
||||
## 2. Inference-time Scaling 技术
|
||||
|
||||
Inference-time Scaling [[2]](https://arxiv.org/abs/2504.00294) 是一类有趣的技术,旨在通过增加推理时的计算量来提升生成结果的质量。例如,在语言模型中,[Qwen/Qwen3.5-27B](https://modelscope.cn/models/Qwen/Qwen3.5-27B)、[deepseek-ai/DeepSeek-R1](deepseek-ai/DeepSeek-R1) 等模型通过“思考模式”引导模型花更多时间仔细思考,让回答结果更准确。接下来我们以模型 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 为例,探讨如何为图像生成模型设计 Inference-time Scaling 方案。
|
||||
|
||||
> 在开始前,我们稍微改造了 `Flux2ImagePipeline` 的代码,使其能够根据输入的特定高斯噪声矩阵进行初始化,便于复现结果,详见 [diffsynth/pipelines/flux2_image.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/flux2_image.py) 中的 `Flux2Unit_NoiseInitializer`。
|
||||
|
||||
运行以下代码,加载模型 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
```
|
||||
|
||||
用提示词 `"sketch, a cat"` 生成一只素描猫猫,并用 PickScore 模型打分。
|
||||
|
||||
```python
|
||||
def evaluate_noise(noise, pipe, reward_model, prompt):
|
||||
# Generate an image and compute the score.
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
initial_noise=noise,
|
||||
progress_bar_cmd=lambda x: x,
|
||||
)
|
||||
score = reward_model(image, prompt)
|
||||
return score
|
||||
|
||||
torch.manual_seed(1)
|
||||
prompt = "sketch, a cat"
|
||||
noise = pipe.generate_noise((1, 128, 64, 64), rand_device="cuda", rand_torch_dtype=pipe.torch_dtype)
|
||||
|
||||
image_1 = pipe(prompt, num_inference_steps=4, initial_noise=noise)
|
||||
print("Score:", reward_model(image_1, prompt))
|
||||
image_1
|
||||
```
|
||||
|
||||

|
||||
|
||||
### 2.1 Best-of-N 随机搜索
|
||||
|
||||
模型的生成结果具有一定的随机性,如果用不同的随机种子,生成的图像结果也是不同的,有时图像质量高,有时图像质量低。那么,我们有一个简单的 Inference-time scaling 方案:使用多个不同的随机种子分别生成图像,然后利用 PickScore 进行打分,只保留分数最高的那一张。
|
||||
|
||||
```python
|
||||
from tqdm import tqdm
|
||||
|
||||
def random_search(base_latents, objective_reward_fn, total_eval_budget):
|
||||
# Search for the noise randomly.
|
||||
best_noise = base_latents
|
||||
best_score = objective_reward_fn(base_latents)
|
||||
for it in tqdm(range(total_eval_budget - 1)):
|
||||
noise = pipe.generate_noise((1, 128, 64, 64), seed=None)
|
||||
score = objective_reward_fn(noise)
|
||||
if score > best_score:
|
||||
best_score, best_noise = score, noise
|
||||
return best_noise
|
||||
|
||||
best_noise = random_search(
|
||||
base_latents=noise,
|
||||
objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),
|
||||
total_eval_budget=50,
|
||||
)
|
||||
image_2 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)
|
||||
print("Score:", reward_model(image_2, prompt))
|
||||
image_2
|
||||
```
|
||||
|
||||

|
||||
|
||||
我们可以清晰地看到,经过多次随机搜索后,最终选出的猫猫毛发细节更加丰富,PickScore 分数也有明显提升。但这种暴力的随机搜索效率极低,生成时间成倍增长,且很容易触及质量上限。因此,我们希望能够找到一种更高效的搜索方法,在同等计算预算下达到更高的分数。
|
||||
|
||||
### 2.2 SES 搜索
|
||||
|
||||
为了突破随机搜索的瓶颈,我们引入了 SES (Spectral Evolution Search) 算法 [[3]](https://arxiv.org/abs/2602.03208),详细的代码位于 [diffsynth/utils/ses](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/utils/ses)。
|
||||
|
||||
扩散模型生成的图像,很大程度上由初始噪声的低频分量决定。SES 算法通过小波变换将高斯噪声分解,固定高频细节,专门针对低频部分使用交叉熵方法进行演化搜索,能以更高的效率找到优质的初始噪声。
|
||||
|
||||
运行下面的代码,即可使用 SES 更高效地搜索最佳的高斯噪声矩阵。
|
||||
|
||||
```python
|
||||
from diffsynth.utils.ses import ses_search
|
||||
|
||||
best_noise = ses_search(
|
||||
base_latents=noise,
|
||||
objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),
|
||||
total_eval_budget=50,
|
||||
)
|
||||
image_3 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)
|
||||
print("Score:", reward_model(image_3, prompt))
|
||||
image_3
|
||||
```
|
||||
|
||||

|
||||
|
||||
可以观察到,在同样的计算预算下,相比于随机搜索,SES 的结果在 PickScore 得分上取得了显著的提升。“素描猫猫”展现出了更精致的整体构图以及更具层次感的明暗对比。
|
||||
|
||||
Inference-time scaling 能够以更长推理时间为代价获得更高的图像质量,那么它生成的图像数据也可以用 DPO [[4]](https://arxiv.org/abs/2311.12908)、差分训练 [[5]](https://arxiv.org/abs/2412.12888) 等方式赋予模型自身,那就是另外一个有趣的探索方向了。
|
||||
@@ -65,6 +65,7 @@
|
||||
:caption: 学术导引
|
||||
|
||||
Research_Tutorial/train_from_scratch
|
||||
Research_Tutorial/inference_time_scaling
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
Reference in New Issue
Block a user