add Flux_ControlNet_Quantization

This commit is contained in:
tc2000731
2024-10-29 17:29:24 +08:00
parent 7e97a96840
commit 900a1c095f
7 changed files with 558 additions and 2 deletions

View File

@@ -83,8 +83,14 @@ class LoRAFromCivitai:
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
fp8=False
if state_dict_model[name].dtype == torch.float8_e4m3fn:
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
fp8=True
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
if fp8:
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
model.load_state_dict(state_dict_model)