mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
122 lines
4.6 KiB
Python
122 lines
4.6 KiB
Python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder
|
|
import torch
|
|
from tqdm import tqdm
|
|
from diffsynth.models.svd_unet import TemporalTimesteps
|
|
from einops import rearrange, repeat
|
|
|
|
|
|
|
|
class ValueEncoder(torch.nn.Module):
|
|
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
|
|
super().__init__()
|
|
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
|
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
|
|
self.proj_value = torch.nn.Linear(dim_in, dim_out)
|
|
self.proj_out = torch.nn.Linear(dim_out, dim_out)
|
|
self.value_emb_length = value_emb_length
|
|
|
|
def forward(self, value):
|
|
value = value * 1
|
|
emb = self.value_emb(value).to(value.dtype)
|
|
emb = self.proj_value(emb)
|
|
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
|
|
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
|
|
emb = torch.nn.functional.silu(emb)
|
|
emb = self.proj_out(emb)
|
|
return emb
|
|
|
|
|
|
class TextInterpolationModel(torch.nn.Module):
|
|
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
|
|
super().__init__()
|
|
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
|
|
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
|
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
|
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
|
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
|
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
|
|
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
|
|
self.to_out = torch.nn.Linear(dim_out, dim_out)
|
|
self.num_heads = num_heads
|
|
|
|
def forward(self, value, x, y):
|
|
q = self.to_q(value)
|
|
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
|
|
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
|
|
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
|
|
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
|
|
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
|
out = rearrange(out, 'b h s d -> b s (h d)')
|
|
out = self.to_out(out)
|
|
return out
|
|
|
|
|
|
def sample_tokens(emb, p):
|
|
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
|
|
return emb[:, perm]
|
|
|
|
|
|
def loss_fn(x, y):
|
|
s, l = x.shape[1], y.shape[1]
|
|
x = repeat(x, "b s d -> b s l d", l=l)
|
|
y = repeat(y, "b l d -> b s l d", s=s)
|
|
d = torch.square(x - y).mean(dim=-1)
|
|
loss_x = d.min(dim=1).values.mean()
|
|
loss_y = d.min(dim=2).values.mean()
|
|
return loss_x + loss_y
|
|
|
|
|
|
def get_target(x, y, p):
|
|
x = sample_tokens(x, 1-p)
|
|
y = sample_tokens(y, p)
|
|
return torch.concat([x, y], dim=1)
|
|
|
|
|
|
pipe = QwenImagePipeline.from_pretrained(
|
|
torch_dtype=torch.bfloat16,
|
|
device="cuda",
|
|
model_configs=[
|
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
],
|
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
)
|
|
unit = QwenImageUnit_PromptEmbedder()
|
|
|
|
|
|
dataset_prompt = [
|
|
(
|
|
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
|
|
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
|
|
),
|
|
]
|
|
|
|
dataset_tensors = []
|
|
for prompt_x, prompt_y in tqdm(dataset_prompt):
|
|
with torch.no_grad():
|
|
x = unit.process(pipe, prompt_x)["prompt_emb"].to(dtype=torch.float32, device="cpu")
|
|
y = unit.process(pipe, prompt_y)["prompt_emb"].to(dtype=torch.float32, device="cpu")
|
|
dataset_tensors.append((x, y))
|
|
|
|
model = TextInterpolationModel().to(dtype=torch.float32, device="cuda")
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
|
|
|
|
for step_id, step in enumerate(tqdm(range(100000))):
|
|
optimizer.zero_grad()
|
|
|
|
data_id = torch.randint(0, len(dataset_tensors), size=(1,)).item()
|
|
x, y = dataset_tensors[data_id]
|
|
x, y = x.to("cuda"), y.to("cuda")
|
|
|
|
value = torch.rand((1,), dtype=torch.float32, device="cuda")
|
|
out = model(value, x, y)
|
|
loss = loss_fn(out, x) * (1 - value) + loss_fn(out, y) * value
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if (step_id + 1) % 1000 == 0:
|
|
print(loss)
|
|
|
|
torch.save(model.state_dict(), f"models/interpolate_{step+1}.pth")
|