diff --git a/.gitignore b/.gitignore index 6fd0d8e..a511cf2 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ /models /scripts /diffusers +/.vscode *.pkl *.safetensors *.pth diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index face319..52f1f02 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -9,6 +9,7 @@ from ..utils.lora import GeneralLoRALoader from ..models.model_loader import ModelPool from ..utils.controlnet import ControlNetInput from ..core.device import get_device_name, IS_NPU_AVAILABLE +from .skills import load_skill_model, load_skill_data_processor class PipelineUnit: @@ -338,6 +339,14 @@ class BasePipeline(torch.nn.Module): else: noise_pred = noise_pred_posi return noise_pred + + + def load_training_skill_model(self, model_config: ModelConfig = None): + if model_config is not None: + model_config.download_if_necessary() + self.skill_model = load_skill_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device) + self.skill_data_processor = load_skill_data_processor(model_config.path)() + class PipelineUnitGraph: diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py index b8c6c6a..9dc90e8 100644 --- a/diffsynth/diffusion/parsers.py +++ b/diffsynth/diffusion/parsers.py @@ -60,6 +60,10 @@ def add_gradient_config(parser: argparse.ArgumentParser): parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") return parser +def add_skill_model_config(parser: argparse.ArgumentParser): + parser.add_argument("--skill_model_id_or_path", type=str, default=None, help="Model ID of path of skill models.") + return parser + def add_general_config(parser: argparse.ArgumentParser): parser = add_dataset_base_config(parser) parser = add_model_config(parser) @@ -67,4 +71,5 @@ def add_general_config(parser: argparse.ArgumentParser): parser = add_output_config(parser) parser = add_lora_config(parser) parser = add_gradient_config(parser) + parser = add_skill_model_config(parser) return parser diff --git a/diffsynth/diffusion/skills.py b/diffsynth/diffusion/skills.py new file mode 100644 index 0000000..ced2fe4 --- /dev/null +++ b/diffsynth/diffusion/skills.py @@ -0,0 +1,137 @@ +import torch, os, importlib, warnings, json +from typing import Dict, List, Tuple, Union +from ..core import ModelConfig, load_model +from ..core.device.npu_compatible_device import get_device_type + + +SkillCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]] + + +class SkillModel(torch.nn.Module): + def __init__(self): + super().__init__() + + @torch.no_grad() + def process_inputs(self, pipe=None, **kwargs): + return {} + + def forward(self, **kwargs) -> SkillCache: + raise NotImplementedError() + + +class MultiSkillModel(SkillModel): + def __init__(self, models: List[SkillModel]): + super().__init__() + if not isinstance(models, list): + models = [models] + self.models = torch.nn.ModuleList(models) + + def merge(self, kv_cache_list: List[SkillCache]) -> SkillCache: + names = {} + for kv_cache in kv_cache_list: + for name in kv_cache: + names[name] = None + kv_cache_merged = {} + for name in names: + kv_list = [kv_cache.get(name) for kv_cache in kv_cache_list] + kv_list = [kv for kv in kv_list if kv is not None] + if len(kv_list) > 0: + k = torch.concat([kv[0] for kv in kv_list], dim=1) + v = torch.concat([kv[1] for kv in kv_list], dim=1) + kv_cache_merged[name] = (k, v) + return kv_cache_merged + + @torch.no_grad() + def process_inputs(self, pipe=None, inputs: List[Dict] = None, **kwargs): + return [(i["model_id"], self.models[i["model_id"]].process_inputs(pipe=pipe, **i)) for i in inputs] + + def forward(self, inputs: List[Tuple[int, Dict]], **kwargs) -> SkillCache: + kv_cache_list = [] + for model_id, model_inputs in inputs: + kv_cache = self.models[model_id](**model_inputs) + kv_cache_list.append(kv_cache) + return self.merge(kv_cache_list) + + +def load_skill_model(path, torch_dtype=torch.bfloat16, device="cuda", verbose=1): + spec = importlib.util.spec_from_file_location("skill_model", os.path.join(path, "model.py")) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + model = load_model( + model_class=getattr(module, 'SKILL_MODEL'), + config=getattr(module, 'SKILL_MODEL_CONFIG') if hasattr(module, 'SKILL_MODEL_CONFIG') else None, + path=os.path.join(path, getattr(module, 'SKILL_MODEL_PATH')), + torch_dtype=torch_dtype, + device=device, + ) + if verbose > 0: + metadata = { + "model_architecture": getattr(module, 'SKILL_MODEL').__name__, + "code_path": os.path.join(path, "model.py"), + "weight_path": os.path.join(path, getattr(module, 'SKILL_MODEL_PATH')), + } + print(f"Skill model loaded: {json.dumps(metadata, indent=4)}") + return model + + +def load_skill_data_processor(path): + spec = importlib.util.spec_from_file_location("skill_model", os.path.join(path, "model.py")) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, 'SKILL_DATA_PROCESSOR'): + processor = getattr(module, 'SKILL_DATA_PROCESSOR') + return processor + else: + return None + + +class SkillsPipeline(MultiSkillModel): + def __init__(self, models: List[SkillModel]): + super().__init__(models) + + @staticmethod + def check_vram_config(model_config: ModelConfig): + params = [ + model_config.offload_device, model_config.offload_dtype, + model_config.onload_device, model_config.onload_dtype, + model_config.preparing_device, model_config.preparing_dtype, + model_config.computation_device, model_config.computation_dtype, + ] + for param in params: + if param is not None: + warnings.warn("SkillsPipeline doesn't support VRAM management. VRAM config will be ignored.") + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + ): + models = [] + for model_config in model_configs: + SkillsPipeline.check_vram_config(model_config) + model_config.download_if_necessary() + model = load_skill_model(model_config.path, torch_dtype=torch_dtype, device=device) + models.append(model) + pipe = SkillsPipeline(models) + return pipe + + def call_single_side(self, pipe = None, inputs: List[Dict] = None): + inputs = self.process_inputs(pipe=pipe, inputs=inputs) + skill_cache = self.forward(inputs) + return skill_cache + + @torch.no_grad() + def __call__( + self, + pipe = None, + inputs: List[Dict] = None, + positive_inputs: List[Dict] = None, + negative_inputs: List[Dict] = None, + ): + shared_cache = self.call_single_side(pipe=pipe, inputs=inputs or []) + positive_cache = self.call_single_side(pipe=pipe, inputs=positive_inputs or []) + negative_cache = self.call_single_side(pipe=pipe, inputs=negative_inputs or []) + positive_cache = self.merge([positive_cache, shared_cache]) + negative_cache = self.merge([negative_cache, shared_cache]) + return {"skill_cache": positive_cache, "negative_skill_cache": negative_cache} diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py index 0a00118..37c90d0 100644 --- a/diffsynth/diffusion/training_module.py +++ b/diffsynth/diffusion/training_module.py @@ -6,6 +6,7 @@ from peft import LoraConfig, inject_adapter_in_model class GeneralUnit_RemoveCache(PipelineUnit): + # Only used for training def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()): super().__init__(take_over=True) self.required_params = required_params @@ -27,6 +28,40 @@ class GeneralUnit_RemoveCache(PipelineUnit): return inputs_shared, inputs_posi, inputs_nega +class GeneralUnit_SkillProcessInputs(PipelineUnit): + # Only used for training + def __init__(self, data_processor): + super().__init__( + input_params=("skill_inputs",), + output_params=("skill_inputs",), + ) + self.data_processor = data_processor + + def process(self, pipe, skill_inputs): + if not hasattr(pipe, "skill_model"): + return {} + if self.data_processor is not None: + skill_inputs = self.data_processor(**skill_inputs) + skill_inputs = pipe.skill_model.process_inputs(pipe=pipe, **skill_inputs) + return {"skill_inputs": skill_inputs} + + +class GeneralUnit_SkillForward(PipelineUnit): + # Only used for training + def __init__(self): + super().__init__( + input_params=("skill_inputs",), + output_params=("skill_cache",), + onload_model_names=("skill_model",) + ) + + def process(self, pipe, skill_inputs): + if not hasattr(pipe, "skill_model"): + return {} + skill_cache = pipe.skill_model.forward(**skill_inputs) + return {"skill_cache": skill_cache} + + class DiffusionTrainingModule(torch.nn.Module): def __init__(self): super().__init__() @@ -209,6 +244,16 @@ class DiffusionTrainingModule(torch.nn.Module): else: lora_target_modules = lora_target_modules.split(",") return lora_target_modules + + + def load_training_skill_model(self, pipe, path_or_model_id): + if path_or_model_id is None: + return pipe + model_config = self.parse_path_or_model_id(path_or_model_id) + pipe.load_training_skill_model(model_config) + pipe.units.append(GeneralUnit_SkillProcessInputs(pipe.skill_data_processor)) + pipe.units.append(GeneralUnit_SkillForward()) + return pipe def switch_pipe_to_training_mode( diff --git a/diffsynth/models/flux2_dit.py b/diffsynth/models/flux2_dit.py index a1bd02a..8be717f 100644 --- a/diffsynth/models/flux2_dit.py +++ b/diffsynth/models/flux2_dit.py @@ -364,78 +364,7 @@ class Flux2FeedForward(nn.Module): return x -class Flux2AttnProcessor: - _attention_backend = None - _parallel_config = None - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") - - def __call__( - self, - attn: "Flux2Attention", - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( - attn, hidden_states, encoder_hidden_states - ) - - query = query.unflatten(-1, (attn.heads, -1)) - key = key.unflatten(-1, (attn.heads, -1)) - value = value.unflatten(-1, (attn.heads, -1)) - - query = attn.norm_q(query) - key = attn.norm_k(key) - - if attn.added_kv_proj_dim is not None: - encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) - encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) - encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) - - encoder_query = attn.norm_added_q(encoder_query) - encoder_key = attn.norm_added_k(encoder_key) - - query = torch.cat([encoder_query, query], dim=1) - key = torch.cat([encoder_key, key], dim=1) - value = torch.cat([encoder_value, value], dim=1) - - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) - key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - - query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype) - hidden_states = attention_forward( - query, - key, - value, - q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", - ) - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( - [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 - ) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - - if encoder_hidden_states is not None: - return hidden_states, encoder_hidden_states - else: - return hidden_states - - class Flux2Attention(torch.nn.Module): - _default_processor_cls = Flux2AttnProcessor - _available_processors = [Flux2AttnProcessor] - def __init__( self, query_dim: int, @@ -449,7 +378,6 @@ class Flux2Attention(torch.nn.Module): eps: float = 1e-5, out_dim: int = None, elementwise_affine: bool = True, - processor=None, ): super().__init__() @@ -485,59 +413,45 @@ class Flux2Attention(torch.nn.Module): self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) - if processor is None: - processor = self._default_processor_cls() - self.processor = processor - def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + kv_cache = None, **kwargs, ) -> torch.Tensor: - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} - return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) - - -class Flux2ParallelSelfAttnProcessor: - _attention_backend = None - _parallel_config = None - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") - - def __call__( - self, - attn: "Flux2ParallelSelfAttention", - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # Parallel in (QKV + MLP in) projection - hidden_states = attn.to_qkv_mlp_proj(hidden_states) - qkv, mlp_hidden_states = torch.split( - hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + self, hidden_states, encoder_hidden_states ) - # Handle the attention logic - query, key, value = qkv.chunk(3, dim=-1) + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) - query = query.unflatten(-1, (attn.heads, -1)) - key = key.unflatten(-1, (attn.heads, -1)) - value = value.unflatten(-1, (attn.heads, -1)) + query = self.norm_q(query) + key = self.norm_k(key) - query = attn.norm_q(query) - key = attn.norm_k(key) + if self.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (self.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.heads, -1)) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype) + if kv_cache is not None: + key = torch.concat([key, kv_cache[0]], dim=1) + value = torch.concat([value, kv_cache[1]], dim=1) hidden_states = attention_forward( query, key, @@ -547,30 +461,22 @@ class Flux2ParallelSelfAttnProcessor: hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) - # Handle the feedforward (FF) logic - mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) - # Concatenate and parallel output projection - hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) - hidden_states = attn.to_out(hidden_states) + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) - return hidden_states + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states class Flux2ParallelSelfAttention(torch.nn.Module): - """ - Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. - - This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF) - input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B - paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. - """ - - _default_processor_cls = Flux2ParallelSelfAttnProcessor - _available_processors = [Flux2ParallelSelfAttnProcessor] - # Does not support QKV fusion as the QKV projections are always fused - _supports_qkv_fusion = False - def __init__( self, query_dim: int, @@ -614,20 +520,54 @@ class Flux2ParallelSelfAttention(torch.nn.Module): # Fused attention output projection + MLP output projection self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) - if processor is None: - processor = self._default_processor_cls() - self.processor = processor - def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + kv_cache = None, **kwargs, ) -> torch.Tensor: - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} - return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) + # Parallel in (QKV + MLP in) projection + hidden_states = self.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, [3 * self.inner_dim, self.mlp_hidden_dim * self.mlp_mult_factor], dim=-1 + ) + + # Handle the attention logic + query, key, value = qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + if kv_cache is not None: + key = torch.concat([key, kv_cache[0]], dim=1) + value = torch.concat([value, kv_cache[1]], dim=1) + hidden_states = attention_forward( + query, + key, + value, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # Handle the feedforward (FF) logic + mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) + + # Concatenate and parallel output projection + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + hidden_states = self.to_out(hidden_states) + + return hidden_states class Flux2SingleTransformerBlock(nn.Module): @@ -657,7 +597,6 @@ class Flux2SingleTransformerBlock(nn.Module): eps=eps, mlp_ratio=mlp_ratio, mlp_mult_factor=2, - processor=Flux2ParallelSelfAttnProcessor(), ) def forward( @@ -669,6 +608,7 @@ class Flux2SingleTransformerBlock(nn.Module): joint_attention_kwargs: Optional[Dict[str, Any]] = None, split_hidden_states: bool = False, text_seq_len: Optional[int] = None, + kv_cache = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already # concatenated @@ -685,6 +625,7 @@ class Flux2SingleTransformerBlock(nn.Module): attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, + kv_cache=kv_cache, **joint_attention_kwargs, ) @@ -725,7 +666,6 @@ class Flux2TransformerBlock(nn.Module): added_proj_bias=bias, out_bias=bias, eps=eps, - processor=Flux2AttnProcessor(), ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) @@ -742,6 +682,7 @@ class Flux2TransformerBlock(nn.Module): temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, + kv_cache = None, ) -> Tuple[torch.Tensor, torch.Tensor]: joint_attention_kwargs = joint_attention_kwargs or {} @@ -762,6 +703,7 @@ class Flux2TransformerBlock(nn.Module): hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + kv_cache=kv_cache, **joint_attention_kwargs, ) @@ -969,6 +911,7 @@ class Flux2DiT(torch.nn.Module): txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, + kv_cache = None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, ): @@ -1013,7 +956,7 @@ class Flux2DiT(torch.nn.Module): ) # 4. Double Stream Transformer Blocks - for index_block, block in enumerate(self.transformer_blocks): + for block_id, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = gradient_checkpoint_forward( block, use_gradient_checkpointing=use_gradient_checkpointing, @@ -1024,12 +967,13 @@ class Flux2DiT(torch.nn.Module): temb_mod_params_txt=double_stream_mod_txt, image_rotary_emb=concat_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, + kv_cache=None if kv_cache is None else kv_cache.get(f"double_{block_id}"), ) # Concatenate text and image streams for single-block inference hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) # 5. Single Stream Transformer Blocks - for index_block, block in enumerate(self.single_transformer_blocks): + for block_id, block in enumerate(self.single_transformer_blocks): hidden_states = gradient_checkpoint_forward( block, use_gradient_checkpointing=use_gradient_checkpointing, @@ -1039,6 +983,7 @@ class Flux2DiT(torch.nn.Module): temb_mod_params=single_stream_mod, image_rotary_emb=concat_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, + kv_cache=None if kv_cache is None else kv_cache.get(f"single_{block_id}"), ) # Remove text tokens from concatenated stream hidden_states = hidden_states[:, num_txt_tokens:, ...] diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 34f4d27..9dda9bf 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -93,6 +93,9 @@ class Flux2ImagePipeline(BasePipeline): initial_noise: torch.Tensor = None, # Steps num_inference_steps: int = 30, + # KV Cache + skill_cache = None, + negative_skill_cache = None, # Progress bar progress_bar_cmd = tqdm, ): @@ -101,9 +104,11 @@ class Flux2ImagePipeline(BasePipeline): # Parameters inputs_posi = { "prompt": prompt, + "skill_cache": skill_cache, } inputs_nega = { "negative_prompt": negative_prompt, + "skill_cache": negative_skill_cache, } inputs_shared = { "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, @@ -570,6 +575,7 @@ def model_fn_flux2( image_ids=None, edit_latents=None, edit_image_ids=None, + skill_cache=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, @@ -587,6 +593,7 @@ def model_fn_flux2( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=image_ids, + kv_cache=skill_cache, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) diff --git a/examples/flux2/model_inference/FLUX.2-klein-base-4B-skills.py b/examples/flux2/model_inference/FLUX.2-klein-base-4B-skills.py new file mode 100644 index 0000000..fcf7992 --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-klein-base-4B-skills.py @@ -0,0 +1,56 @@ +from diffsynth.diffusion.skills import SkillsPipeline +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch +from PIL import Image + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +skills = SkillsPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/F2KB4B-Skills-ControlNet"), + ModelConfig(model_id="DiffSynth-Studio/F2KB4B-Skills-Brightness"), + ], +) +skill_cache = skills( + positive_inputs = [ + { + "model_id": 0, + "image": Image.open("xxx.jpg"), + "prompt": "一位长发少女,四周环绕着魔法粒子", + }, + { + "model_id": 1, + "scale": 0.6, + }, + ], + negative_inputs = [ + { + "model_id": 0, + "image": Image.open("xxx.jpg"), + "prompt": "一位长发少女,四周环绕着魔法粒子", + }, + { + "model_id": 1, + "scale": 0.5, + }, + ], + pipe=pipe, +) +image = pipe( + prompt="一位长发少女,四周环绕着魔法粒子", + seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4, + height=1024, width=1024, + **skill_cache, +) +image.save("image.jpg") diff --git a/examples/flux2/model_training/full/FLUX.2-klein-base-4B-skills.sh b/examples/flux2/model_training/full/FLUX.2-klein-base-4B-skills.sh new file mode 100644 index 0000000..d56634b --- /dev/null +++ b/examples/flux2/model_training/full/FLUX.2-klein-base-4B-skills.sh @@ -0,0 +1,16 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path /mnt/nas1/duanzhongjie.dzj/dataset/ImagePulseV2 \ + --dataset_metadata_path /mnt/nas1/duanzhongjie.dzj/dataset/ImagePulseV2/metadata_example_ti2ti.jsonl \ + --extra_inputs "skill_inputs" \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --skill_model_id_or_path "models/base" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-4 \ + --num_epochs 999 \ + --remove_prefix_in_ckpt "pipe.skill_model." \ + --output_path "./models/train/FLUX.2-klein-base-4B-skills_full" \ + --trainable_models "skill_model" \ + --use_gradient_checkpointing \ + --save_steps 200 diff --git a/examples/flux2/model_training/scripts/convert_base_model_to_skill_model.py b/examples/flux2/model_training/scripts/convert_base_model_to_skill_model.py new file mode 100644 index 0000000..21fab7f --- /dev/null +++ b/examples/flux2/model_training/scripts/convert_base_model_to_skill_model.py @@ -0,0 +1,60 @@ +from diffsynth import load_state_dict +from safetensors.torch import save_file +import torch + + +def Flux2DiTStateDictConverter(state_dict): + rename_dict = { + "time_guidance_embed.timestep_embedder.linear_1.weight": "time_guidance_embed.timestep_embedder.0.weight", + "time_guidance_embed.timestep_embedder.linear_2.weight": "time_guidance_embed.timestep_embedder.2.weight", + "x_embedder.weight": "img_embedder.weight", + "context_embedder.weight": "txt_embedder.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + state_dict_[rename_dict[name]] = state_dict[name] + elif name.startswith("transformer_blocks"): + if name.endswith("attn.to_q.weight"): + state_dict_[name.replace("to_q", "img_to_qkv").replace(".attn.", ".")] = torch.concat([ + state_dict[name.replace("to_q", "to_q")], + state_dict[name.replace("to_q", "to_k")], + state_dict[name.replace("to_q", "to_v")], + ], dim=0) + elif name.endswith("attn.to_k.weight") or name.endswith("attn.to_v.weight"): + continue + elif name.endswith("attn.to_out.0.weight"): + state_dict_[name.replace("attn.to_out.0.weight", "img_to_out.weight")] = state_dict[name] + elif name.endswith("attn.norm_q.weight"): + state_dict_[name.replace("attn.norm_q.weight", "img_norm_q.weight")] = state_dict[name] + elif name.endswith("attn.norm_k.weight"): + state_dict_[name.replace("attn.norm_k.weight", "img_norm_k.weight")] = state_dict[name] + elif name.endswith("attn.norm_added_q.weight"): + state_dict_[name.replace("attn.norm_added_q.weight", "txt_norm_q.weight")] = state_dict[name] + elif name.endswith("attn.norm_added_k.weight"): + state_dict_[name.replace("attn.norm_added_k.weight", "txt_norm_k.weight")] = state_dict[name] + elif name.endswith("attn.to_add_out.weight"): + state_dict_[name.replace("attn.to_add_out.weight", "txt_to_out.weight")] = state_dict[name] + elif name.endswith("attn.add_q_proj.weight"): + state_dict_[name.replace("add_q_proj", "txt_to_qkv").replace(".attn.", ".")] = torch.concat([ + state_dict[name.replace("add_q_proj", "add_q_proj")], + state_dict[name.replace("add_q_proj", "add_k_proj")], + state_dict[name.replace("add_q_proj", "add_v_proj")], + ], dim=0) + elif ".ff." in name: + state_dict_[name.replace(".ff.", ".img_ff.")] = state_dict[name] + elif ".ff_context." in name: + state_dict_[name.replace(".ff_context.", ".txt_ff.")] = state_dict[name] + elif name.endswith("attn.add_k_proj.weight") or name.endswith("attn.add_v_proj.weight"): + continue + else: + state_dict_[name] = state_dict[name] + elif name.startswith("single_transformer_blocks"): + state_dict_[name.replace(".attn.", ".")] = state_dict[name] + else: + state_dict_[name] = state_dict[name] + return state_dict_ + + +state_dict = load_state_dict("xxx.safetensors") +save_file(state_dict, "yyy.safetensors") diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index 6101687..7a15267 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -18,6 +18,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule): extra_inputs=None, fp8_models=None, offload_models=None, + skill_model_id_or_path=None, device="cpu", task="sft", ): @@ -26,6 +27,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule): model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/")) self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) + self.pipe = self.load_training_skill_model(self.pipe, skill_model_id_or_path) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) # Training mode @@ -126,6 +128,7 @@ if __name__ == "__main__": extra_inputs=args.extra_inputs, fp8_models=args.fp8_models, offload_models=args.offload_models, + skill_model_id_or_path=args.skill_model_id_or_path, task=args.task, device="cpu" if args.initialize_model_on_cpu else accelerator.device, )