diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 453755f..5c68827 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -48,6 +48,7 @@ from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder from ..extensions.RIFE import IFNet from ..extensions.ESRGAN import RRDBNet +from ..models.hunyuan_video_dit import HunyuanVideoDiT model_loader_configs = [ @@ -97,6 +98,7 @@ model_loader_configs = [ (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), (None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"), (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder"], [HunyuanVideoVAEDecoder], "diffusers"), + (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/hunyuan_video_dit.py b/diffsynth/models/hunyuan_video_dit.py new file mode 100644 index 0000000..761c740 --- /dev/null +++ b/diffsynth/models/hunyuan_video_dit.py @@ -0,0 +1,695 @@ +import torch +from .sd3_dit import TimestepEmbeddings, RMSNorm +from .utils import init_weights_on_device +from einops import rearrange, repeat +from tqdm import tqdm + + +class PatchEmbed(torch.nn.Module): + def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072): + super().__init__() + self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class IndividualTokenRefinerBlock(torch.nn.Module): + def __init__(self, hidden_size=3072, num_heads=24): + super().__init__() + self.num_heads = num_heads + self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3) + self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size) + + self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.mlp = torch.nn.Sequential( + torch.nn.Linear(hidden_size, hidden_size * 4), + torch.nn.SiLU(), + torch.nn.Linear(hidden_size * 4, hidden_size) + ) + self.adaLN_modulation = torch.nn.Sequential( + torch.nn.SiLU(), + torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16), + ) + + def forward(self, x, c, attn_mask=None): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + + attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + attn = rearrange(attn, "B H L D -> B L (H D)") + + x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1) + x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1) + + return x + + +class SingleTokenRefiner(torch.nn.Module): + def __init__(self, in_channels=4096, hidden_size=3072, depth=2): + super().__init__() + self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu") + self.c_embedder = torch.nn.Sequential( + torch.nn.Linear(in_channels, hidden_size), + torch.nn.SiLU(), + torch.nn.Linear(hidden_size, hidden_size) + ) + self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)]) + + def forward(self, x, t, mask=None): + timestep_aware_representations = self.t_embedder(t, dtype=torch.float32) + + mask_float = mask.float().unsqueeze(-1) + context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1) + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + + x = self.input_embedder(x) + + mask = mask.to(device=x.device, dtype=torch.bool) + mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1]) + mask = mask & mask.transpose(2, 3) + mask[:, :, :, 0] = True + + for block in self.blocks: + x = block(x, c, mask) + + return x + + +class ModulateDiT(torch.nn.Module): + def __init__(self, hidden_size, factor=6): + super().__init__() + self.act = torch.nn.SiLU() + self.linear = torch.nn.Linear(hidden_size, factor * hidden_size) + + def forward(self, x): + return self.linear(self.act(x)) + + +def modulate(x, shift=None, scale=None): + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale.unsqueeze(1)) + elif scale is None: + return x + shift.unsqueeze(1) + else: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def reshape_for_broadcast( + freqs_cis, + x: torch.Tensor, + head_first=False, +): + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [ + d if i == ndim - 2 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] + else: + assert freqs_cis[0].shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [ + d if i == ndim - 2 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] + else: + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = ( + x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + ) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis, + head_first: bool = False, +): + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + # real * cos - imag * sin + # imag * cos + real * sin + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + else: + # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) + xq_ = torch.view_as_complex( + xq.float().reshape(*xq.shape[:-1], -1, 2) + ) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( + xq.device + ) # [S, D//2] --> [1, S, 1, D//2] + # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) + # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + xk_ = torch.view_as_complex( + xk.float().reshape(*xk.shape[:-1], -1, 2) + ) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + +def attention(q, k, v): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = x.transpose(1, 2).flatten(2, 3) + return x + + +class MMDoubleStreamBlockComponent(torch.nn.Module): + def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): + super().__init__() + self.heads_num = heads_num + + self.mod = ModulateDiT(hidden_size) + self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3) + self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6) + self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6) + self.to_out = torch.nn.Linear(hidden_size, hidden_size) + + self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = torch.nn.Sequential( + torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size) + ) + + def forward(self, hidden_states, conditioning, freqs_cis=None): + mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1) + + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale) + qkv = self.to_qkv(norm_hidden_states) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + + q = self.norm_q(q) + k = self.norm_k(k) + + if freqs_cis is not None: + q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False) + + return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate) + + def process_ff(self, hidden_states, attn_output, mod): + mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod + hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1) + hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1) + return hidden_states + + +class MMDoubleStreamBlock(torch.nn.Module): + def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): + super().__init__() + self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio) + self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio) + + def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis): + (q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis) + (q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None) + + q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous() + k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous() + v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous() + attn_output_a = attention(q_a, k_a, v_a) + attn_output_b = attention(q_b, k_b, v_b) + attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1) + + hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a) + hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b) + return hidden_states_a, hidden_states_b + + +class MMSingleStreamBlockOriginal(torch.nn.Module): + def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): + super().__init__() + self.hidden_size = hidden_size + self.heads_num = heads_num + self.mlp_hidden_dim = hidden_size * mlp_width_ratio + + self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6) + self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6) + + self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = torch.nn.GELU(approximate="tanh") + self.modulation = ModulateDiT(hidden_size, factor=3) + + def forward(self, x, vec, freqs_cis=None, txt_len=256): + mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) + x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + q = self.q_norm(q) + k = self.k_norm(k) + + q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] + k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] + q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False) + q = torch.cat((q_a, q_b), dim=1) + k = torch.cat((k_a, k_b), dim=1) + + attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous()) + attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous()) + attn_output = torch.concat([attn_output_a, attn_output_b], dim=1) + + output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2)) + return x + output * mod_gate.unsqueeze(1) + + +class MMSingleStreamBlock(torch.nn.Module): + def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): + super().__init__() + self.heads_num = heads_num + + self.mod = ModulateDiT(hidden_size, factor=3) + self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3) + self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6) + self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6) + self.to_out = torch.nn.Linear(hidden_size, hidden_size) + + self.ff = torch.nn.Sequential( + torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False) + ) + + def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256): + mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1) + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale) + qkv = self.to_qkv(norm_hidden_states) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + + q = self.norm_q(q) + k = self.norm_k(k) + + q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] + k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] + q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False) + + q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous() + k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous() + v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous() + + attn_output_a = attention(q_a, k_a, v_a) + attn_output_b = attention(q_b, k_b, v_b) + attn_output = torch.concat([attn_output_a, attn_output_b], dim=1) + + hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1) + hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1) + return hidden_states + + +class FinalLayer(torch.nn.Module): + def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16): + super().__init__() + + self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels) + + self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift=shift, scale=scale) + x = self.linear(x) + return x + + +class HunyuanVideoDiT(torch.nn.Module): + def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40): + super().__init__() + self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size) + self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size) + self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") + self.vector_in = torch.nn.Sequential( + torch.nn.Linear(768, hidden_size), + torch.nn.SiLU(), + torch.nn.Linear(hidden_size, hidden_size) + ) + self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") + self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)]) + self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)]) + self.final_layer = FinalLayer(hidden_size) + + # TODO: remove these parameters + self.dtype = torch.bfloat16 + self.patch_size = [1, 2, 2] + self.hidden_size = 3072 + self.heads_num = 24 + self.rope_dim_list = [16, 56, 56] + + def unpatchify(self, x, T, H, W): + x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2) + return x + + def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"): + self.warm_device = warm_device + self.cold_device = cold_device + self.to(self.cold_device) + + def load_models_to_device(self, loadmodel_names=[], device="cpu"): + for model_name in loadmodel_names: + model = getattr(self, model_name) + if model is not None: + model.to(device) + torch.cuda.empty_cache() + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + text_states: torch.Tensor = None, + text_mask: torch.Tensor = None, + text_states_2: torch.Tensor = None, + freqs_cos: torch.Tensor = None, + freqs_sin: torch.Tensor = None, + guidance: torch.Tensor = None, + **kwargs + ): + B, C, T, H, W = x.shape + + vec = self.time_in(t, dtype=torch.float32) + self.vector_in(text_states_2) + self.guidance_in(guidance, dtype=torch.float32) + img = self.img_in(x) + txt = self.txt_in(text_states, t, text_mask) + + for block in tqdm(self.double_blocks, desc="Double stream blocks"): + img, txt = block(img, txt, vec, (freqs_cos, freqs_sin)) + + x = torch.concat([img, txt], dim=1) + for block in tqdm(self.single_blocks, desc="Single stream blocks"): + x = block(x, vec, (freqs_cos, freqs_sin)) + + img = x[:, :-256] + img = self.final_layer(img, vec) + img = self.unpatchify(img, T=T//1, H=H//2, W=W//2) + return img + + + def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"): + def cast_to(weight, dtype=None, device=None, copy=False): + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + return weight.to(dtype=dtype, copy=copy) + + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + def cast_weight(s, input=None, dtype=None, device=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if device is None: + device = input.device + weight = cast_to(s.weight, dtype, device) + return weight + + def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if bias_dtype is None: + bias_dtype = dtype + if device is None: + device = input.device + weight = cast_to(s.weight, dtype, device) + bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None + return weight, bias + + class quantized_layer: + class Linear(torch.nn.Linear): + def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs): + super().__init__(*args, **kwargs) + self.dtype = dtype + self.device = device + + def block_forward_(self, x, i, j, dtype, device): + weight_ = cast_to( + self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size], + dtype=dtype, device=device + ) + if self.bias is None or i > 0: + bias_ = None + else: + bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device) + x_ = x[..., i * self.block_size: (i + 1) * self.block_size] + y_ = torch.nn.functional.linear(x_, weight_, bias_) + del x_, weight_, bias_ + torch.cuda.empty_cache() + return y_ + + def block_forward(self, x, **kwargs): + # This feature can only reduce 2GB VRAM, so we disable it. + y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device) + for i in range((self.in_features + self.block_size - 1) // self.block_size): + for j in range((self.out_features + self.block_size - 1) // self.block_size): + y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device) + return y + + def forward(self, x, **kwargs): + weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) + return torch.nn.functional.linear(x, weight, bias) + + + class RMSNorm(torch.nn.Module): + def __init__(self, module, dtype=torch.bfloat16, device="cuda"): + super().__init__() + self.module = module + self.dtype = dtype + self.device = device + + def forward(self, hidden_states, **kwargs): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) + hidden_states = hidden_states.to(input_dtype) + if self.module.weight is not None: + weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda") + hidden_states = hidden_states * weight + return hidden_states + + class Conv3d(torch.nn.Conv3d): + def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs): + super().__init__(*args, **kwargs) + self.dtype = dtype + self.device = device + + def forward(self, x): + weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) + return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups) + + class LayerNorm(torch.nn.LayerNorm): + def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs): + super().__init__(*args, **kwargs) + self.dtype = dtype + self.device = device + + def forward(self, x): + if self.weight is not None and self.bias is not None: + weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) + return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps) + else: + return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + + def replace_layer(model, dtype=torch.bfloat16, device="cuda"): + for name, module in model.named_children(): + if isinstance(module, torch.nn.Linear): + with init_weights_on_device(): + new_layer = quantized_layer.Linear( + module.in_features, module.out_features, bias=module.bias is not None, + dtype=dtype, device=device + ) + new_layer.load_state_dict(module.state_dict(), assign=True) + setattr(model, name, new_layer) + elif isinstance(module, torch.nn.Conv3d): + with init_weights_on_device(): + new_layer = quantized_layer.Conv3d( + module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride, + dtype=dtype, device=device + ) + new_layer.load_state_dict(module.state_dict(), assign=True) + setattr(model, name, new_layer) + elif isinstance(module, RMSNorm): + new_layer = quantized_layer.RMSNorm( + module, + dtype=dtype, device=device + ) + setattr(model, name, new_layer) + elif isinstance(module, torch.nn.LayerNorm): + with init_weights_on_device(): + new_layer = quantized_layer.LayerNorm( + module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps, + dtype=dtype, device=device + ) + new_layer.load_state_dict(module.state_dict(), assign=True) + setattr(model, name, new_layer) + else: + replace_layer(module, dtype=dtype, device=device) + + replace_layer(self, dtype=dtype, device=device) + + @staticmethod + def state_dict_converter(): + return HunyuanVideoDiTStateDictConverter() + + + +class HunyuanVideoDiTStateDictConverter: + def __init__(self): + pass + + def from_civitai(self, state_dict): + if "module" in state_dict: + state_dict = state_dict["module"] + direct_dict = { + "img_in.proj": "img_in.proj", + "time_in.mlp.0": "time_in.timestep_embedder.0", + "time_in.mlp.2": "time_in.timestep_embedder.2", + "vector_in.in_layer": "vector_in.0", + "vector_in.out_layer": "vector_in.2", + "guidance_in.mlp.0": "guidance_in.timestep_embedder.0", + "guidance_in.mlp.2": "guidance_in.timestep_embedder.2", + "txt_in.input_embedder": "txt_in.input_embedder", + "txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0", + "txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2", + "txt_in.c_embedder.linear_1": "txt_in.c_embedder.0", + "txt_in.c_embedder.linear_2": "txt_in.c_embedder.2", + "final_layer.linear": "final_layer.linear", + "final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1", + } + txt_suffix_dict = { + "norm1": "norm1", + "self_attn_qkv": "self_attn_qkv", + "self_attn_proj": "self_attn_proj", + "norm2": "norm2", + "mlp.fc1": "mlp.0", + "mlp.fc2": "mlp.2", + "adaLN_modulation.1": "adaLN_modulation.1", + } + double_suffix_dict = { + "img_mod.linear": "component_a.mod.linear", + "img_attn_qkv": "component_a.to_qkv", + "img_attn_q_norm": "component_a.norm_q", + "img_attn_k_norm": "component_a.norm_k", + "img_attn_proj": "component_a.to_out", + "img_mlp.fc1": "component_a.ff.0", + "img_mlp.fc2": "component_a.ff.2", + "txt_mod.linear": "component_b.mod.linear", + "txt_attn_qkv": "component_b.to_qkv", + "txt_attn_q_norm": "component_b.norm_q", + "txt_attn_k_norm": "component_b.norm_k", + "txt_attn_proj": "component_b.to_out", + "txt_mlp.fc1": "component_b.ff.0", + "txt_mlp.fc2": "component_b.ff.2", + } + single_suffix_dict = { + "linear1": ["to_qkv", "ff.0"], + "linear2": ["to_out", "ff.2"], + "q_norm": "norm_q", + "k_norm": "norm_k", + "modulation.linear": "mod.linear", + } + # single_suffix_dict = { + # "linear1": "linear1", + # "linear2": "linear2", + # "q_norm": "q_norm", + # "k_norm": "k_norm", + # "modulation.linear": "modulation.linear", + # } + state_dict_ = {} + for name, param in state_dict.items(): + names = name.split(".") + direct_name = ".".join(names[:-1]) + if direct_name in direct_dict: + name_ = direct_dict[direct_name] + "." + names[-1] + state_dict_[name_] = param + elif names[0] == "double_blocks": + prefix = ".".join(names[:2]) + suffix = ".".join(names[2:-1]) + name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1] + state_dict_[name_] = param + elif names[0] == "single_blocks": + prefix = ".".join(names[:2]) + suffix = ".".join(names[2:-1]) + if isinstance(single_suffix_dict[suffix], list): + if suffix == "linear1": + name_a, name_b = single_suffix_dict[suffix] + param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0) + state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a + state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b + elif suffix == "linear2": + if names[-1] == "weight": + name_a, name_b = single_suffix_dict[suffix] + param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1) + state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a + state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b + else: + name_a, name_b = single_suffix_dict[suffix] + state_dict_[prefix + "." + name_a + "." + names[-1]] = param + else: + pass + else: + name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1] + state_dict_[name_] = param + elif names[0] == "txt_in": + prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".") + suffix = ".".join(names[4:-1]) + name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1] + state_dict_[name_] = param + else: + pass + return state_dict_ diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index eebc4a2..d416864 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -306,6 +306,53 @@ class FluxLoRAConverter: state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0] return state_dict_ + @staticmethod + def align_to_diffsynth_format(state_dict): + rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + def guess_block_id(name): + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + return None, None + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name) + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + return state_dict_ + def get_lora_loaders(): return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()] diff --git a/diffsynth/models/sd3_dit.py b/diffsynth/models/sd3_dit.py index 730e6fc..60e6be4 100644 --- a/diffsynth/models/sd3_dit.py +++ b/diffsynth/models/sd3_dit.py @@ -52,9 +52,9 @@ class PatchEmbed(torch.nn.Module): class TimestepEmbeddings(torch.nn.Module): - def __init__(self, dim_in, dim_out): + def __init__(self, dim_in, dim_out, computation_device=None): super().__init__() - self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device) self.timestep_embedder = torch.nn.Sequential( torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) ) diff --git a/diffsynth/models/svd_unet.py b/diffsynth/models/svd_unet.py index 85c6aba..19c540a 100644 --- a/diffsynth/models/svd_unet.py +++ b/diffsynth/models/svd_unet.py @@ -44,6 +44,7 @@ def get_timestep_embedding( downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, + computation_device = None, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. @@ -57,11 +58,11 @@ def get_timestep_embedding( half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device ) exponent = exponent / (half_dim - downscale_freq_shift) - emb = torch.exp(exponent) + emb = torch.exp(exponent).to(timesteps.device) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings @@ -81,11 +82,12 @@ def get_timestep_embedding( class TemporalTimesteps(torch.nn.Module): - def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift + self.computation_device = computation_device def forward(self, timesteps): t_emb = get_timestep_embedding( @@ -93,6 +95,7 @@ class TemporalTimesteps(torch.nn.Module): self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, + computation_device=self.computation_device, ) return t_emb diff --git a/diffsynth/models/utils.py b/diffsynth/models/utils.py index e18e2dd..99f5dee 100644 --- a/diffsynth/models/utils.py +++ b/diffsynth/models/utils.py @@ -80,7 +80,7 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None): def load_state_dict_from_bin(file_path, torch_dtype=None): - state_dict = torch.load(file_path, map_location="cpu") + state_dict = torch.load(file_path, map_location="cpu", weights_only=True) if torch_dtype is not None: for i in state_dict: if isinstance(state_dict[i], torch.Tensor): diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 3177474..5f7c723 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -3,6 +3,7 @@ from peft import LoraConfig, inject_adapter_in_model import torch, os from ..data.simple_text_image import TextImageDataset from modelscope.hub.api import HubApi +from ..models.utils import load_state_dict @@ -33,7 +34,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): self.pipe.denoising_model().train() - def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian"): + def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, state_dict_converter=None): # Add LoRA to UNet self.lora_alpha = lora_alpha if init_lora_weights == "kaiming": @@ -51,6 +52,17 @@ class LightningModelForT2ILoRA(pl.LightningModule): if param.requires_grad: param.data = param.to(torch.float32) + # Lora pretrained lora weights + if pretrained_lora_path is not None: + state_dict = load_state_dict(pretrained_lora_path) + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + all_keys = [i for i, _ in model.named_parameters()] + num_updated_keys = len(all_keys) - len(missing_keys) + num_unexpected_keys = len(unexpected_keys) + print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.") + def training_step(self, batch, batch_idx): # Data @@ -229,6 +241,12 @@ def add_general_parsers(parser): default=None, help="Access key on ModelScope (https://www.modelscope.cn/). Required if you want to upload the model to ModelScope.", ) + parser.add_argument( + "--pretrained_lora_path", + type=str, + default=None, + help="Pretrained LoRA path. Required if the training is resumed.", + ) return parser diff --git a/examples/train/flux/train_flux_lora.py b/examples/train/flux/train_flux_lora.py index 4efeed3..eb5539a 100644 --- a/examples/train/flux/train_flux_lora.py +++ b/examples/train/flux/train_flux_lora.py @@ -10,7 +10,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None, learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None, quantize = None ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter) @@ -34,7 +34,15 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.scheduler.set_timesteps(1000, training=True) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format + ) def parse_args(): @@ -109,6 +117,7 @@ if __name__ == '__main__': lora_alpha=args.lora_alpha, lora_target_modules=args.lora_target_modules, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None, quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None), ) diff --git a/examples/train/hunyuan_dit/train_hunyuan_dit_lora.py b/examples/train/hunyuan_dit/train_hunyuan_dit_lora.py index 6ceba42..7764ab5 100644 --- a/examples/train/hunyuan_dit/train_hunyuan_dit_lora.py +++ b/examples/train/hunyuan_dit/train_hunyuan_dit_lora.py @@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) # Load models @@ -19,7 +19,14 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.scheduler.set_timesteps(1000) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) def parse_args(): @@ -57,6 +64,7 @@ if __name__ == '__main__': lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, lora_target_modules=args.lora_target_modules ) launch_training_task(model, args) diff --git a/examples/train/kolors/train_kolors_lora.py b/examples/train/kolors/train_kolors_lora.py index 120e41d..48a9892 100644 --- a/examples/train/kolors/train_kolors_lora.py +++ b/examples/train/kolors/train_kolors_lora.py @@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) # Load models @@ -22,7 +22,14 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.vae_encoder.to(torch_dtype) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) def parse_args(): @@ -73,6 +80,7 @@ if __name__ == '__main__': lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, lora_target_modules=args.lora_target_modules ) launch_training_task(model, args) diff --git a/examples/train/stable_diffusion/train_sd_lora.py b/examples/train/stable_diffusion/train_sd_lora.py index 8dcaf7a..dc24520 100644 --- a/examples/train/stable_diffusion/train_sd_lora.py +++ b/examples/train/stable_diffusion/train_sd_lora.py @@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) # Load models @@ -19,7 +19,14 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.scheduler.set_timesteps(1000) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) def parse_args(): @@ -52,6 +59,7 @@ if __name__ == '__main__': lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, lora_target_modules=args.lora_target_modules ) launch_training_task(model, args) diff --git a/examples/train/stable_diffusion_3/train_sd3_lora.py b/examples/train/stable_diffusion_3/train_sd3_lora.py index a677bcb..c9abf2b 100644 --- a/examples/train/stable_diffusion_3/train_sd3_lora.py +++ b/examples/train/stable_diffusion_3/train_sd3_lora.py @@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None, learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) # Load models @@ -24,7 +24,14 @@ class LightningModel(LightningModelForT2ILoRA): model_manager.load_lora(path) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) def parse_args(): @@ -70,6 +77,7 @@ if __name__ == '__main__': lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, lora_target_modules=args.lora_target_modules ) launch_training_task(model, args) diff --git a/examples/train/stable_diffusion_xl/train_sdxl_lora.py b/examples/train/stable_diffusion_xl/train_sdxl_lora.py index 69ca71d..de0241d 100644 --- a/examples/train/stable_diffusion_xl/train_sdxl_lora.py +++ b/examples/train/stable_diffusion_xl/train_sdxl_lora.py @@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) # Load models @@ -19,7 +19,14 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.scheduler.set_timesteps(1000) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) def parse_args(): @@ -52,6 +59,7 @@ if __name__ == '__main__': lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, lora_target_modules=args.lora_target_modules ) launch_training_task(model, args)