mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
2 Commits
qwen-image
...
value-cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba421a9ab9 | ||
|
|
6c30a7f080 |
@@ -64,6 +64,8 @@ from ..models.wan_video_vace import VaceWanModel
|
|||||||
|
|
||||||
from ..models.step1x_connector import Qwen2Connector
|
from ..models.step1x_connector import Qwen2Connector
|
||||||
|
|
||||||
|
from ..models.flux_value_control import SingleValueEncoder
|
||||||
|
|
||||||
|
|
||||||
model_loader_configs = [
|
model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -102,6 +104,7 @@ model_loader_configs = [
|
|||||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||||
|
(None, "3ede90c44b2c161240b659f3b8393c9d", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
|
||||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||||
|
|||||||
58
diffsynth/models/flux_value_control.py
Normal file
58
diffsynth/models/flux_value_control.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||||
|
|
||||||
|
|
||||||
|
class MultiValueEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, encoders=()):
|
||||||
|
super().__init__()
|
||||||
|
self.encoders = torch.nn.ModuleList(encoders)
|
||||||
|
|
||||||
|
def __call__(self, values, dtype):
|
||||||
|
emb = []
|
||||||
|
for encoder, value in zip(self.encoders, values):
|
||||||
|
if value is not None:
|
||||||
|
value = value.unsqueeze(0)
|
||||||
|
emb.append(encoder(value, dtype))
|
||||||
|
emb = torch.concat(emb, dim=0)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class SingleValueEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.prefer_len = prefer_len
|
||||||
|
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||||
|
self.prefer_value_embedder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||||
|
)
|
||||||
|
self.positional_embedding = torch.nn.Parameter(
|
||||||
|
torch.randn(self.prefer_len, dim_in)
|
||||||
|
)
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
last_linear = self.prefer_value_embedder[-1]
|
||||||
|
torch.nn.init.zeros_(last_linear.weight)
|
||||||
|
torch.nn.init.zeros_(last_linear.bias)
|
||||||
|
|
||||||
|
def forward(self, value, dtype):
|
||||||
|
emb = self.prefer_proj(value).to(dtype)
|
||||||
|
emb = emb.expand(self.prefer_len, -1)
|
||||||
|
emb = emb + self.positional_embedding
|
||||||
|
emb = self.prefer_value_embedder(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SingleValueEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class SingleValueEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
@@ -18,6 +18,7 @@ from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEnc
|
|||||||
from ..models.step1x_connector import Qwen2Connector
|
from ..models.step1x_connector import Qwen2Connector
|
||||||
from ..models.flux_controlnet import FluxControlNet
|
from ..models.flux_controlnet import FluxControlNet
|
||||||
from ..models.flux_ipadapter import FluxIpAdapter
|
from ..models.flux_ipadapter import FluxIpAdapter
|
||||||
|
from ..models.flux_value_control import MultiValueEncoder
|
||||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||||
from ..models.tiler import FastTileWorker
|
from ..models.tiler import FastTileWorker
|
||||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||||
@@ -94,6 +95,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.unit_runner = PipelineUnitRunner()
|
self.unit_runner = PipelineUnitRunner()
|
||||||
self.qwenvl = None
|
self.qwenvl = None
|
||||||
self.step1x_connector: Qwen2Connector = None
|
self.step1x_connector: Qwen2Connector = None
|
||||||
|
self.value_controller: MultiValueEncoder = None
|
||||||
self.infinityou_processor: InfinitYou = None
|
self.infinityou_processor: InfinitYou = None
|
||||||
self.image_proj_model: InfiniteYouImageProjector = None
|
self.image_proj_model: InfiniteYouImageProjector = None
|
||||||
self.in_iteration_models = ("dit", "step1x_connector", "controlnet")
|
self.in_iteration_models = ("dit", "step1x_connector", "controlnet")
|
||||||
@@ -112,6 +114,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
FluxImageUnit_TeaCache(),
|
FluxImageUnit_TeaCache(),
|
||||||
FluxImageUnit_Flex(),
|
FluxImageUnit_Flex(),
|
||||||
FluxImageUnit_Step1x(),
|
FluxImageUnit_Step1x(),
|
||||||
|
FluxImageUnit_ValueControl(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_flux_image
|
self.model_fn = model_fn_flux_image
|
||||||
|
|
||||||
@@ -295,7 +298,16 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
for model_name, model in zip(model_manager.model_name, model_manager.model):
|
for model_name, model in zip(model_manager.model_name, model_manager.model):
|
||||||
if model_name == "flux_controlnet":
|
if model_name == "flux_controlnet":
|
||||||
controlnets.append(model)
|
controlnets.append(model)
|
||||||
pipe.controlnet = MultiControlNet(controlnets)
|
if len(controlnets) > 0:
|
||||||
|
pipe.controlnet = MultiControlNet(controlnets)
|
||||||
|
|
||||||
|
# Value Controller
|
||||||
|
value_controllers = []
|
||||||
|
for model_name, model in zip(model_manager.model_name, model_manager.model):
|
||||||
|
if model_name == "flux_value_controller":
|
||||||
|
value_controllers.append(model)
|
||||||
|
if len(value_controllers) > 0:
|
||||||
|
pipe.value_controller = MultiValueEncoder(value_controllers)
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
@@ -347,6 +359,8 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
flex_control_image: Image.Image = None,
|
flex_control_image: Image.Image = None,
|
||||||
flex_control_strength: float = 0.5,
|
flex_control_strength: float = 0.5,
|
||||||
flex_control_stop: float = 0.5,
|
flex_control_stop: float = 0.5,
|
||||||
|
# Value Controller
|
||||||
|
value_controller_inputs: list[float] = None,
|
||||||
# Step1x
|
# Step1x
|
||||||
step1x_reference_image: Image.Image = None,
|
step1x_reference_image: Image.Image = None,
|
||||||
# TeaCache
|
# TeaCache
|
||||||
@@ -380,6 +394,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
|
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
|
||||||
"infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance,
|
"infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance,
|
||||||
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
|
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
|
||||||
|
"value_controller_inputs": value_controller_inputs,
|
||||||
"step1x_reference_image": step1x_reference_image,
|
"step1x_reference_image": step1x_reference_image,
|
||||||
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
||||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||||
@@ -720,6 +735,27 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxImageUnit_ValueControl(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
take_over=True,
|
||||||
|
onload_model_names=("value_controller",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
if inputs_shared.get("value_controller_inputs", None) is None:
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
value_controller_inputs = torch.tensor(inputs_shared["value_controller_inputs"]).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
pipe.load_models_to_device(["value_controller_inputs"])
|
||||||
|
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
|
||||||
|
value_emb = value_emb.unsqueeze(0)
|
||||||
|
value_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
|
||||||
|
inputs_posi["prompt_emb"] = torch.concat([inputs_posi["prompt_emb"], value_emb], dim=1)
|
||||||
|
inputs_posi["text_ids"] = torch.concat([inputs_posi["text_ids"], value_text_ids], dim=1)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InfinitYou:
|
class InfinitYou:
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||||
from facexlib.recognition import init_recognition_model
|
from facexlib.recognition import init_recognition_model
|
||||||
|
|||||||
20
examples/flux/model_inference/FLUX.1-dev-ValueControl.py
Normal file
20
examples/flux/model_inference/FLUX.1-dev-ValueControl.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffsynth.models.flux_value_control import SingleValueEncoder, MultiValueEncoder
|
||||||
|
pipe.value_controller = MultiValueEncoder(encoders=[SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder()]).to(dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
image = pipe(prompt="a cat", seed=0, value_controller_inputs=[0.5, 0.5, 1, 0])
|
||||||
|
image.save("flux.jpg")
|
||||||
120
examples/flux/model_training/train_value_controller.py
Normal file
120
examples/flux/model_training/train_value_controller.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import torch, os, json
|
||||||
|
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||||
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
|
||||||
|
from diffsynth.models.lora import FluxLoRAConverter
|
||||||
|
from diffsynth.models.flux_value_control import SingleValueEncoder, MultiValueEncoder
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
model_configs = []
|
||||||
|
if model_paths is not None:
|
||||||
|
model_paths = json.loads(model_paths)
|
||||||
|
model_configs += [ModelConfig(path=path) for path in model_paths]
|
||||||
|
if model_id_with_origin_paths is not None:
|
||||||
|
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||||
|
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
||||||
|
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||||
|
|
||||||
|
self.pipe.value_controller = MultiValueEncoder(encoders=[SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder(), SingleValueEncoder()]).to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Reset training scheduler
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
# Freeze untrainable models
|
||||||
|
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||||
|
|
||||||
|
# Add LoRA to the base models
|
||||||
|
if lora_base_model is not None:
|
||||||
|
model = self.add_lora_to_model(
|
||||||
|
getattr(self.pipe, lora_base_model),
|
||||||
|
target_modules=lora_target_modules.split(","),
|
||||||
|
lora_rank=lora_rank
|
||||||
|
)
|
||||||
|
setattr(self.pipe, lora_base_model, model)
|
||||||
|
|
||||||
|
# Store other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
|
||||||
|
|
||||||
|
def forward_preprocess(self, data):
|
||||||
|
# CFG-sensitive parameters
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {}
|
||||||
|
|
||||||
|
# CFG-unsensitive parameters
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"embedded_guidance": 1,
|
||||||
|
"t5_sequence_length": 512,
|
||||||
|
"tiled": False,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extra inputs
|
||||||
|
for extra_input in self.extra_inputs:
|
||||||
|
inputs_shared[extra_input] = data[extra_input]
|
||||||
|
|
||||||
|
# Pipeline units will automatically process the input parameters.
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
return {**inputs_shared, **inputs_posi}
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.forward_preprocess(data)
|
||||||
|
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||||
|
loss = self.pipe.training_loss(**models, **inputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = flux_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
dataset = ImageDataset(args=args)
|
||||||
|
model = FluxTrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
args.output_path,
|
||||||
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||||
|
)
|
||||||
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||||
|
launch_training_task(
|
||||||
|
dataset, model, model_logger, optimizer, scheduler,
|
||||||
|
num_epochs=args.num_epochs,
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user