mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
bug fix
This commit is contained in:
@@ -202,7 +202,7 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
if name not in controlnet_inputs:
|
||||
controlnet_inputs[name] = {}
|
||||
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
|
||||
break
|
||||
break
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
for name, params in controlnet_inputs.items():
|
||||
|
||||
@@ -145,7 +145,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
value_controllers = model_pool.fetch_model("flux_value_controller")
|
||||
if value_controllers is not None:
|
||||
pipe.value_controller = MultiValueEncoder(value_controllers)
|
||||
pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled
|
||||
if hasattr(pipe.value_controller.encoders[0], "vram_management_enabled"):
|
||||
pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled
|
||||
controlnets = model_pool.fetch_model("flux_controlnet")
|
||||
if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets)
|
||||
pipe.ipadapter = model_pool.fetch_model("flux_ipadapter")
|
||||
@@ -295,7 +296,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
class FluxImageUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("height", "width"))
|
||||
super().__init__(input_params=("height", "width"), output_params=("height", "width"))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, height, width):
|
||||
height, width = pipe.check_resize_height_width(height, width)
|
||||
@@ -305,7 +306,7 @@ class FluxImageUnit_ShapeChecker(PipelineUnit):
|
||||
|
||||
class FluxImageUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("height", "width", "seed", "rand_device"))
|
||||
super().__init__(input_params=("height", "width", "seed", "rand_device"), output_params=("noise",))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device)
|
||||
@@ -317,6 +318,7 @@ class FluxImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
@@ -341,6 +343,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
|
||||
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
||||
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
|
||||
input_params=("t5_sequence_length",),
|
||||
output_params=("prompt_emb", "pooled_prompt_emb", "text_ids"),
|
||||
onload_model_names=("text_encoder_1", "text_encoder_2")
|
||||
)
|
||||
|
||||
@@ -396,7 +399,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
class FluxImageUnit_ImageIDs(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("latents",))
|
||||
super().__init__(input_params=("latents",), output_params=("image_ids",))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, latents):
|
||||
latent_image_ids = pipe.dit.prepare_image_ids(latents)
|
||||
@@ -406,7 +409,7 @@ class FluxImageUnit_ImageIDs(PipelineUnit):
|
||||
|
||||
class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("embedded_guidance", "latents"))
|
||||
super().__init__(input_params=("embedded_guidance", "latents"), output_params=("guidance",))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, embedded_guidance, latents):
|
||||
guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
@@ -416,7 +419,11 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
|
||||
|
||||
class FluxImageUnit_Kontext(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride"))
|
||||
super().__init__(
|
||||
input_params=("kontext_images", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("kontext_latents", "kontext_image_ids"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride):
|
||||
if kontext_images is None:
|
||||
@@ -444,6 +451,7 @@ class FluxImageUnit_ControlNet(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("controlnet_conditionings",),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
@@ -486,6 +494,8 @@ class FluxImageUnit_IPAdapter(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("ipadapter_images", "ipadapter_scale"),
|
||||
output_params=("ipadapter_kwargs_list",),
|
||||
onload_model_names=("ipadapter_image_encoder", "ipadapter")
|
||||
)
|
||||
|
||||
@@ -513,6 +523,8 @@ class FluxImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("eligen_entity_prompts", "eligen_entity_masks", "eligen_enable_on_negative", "width", "height", "t5_sequence_length", "cfg_scale"),
|
||||
output_params=("entity_prompt_emb", "entity_masks"),
|
||||
onload_model_names=("text_encoder_1", "text_encoder_2")
|
||||
)
|
||||
|
||||
@@ -603,6 +615,8 @@ class FluxImageUnit_NexusGen(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("nexus_gen_reference_image", "prompt", "latents"),
|
||||
output_params=("prompt_emb", "text_ids"),
|
||||
onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"),
|
||||
)
|
||||
|
||||
@@ -655,7 +669,12 @@ class FluxImageUnit_NexusGen(PipelineUnit):
|
||||
|
||||
class FluxImageUnit_Step1x(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder"))
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("step1x_reference_image", "prompt", "negative_prompt"),
|
||||
output_params=("step1x_llm_embedding", "step1x_mask", "step1x_reference_latents"),
|
||||
onload_model_names=("qwenvl","vae_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict):
|
||||
image = inputs_shared.get("step1x_reference_image",None)
|
||||
@@ -678,7 +697,7 @@ class FluxImageUnit_Step1x(PipelineUnit):
|
||||
|
||||
class FluxImageUnit_TeaCache(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"))
|
||||
super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"), output_params=("tea_cache",))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh):
|
||||
if tea_cache_l1_thresh is None:
|
||||
@@ -690,6 +709,7 @@ class FluxImageUnit_Flex(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("flex_condition", "flex_uncondition", "flex_control_stop_timestep"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
@@ -728,6 +748,7 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("infinityou_id_image", "infinityou_guidance"),
|
||||
output_params=("id_emb", "infinityou_guidance"),
|
||||
onload_model_names=("infinityou_processor",)
|
||||
)
|
||||
|
||||
@@ -747,6 +768,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
|
||||
input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
|
||||
input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
|
||||
input_params=("value_controller_inputs",),
|
||||
output_params=("prompt_emb", "text_ids"),
|
||||
onload_model_names=("value_controller",)
|
||||
)
|
||||
|
||||
@@ -825,6 +847,8 @@ class FluxImageUnit_LoRAEncode(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("lora_encoder_inputs", "lora_encoder_scale"),
|
||||
output_params=("prompt_emb", "text_ids"),
|
||||
onload_model_names=("lora_encoder",)
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ accelerate launch examples/flux/model_training/train.py \
|
||||
--data_file_keys "image,ipadapter_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.ipadapter." \
|
||||
|
||||
@@ -4,7 +4,7 @@ accelerate launch --config_file examples/flux/model_training/full/accelerate_con
|
||||
--data_file_keys "image,step1x_reference_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
|
||||
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:model-*.safetensors,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
|
||||
@@ -4,7 +4,7 @@ accelerate launch examples/flux/model_training/train.py \
|
||||
--data_file_keys "image,ipadapter_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
|
||||
@@ -4,7 +4,7 @@ accelerate launch examples/flux/model_training/train.py \
|
||||
--data_file_keys "image,step1x_reference_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
|
||||
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:model-*.safetensors,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
|
||||
@@ -13,7 +13,7 @@ pipe = FluxImagePipeline.from_pretrained(
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
|
||||
ModelConfig(model_id="google/siglip-so400m-patch14-384"),
|
||||
ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLUX.1-dev-IP-Adapter_full/epoch-0.safetensors")
|
||||
|
||||
@@ -6,7 +6,7 @@ pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
|
||||
@@ -12,7 +12,7 @@ pipe = FluxImagePipeline.from_pretrained(
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
|
||||
ModelConfig(model_id="google/siglip-so400m-patch14-384"),
|
||||
ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-IP-Adapter_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
@@ -7,7 +7,7 @@ pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"),
|
||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
|
||||
],
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, load_state_dict
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
from modelscope import snapshot_download
|
||||
import torch, math
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, load_state_dict
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
from modelscope import snapshot_download
|
||||
import torch, math
|
||||
|
||||
|
||||
@@ -110,5 +110,5 @@ def test_flux():
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_qwen_image()
|
||||
test_wan()
|
||||
test_flux()
|
||||
test_wan()
|
||||
Reference in New Issue
Block a user