mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
nexus-gen
This commit is contained in:
64
modeling/decoder/flux_decoder.py
Normal file
64
modeling/decoder/flux_decoder.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager
|
||||
from .flux_image_pipeline import FluxImagePipelineAll2All
|
||||
|
||||
class FluxDecoder:
|
||||
|
||||
def __init__(self, flux_all2all_modelpath, flux_path, device='cuda', torch_dtype=torch.bfloat16):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.pipe, self.adapter = self.get_pipe(flux_all2all_modelpath, flux_path, device, torch_dtype)
|
||||
|
||||
def get_pipe(self, flux_all2all_modelpath, flux_path, device="cuda", torch_dtype=torch.bfloat16):
|
||||
model_manager = ModelManager(torch_dtype=torch_dtype, device=device)
|
||||
model_manager.load_models([
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
f"{flux_path}/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
])
|
||||
|
||||
state_dict = torch.load(flux_all2all_modelpath, weights_only=True, map_location='cpu')
|
||||
adapter_states = ['0.weight', '0.bias', '1.weight', '1.bias', '3.weight', '3.bias', '4.weight', '4.bias']
|
||||
adapter_state_dict = {}
|
||||
for key in adapter_states:
|
||||
adapter_state_dict[key] = state_dict.pop(key)
|
||||
|
||||
in_channel = 3584
|
||||
out_channel = 4096
|
||||
expand_ratio = 1
|
||||
adapter = torch.nn.Sequential(torch.nn.Linear(in_channel, out_channel * expand_ratio),
|
||||
torch.nn.LayerNorm(out_channel * expand_ratio), torch.nn.ReLU(),
|
||||
torch.nn.Linear(out_channel * expand_ratio, out_channel),
|
||||
torch.nn.LayerNorm(out_channel))
|
||||
adapter.load_state_dict(adapter_state_dict)
|
||||
adapter.to(device, dtype=torch_dtype)
|
||||
|
||||
pipe = FluxImagePipelineAll2All.from_model_manager(model_manager)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
return pipe, adapter
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_image_embeds(self,
|
||||
output_image_embeddings,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=50,
|
||||
seed=42,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
**pipe_kwargs):
|
||||
output_image_embeddings = output_image_embeddings.to(device=self.device, dtype=self.torch_dtype)
|
||||
image_embed = self.adapter(output_image_embeddings)
|
||||
image = self.pipe(prompt="",
|
||||
image_embed=image_embed,
|
||||
num_inference_steps=num_inference_steps,
|
||||
embedded_guidance=3.5,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=cfg_scale,
|
||||
height=height,
|
||||
width=width,
|
||||
seed=seed,
|
||||
**pipe_kwargs)
|
||||
return image
|
||||
Reference in New Issue
Block a user