qwen-image-interpolate

This commit is contained in:
Artiprocher
2025-08-19 10:50:56 +08:00
parent f11a91e610
commit 4abc2c1cd1
3 changed files with 244 additions and 0 deletions

View File

@@ -279,6 +279,7 @@ class QwenImagePipeline(BasePipeline):
tile_stride: int = 64,
# Progress bar
progress_bar_cmd = tqdm,
extra_prompt_emb = None,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16))
@@ -304,6 +305,9 @@ class QwenImagePipeline(BasePipeline):
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
if extra_prompt_emb is not None:
inputs_posi["prompt_emb"] = torch.concat([inputs_posi["prompt_emb"], extra_prompt_emb], dim=1)
inputs_posi["prompt_emb_mask"] = torch.ones((1, inputs_posi["prompt_emb"].shape[1]), dtype=inputs_posi["prompt_emb_mask"].dtype, device=inputs_posi["prompt_emb_mask"].device)
# Denoise
self.load_models_to_device(self.in_iteration_models)

119
test_interpolate.py Normal file
View File

@@ -0,0 +1,119 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder, load_state_dict
import torch, os
from tqdm import tqdm
from diffsynth.models.svd_unet import TemporalTimesteps
from einops import rearrange, repeat
class ValueEncoder(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
super().__init__()
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
self.proj_value = torch.nn.Linear(dim_in, dim_out)
self.proj_out = torch.nn.Linear(dim_out, dim_out)
self.value_emb_length = value_emb_length
def forward(self, value):
value = value * 1
emb = self.value_emb(value).to(value.dtype)
emb = self.proj_value(emb)
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
emb = torch.nn.functional.silu(emb)
emb = self.proj_out(emb)
return emb
class TextInterpolationModel(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
super().__init__()
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
self.to_out = torch.nn.Linear(dim_out, dim_out)
self.num_heads = num_heads
def forward(self, value, x, y):
q = self.to_q(value)
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
out = rearrange(out, 'b h s d -> b s (h d)')
out = self.to_out(out)
return out
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
unit = QwenImageUnit_PromptEmbedder()
dataset_prompt = [
(
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
),
]
dataset_tensors = []
for prompt_x, prompt_y in tqdm(dataset_prompt):
with torch.no_grad():
x = unit.process(pipe, prompt_x)["prompt_emb"]
y = unit.process(pipe, prompt_y)["prompt_emb"]
dataset_tensors.append((x, y))
model = TextInterpolationModel().to(dtype=torch.bfloat16, device="cuda")
model.load_state_dict(load_state_dict("models/interpolate.pth"))
def sample_tokens(emb, p):
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
return emb[:, perm]
def loss_fn(x, y):
s, l = x.shape[1], y.shape[1]
x = repeat(x, "b s d -> b s l d", l=l)
y = repeat(y, "b l d -> b s l d", s=s)
d = torch.square(x - y).mean(dim=-1)
loss_x = d.min(dim=1).values.mean()
loss_y = d.min(dim=2).values.mean()
return loss_x + loss_y
def get_target(x, y, p):
x = sample_tokens(x, 1-p)
y = sample_tokens(y, p)
return torch.concat([x, y], dim=1)
name = "brightness"
for i in range(6):
v = i/5
with torch.no_grad():
data_id = 0
x, y = dataset_tensors[data_id]
x, y = x.to("cuda"), y.to("cuda")
value = torch.tensor([v], dtype=torch.bfloat16, device="cuda")
value_emb = model(value, x, y)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40, extra_prompt_emb=value_emb)
os.makedirs(f"data/qwen_image_value/{name}", exist_ok=True)
image.save(f"data/qwen_image_value/{name}/image_{v}.jpg")

121
train_interpolate.py Normal file
View File

@@ -0,0 +1,121 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder
import torch
from tqdm import tqdm
from diffsynth.models.svd_unet import TemporalTimesteps
from einops import rearrange, repeat
class ValueEncoder(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
super().__init__()
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
self.proj_value = torch.nn.Linear(dim_in, dim_out)
self.proj_out = torch.nn.Linear(dim_out, dim_out)
self.value_emb_length = value_emb_length
def forward(self, value):
value = value * 1
emb = self.value_emb(value).to(value.dtype)
emb = self.proj_value(emb)
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
emb = torch.nn.functional.silu(emb)
emb = self.proj_out(emb)
return emb
class TextInterpolationModel(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
super().__init__()
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
self.to_out = torch.nn.Linear(dim_out, dim_out)
self.num_heads = num_heads
def forward(self, value, x, y):
q = self.to_q(value)
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
out = rearrange(out, 'b h s d -> b s (h d)')
out = self.to_out(out)
return out
def sample_tokens(emb, p):
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
return emb[:, perm]
def loss_fn(x, y):
s, l = x.shape[1], y.shape[1]
x = repeat(x, "b s d -> b s l d", l=l)
y = repeat(y, "b l d -> b s l d", s=s)
d = torch.square(x - y).mean(dim=-1)
loss_x = d.min(dim=1).values.mean()
loss_y = d.min(dim=2).values.mean()
return loss_x + loss_y
def get_target(x, y, p):
x = sample_tokens(x, 1-p)
y = sample_tokens(y, p)
return torch.concat([x, y], dim=1)
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
unit = QwenImageUnit_PromptEmbedder()
dataset_prompt = [
(
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
),
]
dataset_tensors = []
for prompt_x, prompt_y in tqdm(dataset_prompt):
with torch.no_grad():
x = unit.process(pipe, prompt_x)["prompt_emb"].to(dtype=torch.float32, device="cpu")
y = unit.process(pipe, prompt_y)["prompt_emb"].to(dtype=torch.float32, device="cpu")
dataset_tensors.append((x, y))
model = TextInterpolationModel().to(dtype=torch.float32, device="cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for step_id, step in enumerate(tqdm(range(100000))):
optimizer.zero_grad()
data_id = torch.randint(0, len(dataset_tensors), size=(1,)).item()
x, y = dataset_tensors[data_id]
x, y = x.to("cuda"), y.to("cuda")
value = torch.rand((1,), dtype=torch.float32, device="cuda")
out = model(value, x, y)
loss = loss_fn(out, x) * (1 - value) + loss_fn(out, y) * value
loss.backward()
optimizer.step()
if (step_id + 1) % 1000 == 0:
print(loss)
torch.save(model.state_dict(), f"models/interpolate_{step+1}.pth")