This commit is contained in:
Artiprocher
2025-11-21 20:26:49 +08:00
parent 0b7dd55ff3
commit 0336551544
13 changed files with 46 additions and 20 deletions

View File

@@ -202,7 +202,7 @@ class DiffusionTrainingModule(torch.nn.Module):
if name not in controlnet_inputs: if name not in controlnet_inputs:
controlnet_inputs[name] = {} controlnet_inputs[name] = {}
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input] controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
break break
else: else:
inputs_shared[extra_input] = data[extra_input] inputs_shared[extra_input] = data[extra_input]
for name, params in controlnet_inputs.items(): for name, params in controlnet_inputs.items():

View File

@@ -145,7 +145,8 @@ class FluxImagePipeline(BasePipeline):
value_controllers = model_pool.fetch_model("flux_value_controller") value_controllers = model_pool.fetch_model("flux_value_controller")
if value_controllers is not None: if value_controllers is not None:
pipe.value_controller = MultiValueEncoder(value_controllers) pipe.value_controller = MultiValueEncoder(value_controllers)
pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled if hasattr(pipe.value_controller.encoders[0], "vram_management_enabled"):
pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled
controlnets = model_pool.fetch_model("flux_controlnet") controlnets = model_pool.fetch_model("flux_controlnet")
if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets) if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets)
pipe.ipadapter = model_pool.fetch_model("flux_ipadapter") pipe.ipadapter = model_pool.fetch_model("flux_ipadapter")
@@ -295,7 +296,7 @@ class FluxImagePipeline(BasePipeline):
class FluxImageUnit_ShapeChecker(PipelineUnit): class FluxImageUnit_ShapeChecker(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__(input_params=("height", "width")) super().__init__(input_params=("height", "width"), output_params=("height", "width"))
def process(self, pipe: FluxImagePipeline, height, width): def process(self, pipe: FluxImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width) height, width = pipe.check_resize_height_width(height, width)
@@ -305,7 +306,7 @@ class FluxImageUnit_ShapeChecker(PipelineUnit):
class FluxImageUnit_NoiseInitializer(PipelineUnit): class FluxImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__(input_params=("height", "width", "seed", "rand_device")) super().__init__(input_params=("height", "width", "seed", "rand_device"), output_params=("noise",))
def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device): def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device) noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device)
@@ -317,6 +318,7 @@ class FluxImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
output_params=("latents", "input_latents"),
onload_model_names=("vae_encoder",) onload_model_names=("vae_encoder",)
) )
@@ -341,6 +343,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
input_params_posi={"prompt": "prompt", "positive": "positive"}, input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
input_params=("t5_sequence_length",), input_params=("t5_sequence_length",),
output_params=("prompt_emb", "pooled_prompt_emb", "text_ids"),
onload_model_names=("text_encoder_1", "text_encoder_2") onload_model_names=("text_encoder_1", "text_encoder_2")
) )
@@ -396,7 +399,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
class FluxImageUnit_ImageIDs(PipelineUnit): class FluxImageUnit_ImageIDs(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__(input_params=("latents",)) super().__init__(input_params=("latents",), output_params=("image_ids",))
def process(self, pipe: FluxImagePipeline, latents): def process(self, pipe: FluxImagePipeline, latents):
latent_image_ids = pipe.dit.prepare_image_ids(latents) latent_image_ids = pipe.dit.prepare_image_ids(latents)
@@ -406,7 +409,7 @@ class FluxImageUnit_ImageIDs(PipelineUnit):
class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit): class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__(input_params=("embedded_guidance", "latents")) super().__init__(input_params=("embedded_guidance", "latents"), output_params=("guidance",))
def process(self, pipe: FluxImagePipeline, embedded_guidance, latents): def process(self, pipe: FluxImagePipeline, embedded_guidance, latents):
guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
@@ -416,7 +419,11 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
class FluxImageUnit_Kontext(PipelineUnit): class FluxImageUnit_Kontext(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride")) super().__init__(
input_params=("kontext_images", "tiled", "tile_size", "tile_stride"),
output_params=("kontext_latents", "kontext_image_ids"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride): def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride):
if kontext_images is None: if kontext_images is None:
@@ -444,6 +451,7 @@ class FluxImageUnit_ControlNet(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"), input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
output_params=("controlnet_conditionings",),
onload_model_names=("vae_encoder",) onload_model_names=("vae_encoder",)
) )
@@ -486,6 +494,8 @@ class FluxImageUnit_IPAdapter(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
take_over=True, take_over=True,
input_params=("ipadapter_images", "ipadapter_scale"),
output_params=("ipadapter_kwargs_list",),
onload_model_names=("ipadapter_image_encoder", "ipadapter") onload_model_names=("ipadapter_image_encoder", "ipadapter")
) )
@@ -513,6 +523,8 @@ class FluxImageUnit_EntityControl(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
take_over=True, take_over=True,
input_params=("eligen_entity_prompts", "eligen_entity_masks", "eligen_enable_on_negative", "width", "height", "t5_sequence_length", "cfg_scale"),
output_params=("entity_prompt_emb", "entity_masks"),
onload_model_names=("text_encoder_1", "text_encoder_2") onload_model_names=("text_encoder_1", "text_encoder_2")
) )
@@ -603,6 +615,8 @@ class FluxImageUnit_NexusGen(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
take_over=True, take_over=True,
input_params=("nexus_gen_reference_image", "prompt", "latents"),
output_params=("prompt_emb", "text_ids"),
onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"), onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"),
) )
@@ -655,7 +669,12 @@ class FluxImageUnit_NexusGen(PipelineUnit):
class FluxImageUnit_Step1x(PipelineUnit): class FluxImageUnit_Step1x(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder")) super().__init__(
take_over=True,
input_params=("step1x_reference_image", "prompt", "negative_prompt"),
output_params=("step1x_llm_embedding", "step1x_mask", "step1x_reference_latents"),
onload_model_names=("qwenvl","vae_encoder")
)
def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict): def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict):
image = inputs_shared.get("step1x_reference_image",None) image = inputs_shared.get("step1x_reference_image",None)
@@ -678,7 +697,7 @@ class FluxImageUnit_Step1x(PipelineUnit):
class FluxImageUnit_TeaCache(PipelineUnit): class FluxImageUnit_TeaCache(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh")) super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"), output_params=("tea_cache",))
def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh): def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh):
if tea_cache_l1_thresh is None: if tea_cache_l1_thresh is None:
@@ -690,6 +709,7 @@ class FluxImageUnit_Flex(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"), input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"),
output_params=("flex_condition", "flex_uncondition", "flex_control_stop_timestep"),
onload_model_names=("vae_encoder",) onload_model_names=("vae_encoder",)
) )
@@ -728,6 +748,7 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("infinityou_id_image", "infinityou_guidance"), input_params=("infinityou_id_image", "infinityou_guidance"),
output_params=("id_emb", "infinityou_guidance"),
onload_model_names=("infinityou_processor",) onload_model_names=("infinityou_processor",)
) )
@@ -747,6 +768,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params=("value_controller_inputs",), input_params=("value_controller_inputs",),
output_params=("prompt_emb", "text_ids"),
onload_model_names=("value_controller",) onload_model_names=("value_controller",)
) )
@@ -825,6 +847,8 @@ class FluxImageUnit_LoRAEncode(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
take_over=True, take_over=True,
input_params=("lora_encoder_inputs", "lora_encoder_scale"),
output_params=("prompt_emb", "text_ids"),
onload_model_names=("lora_encoder",) onload_model_names=("lora_encoder",)
) )

