support z-image-omni-base-i2L

This commit is contained in:
Artiprocher
2026-01-07 20:36:53 +08:00
parent bac39b1cd2
commit dd479e5bff
6 changed files with 413 additions and 6 deletions

View File

@@ -235,6 +235,7 @@ class BasePipeline(torch.nn.Module):
alpha=1,
hotload=None,
state_dict=None,
verbose=1,
):
if state_dict is None:
if isinstance(lora_config, str):
@@ -261,12 +262,13 @@ class BasePipeline(torch.nn.Module):
updated_num += 1
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
if verbose >= 1:
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
else:
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
def clear_lora(self):
def clear_lora(self, verbose=1):
cleared_num = 0
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
@@ -276,7 +278,8 @@ class BasePipeline(torch.nn.Module):
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
print(f"{cleared_num} LoRA layers are cleared.")
if verbose >= 1:
print(f"{cleared_num} LoRA layers are cleared.")
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
@@ -304,8 +307,13 @@ class BasePipeline(torch.nn.Module):
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: