mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
qwen-image-interpolate
This commit is contained in:
@@ -279,6 +279,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
tile_stride: int = 64,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
extra_prompt_emb = None,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16))
|
||||
@@ -304,6 +305,9 @@ class QwenImagePipeline(BasePipeline):
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
if extra_prompt_emb is not None:
|
||||
inputs_posi["prompt_emb"] = torch.concat([inputs_posi["prompt_emb"], extra_prompt_emb], dim=1)
|
||||
inputs_posi["prompt_emb_mask"] = torch.ones((1, inputs_posi["prompt_emb"].shape[1]), dtype=inputs_posi["prompt_emb_mask"].dtype, device=inputs_posi["prompt_emb_mask"].device)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
|
||||
119
test_interpolate.py
Normal file
119
test_interpolate.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder, load_state_dict
|
||||
import torch, os
|
||||
from tqdm import tqdm
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
|
||||
class ValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
|
||||
super().__init__()
|
||||
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
|
||||
self.proj_value = torch.nn.Linear(dim_in, dim_out)
|
||||
self.proj_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.value_emb_length = value_emb_length
|
||||
|
||||
def forward(self, value):
|
||||
value = value * 1
|
||||
emb = self.value_emb(value).to(value.dtype)
|
||||
emb = self.proj_value(emb)
|
||||
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
|
||||
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
|
||||
emb = torch.nn.functional.silu(emb)
|
||||
emb = self.proj_out(emb)
|
||||
return emb
|
||||
|
||||
|
||||
class TextInterpolationModel(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
|
||||
super().__init__()
|
||||
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
|
||||
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, value, x, y):
|
||||
q = self.to_q(value)
|
||||
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
|
||||
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
|
||||
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
out = rearrange(out, 'b h s d -> b s (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
unit = QwenImageUnit_PromptEmbedder()
|
||||
|
||||
dataset_prompt = [
|
||||
(
|
||||
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
|
||||
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
|
||||
),
|
||||
]
|
||||
dataset_tensors = []
|
||||
for prompt_x, prompt_y in tqdm(dataset_prompt):
|
||||
with torch.no_grad():
|
||||
x = unit.process(pipe, prompt_x)["prompt_emb"]
|
||||
y = unit.process(pipe, prompt_y)["prompt_emb"]
|
||||
dataset_tensors.append((x, y))
|
||||
|
||||
model = TextInterpolationModel().to(dtype=torch.bfloat16, device="cuda")
|
||||
model.load_state_dict(load_state_dict("models/interpolate.pth"))
|
||||
|
||||
def sample_tokens(emb, p):
|
||||
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
|
||||
return emb[:, perm]
|
||||
|
||||
|
||||
def loss_fn(x, y):
|
||||
s, l = x.shape[1], y.shape[1]
|
||||
x = repeat(x, "b s d -> b s l d", l=l)
|
||||
y = repeat(y, "b l d -> b s l d", s=s)
|
||||
d = torch.square(x - y).mean(dim=-1)
|
||||
loss_x = d.min(dim=1).values.mean()
|
||||
loss_y = d.min(dim=2).values.mean()
|
||||
return loss_x + loss_y
|
||||
|
||||
|
||||
def get_target(x, y, p):
|
||||
x = sample_tokens(x, 1-p)
|
||||
y = sample_tokens(y, p)
|
||||
return torch.concat([x, y], dim=1)
|
||||
|
||||
name = "brightness"
|
||||
for i in range(6):
|
||||
v = i/5
|
||||
with torch.no_grad():
|
||||
data_id = 0
|
||||
x, y = dataset_tensors[data_id]
|
||||
x, y = x.to("cuda"), y.to("cuda")
|
||||
value = torch.tensor([v], dtype=torch.bfloat16, device="cuda")
|
||||
value_emb = model(value, x, y)
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40, extra_prompt_emb=value_emb)
|
||||
os.makedirs(f"data/qwen_image_value/{name}", exist_ok=True)
|
||||
image.save(f"data/qwen_image_value/{name}/image_{v}.jpg")
|
||||
121
train_interpolate.py
Normal file
121
train_interpolate.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
|
||||
class ValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
|
||||
super().__init__()
|
||||
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
|
||||
self.proj_value = torch.nn.Linear(dim_in, dim_out)
|
||||
self.proj_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.value_emb_length = value_emb_length
|
||||
|
||||
def forward(self, value):
|
||||
value = value * 1
|
||||
emb = self.value_emb(value).to(value.dtype)
|
||||
emb = self.proj_value(emb)
|
||||
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
|
||||
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
|
||||
emb = torch.nn.functional.silu(emb)
|
||||
emb = self.proj_out(emb)
|
||||
return emb
|
||||
|
||||
|
||||
class TextInterpolationModel(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
|
||||
super().__init__()
|
||||
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
|
||||
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, value, x, y):
|
||||
q = self.to_q(value)
|
||||
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
|
||||
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
|
||||
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
out = rearrange(out, 'b h s d -> b s (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
def sample_tokens(emb, p):
|
||||
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
|
||||
return emb[:, perm]
|
||||
|
||||
|
||||
def loss_fn(x, y):
|
||||
s, l = x.shape[1], y.shape[1]
|
||||
x = repeat(x, "b s d -> b s l d", l=l)
|
||||
y = repeat(y, "b l d -> b s l d", s=s)
|
||||
d = torch.square(x - y).mean(dim=-1)
|
||||
loss_x = d.min(dim=1).values.mean()
|
||||
loss_y = d.min(dim=2).values.mean()
|
||||
return loss_x + loss_y
|
||||
|
||||
|
||||
def get_target(x, y, p):
|
||||
x = sample_tokens(x, 1-p)
|
||||
y = sample_tokens(y, p)
|
||||
return torch.concat([x, y], dim=1)
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
unit = QwenImageUnit_PromptEmbedder()
|
||||
|
||||
|
||||
dataset_prompt = [
|
||||
(
|
||||
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
|
||||
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
|
||||
),
|
||||
]
|
||||
|
||||
dataset_tensors = []
|
||||
for prompt_x, prompt_y in tqdm(dataset_prompt):
|
||||
with torch.no_grad():
|
||||
x = unit.process(pipe, prompt_x)["prompt_emb"].to(dtype=torch.float32, device="cpu")
|
||||
y = unit.process(pipe, prompt_y)["prompt_emb"].to(dtype=torch.float32, device="cpu")
|
||||
dataset_tensors.append((x, y))
|
||||
|
||||
model = TextInterpolationModel().to(dtype=torch.float32, device="cuda")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
|
||||
|
||||
for step_id, step in enumerate(tqdm(range(100000))):
|
||||
optimizer.zero_grad()
|
||||
|
||||
data_id = torch.randint(0, len(dataset_tensors), size=(1,)).item()
|
||||
x, y = dataset_tensors[data_id]
|
||||
x, y = x.to("cuda"), y.to("cuda")
|
||||
|
||||
value = torch.rand((1,), dtype=torch.float32, device="cuda")
|
||||
out = model(value, x, y)
|
||||
loss = loss_fn(out, x) * (1 - value) + loss_fn(out, y) * value
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if (step_id + 1) % 1000 == 0:
|
||||
print(loss)
|
||||
|
||||
torch.save(model.state_dict(), f"models/interpolate_{step+1}.pth")
|
||||
Reference in New Issue
Block a user