View File

@@ -4,7 +4,7 @@ accelerate launch examples/flux/model_training/train.py \
--data_file_keys "image,ipadapter_images" \ --data_file_keys "image,ipadapter_images" \
--max_pixels 1048576 \ --max_pixels 1048576 \
--dataset_repeat 100 \ --dataset_repeat 100 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \ --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:model.safetensors" \
--learning_rate 1e-5 \ --learning_rate 1e-5 \
--num_epochs 1 \ --num_epochs 1 \
--remove_prefix_in_ckpt "pipe.ipadapter." \ --remove_prefix_in_ckpt "pipe.ipadapter." \

View File

@@ -4,7 +4,7 @@ accelerate launch --config_file examples/flux/model_training/full/accelerate_con
--data_file_keys "image,step1x_reference_image" \ --data_file_keys "image,step1x_reference_image" \
--max_pixels 1048576 \ --max_pixels 1048576 \
--dataset_repeat 400 \ --dataset_repeat 400 \
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \ --model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:model-*.safetensors,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
--learning_rate 1e-5 \ --learning_rate 1e-5 \
--num_epochs 1 \ --num_epochs 1 \
--remove_prefix_in_ckpt "pipe.dit." \ --remove_prefix_in_ckpt "pipe.dit." \

View File

@@ -4,7 +4,7 @@ accelerate launch examples/flux/model_training/train.py \
--data_file_keys "image,ipadapter_images" \ --data_file_keys "image,ipadapter_images" \
--max_pixels 1048576 \ --max_pixels 1048576 \
--dataset_repeat 50 \ --dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \ --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:model.safetensors" \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
--num_epochs 5 \ --num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \ --remove_prefix_in_ckpt "pipe.dit." \

View File

@@ -4,7 +4,7 @@ accelerate launch examples/flux/model_training/train.py \
--data_file_keys "image,step1x_reference_image" \ --data_file_keys "image,step1x_reference_image" \
--max_pixels 1048576 \ --max_pixels 1048576 \
--dataset_repeat 50 \ --dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \ --model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:model-*.safetensors,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
--num_epochs 5 \ --num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \ --remove_prefix_in_ckpt "pipe.dit." \

View File

@@ -13,7 +13,7 @@ pipe = FluxImagePipeline.from_pretrained(
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"), ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
ModelConfig(model_id="google/siglip-so400m-patch14-384"), ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors"),
], ],
) )
state_dict = load_state_dict("models/train/FLUX.1-dev-IP-Adapter_full/epoch-0.safetensors") state_dict = load_state_dict("models/train/FLUX.1-dev-IP-Adapter_full/epoch-0.safetensors")

View File

@@ -6,7 +6,7 @@ pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device="cuda", device="cuda",
model_configs=[ model_configs=[
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"), ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),

View File

@@ -12,7 +12,7 @@ pipe = FluxImagePipeline.from_pretrained(
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"), ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
ModelConfig(model_id="google/siglip-so400m-patch14-384"), ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors"),
], ],
) )
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-IP-Adapter_lora/epoch-4.safetensors", alpha=1) pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-IP-Adapter_lora/epoch-4.safetensors", alpha=1)

View File

@@ -7,7 +7,7 @@ pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device="cuda", device="cuda",
model_configs=[ model_configs=[
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"), ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"), ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
], ],

View File

@@ -1,4 +1,5 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, load_state_dict from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.core import load_state_dict
from modelscope import snapshot_download from modelscope import snapshot_download
import torch, math import torch, math

View File

@@ -1,4 +1,5 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, load_state_dict from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.core import load_state_dict
from modelscope import snapshot_download from modelscope import snapshot_download
import torch, math import torch, math

View File

@@ -110,5 +110,5 @@ def test_flux():
if __name__ == "__main__": if __name__ == "__main__":
test_qwen_image() test_qwen_image()
test_wan()
test_flux() test_flux()
test_wan()