nexus-gen

This commit is contained in:
Artiprocher
2025-04-30 17:09:15 +08:00
parent ef2a7abad4
commit f7737aff98
9 changed files with 3200 additions and 7 deletions

View File

@@ -0,0 +1,64 @@
import torch
from diffsynth import ModelManager
from .flux_image_pipeline import FluxImagePipelineAll2All
class FluxDecoder:
def __init__(self, flux_all2all_modelpath, flux_path, device='cuda', torch_dtype=torch.bfloat16):
self.device = device
self.torch_dtype = torch_dtype
self.pipe, self.adapter = self.get_pipe(flux_all2all_modelpath, flux_path, device, torch_dtype)
def get_pipe(self, flux_all2all_modelpath, flux_path, device="cuda", torch_dtype=torch.bfloat16):
model_manager = ModelManager(torch_dtype=torch_dtype, device=device)
model_manager.load_models([
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder_2",
f"{flux_path}/FLUX/FLUX.1-dev/ae.safetensors",
f"{flux_path}/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
state_dict = torch.load(flux_all2all_modelpath, weights_only=True, map_location='cpu')
adapter_states = ['0.weight', '0.bias', '1.weight', '1.bias', '3.weight', '3.bias', '4.weight', '4.bias']
adapter_state_dict = {}
for key in adapter_states:
adapter_state_dict[key] = state_dict.pop(key)
in_channel = 3584
out_channel = 4096
expand_ratio = 1
adapter = torch.nn.Sequential(torch.nn.Linear(in_channel, out_channel * expand_ratio),
torch.nn.LayerNorm(out_channel * expand_ratio), torch.nn.ReLU(),
torch.nn.Linear(out_channel * expand_ratio, out_channel),
torch.nn.LayerNorm(out_channel))
adapter.load_state_dict(adapter_state_dict)
adapter.to(device, dtype=torch_dtype)
pipe = FluxImagePipelineAll2All.from_model_manager(model_manager)
pipe.dit.load_state_dict(state_dict)
return pipe, adapter
@torch.no_grad()
def decode_image_embeds(self,
output_image_embeddings,
height=512,
width=512,
num_inference_steps=50,
seed=42,
negative_prompt="",
cfg_scale=1.0,
**pipe_kwargs):
output_image_embeddings = output_image_embeddings.to(device=self.device, dtype=self.torch_dtype)
image_embed = self.adapter(output_image_embeddings)
image = self.pipe(prompt="",
image_embed=image_embed,
num_inference_steps=num_inference_steps,
embedded_guidance=3.5,
negative_prompt=negative_prompt,
cfg_scale=cfg_scale,
height=height,
width=width,
seed=seed,
**pipe_kwargs)
return image