diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 52f1f02..f7dc5ab 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -9,7 +9,7 @@ from ..utils.lora import GeneralLoRALoader from ..models.model_loader import ModelPool from ..utils.controlnet import ControlNetInput from ..core.device import get_device_name, IS_NPU_AVAILABLE -from .skills import load_skill_model, load_skill_data_processor +from .template import load_template_model, load_template_data_processor class PipelineUnit: @@ -320,14 +320,21 @@ class BasePipeline(torch.nn.Module): def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): + # Positive side forward if inputs_shared.get("positive_only_lora", None) is not None: - self.clear_lora(verbose=0) self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + if cfg_scale != 1.0: - if inputs_shared.get("positive_only_lora", None) is not None: - self.clear_lora(verbose=0) + # Negative side forward + if inputs_shared.get("negative_only_lora", None) is not None: + self.load_lora(self.dit, state_dict=inputs_shared["negative_only_lora"], verbose=0) noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) + if inputs_shared.get("negative_only_lora", None) is not None: + self.clear_lora(verbose=0) + if isinstance(noise_pred_posi, tuple): # Separately handling different output types of latents, eg. video and audio latents. noise_pred = tuple( @@ -341,11 +348,11 @@ class BasePipeline(torch.nn.Module): return noise_pred - def load_training_skill_model(self, model_config: ModelConfig = None): + def load_training_template_model(self, model_config: ModelConfig = None): if model_config is not None: model_config.download_if_necessary() - self.skill_model = load_skill_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device) - self.skill_data_processor = load_skill_data_processor(model_config.path)() + self.template_model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device) + self.template_data_processor = load_template_data_processor(model_config.path)() diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py index 9dc90e8..3bcbe4b 100644 --- a/diffsynth/diffusion/parsers.py +++ b/diffsynth/diffusion/parsers.py @@ -60,8 +60,8 @@ def add_gradient_config(parser: argparse.ArgumentParser): parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") return parser -def add_skill_model_config(parser: argparse.ArgumentParser): - parser.add_argument("--skill_model_id_or_path", type=str, default=None, help="Model ID of path of skill models.") +def add_template_model_config(parser: argparse.ArgumentParser): + parser.add_argument("--template_model_id_or_path", type=str, default=None, help="Model ID of path of template models.") return parser def add_general_config(parser: argparse.ArgumentParser): @@ -71,5 +71,5 @@ def add_general_config(parser: argparse.ArgumentParser): parser = add_output_config(parser) parser = add_lora_config(parser) parser = add_gradient_config(parser) - parser = add_skill_model_config(parser) + parser = add_template_model_config(parser) return parser diff --git a/diffsynth/diffusion/skills.py b/diffsynth/diffusion/skills.py deleted file mode 100644 index ced2fe4..0000000 --- a/diffsynth/diffusion/skills.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch, os, importlib, warnings, json -from typing import Dict, List, Tuple, Union -from ..core import ModelConfig, load_model -from ..core.device.npu_compatible_device import get_device_type - - -SkillCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]] - - -class SkillModel(torch.nn.Module): - def __init__(self): - super().__init__() - - @torch.no_grad() - def process_inputs(self, pipe=None, **kwargs): - return {} - - def forward(self, **kwargs) -> SkillCache: - raise NotImplementedError() - - -class MultiSkillModel(SkillModel): - def __init__(self, models: List[SkillModel]): - super().__init__() - if not isinstance(models, list): - models = [models] - self.models = torch.nn.ModuleList(models) - - def merge(self, kv_cache_list: List[SkillCache]) -> SkillCache: - names = {} - for kv_cache in kv_cache_list: - for name in kv_cache: - names[name] = None - kv_cache_merged = {} - for name in names: - kv_list = [kv_cache.get(name) for kv_cache in kv_cache_list] - kv_list = [kv for kv in kv_list if kv is not None] - if len(kv_list) > 0: - k = torch.concat([kv[0] for kv in kv_list], dim=1) - v = torch.concat([kv[1] for kv in kv_list], dim=1) - kv_cache_merged[name] = (k, v) - return kv_cache_merged - - @torch.no_grad() - def process_inputs(self, pipe=None, inputs: List[Dict] = None, **kwargs): - return [(i["model_id"], self.models[i["model_id"]].process_inputs(pipe=pipe, **i)) for i in inputs] - - def forward(self, inputs: List[Tuple[int, Dict]], **kwargs) -> SkillCache: - kv_cache_list = [] - for model_id, model_inputs in inputs: - kv_cache = self.models[model_id](**model_inputs) - kv_cache_list.append(kv_cache) - return self.merge(kv_cache_list) - - -def load_skill_model(path, torch_dtype=torch.bfloat16, device="cuda", verbose=1): - spec = importlib.util.spec_from_file_location("skill_model", os.path.join(path, "model.py")) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - model = load_model( - model_class=getattr(module, 'SKILL_MODEL'), - config=getattr(module, 'SKILL_MODEL_CONFIG') if hasattr(module, 'SKILL_MODEL_CONFIG') else None, - path=os.path.join(path, getattr(module, 'SKILL_MODEL_PATH')), - torch_dtype=torch_dtype, - device=device, - ) - if verbose > 0: - metadata = { - "model_architecture": getattr(module, 'SKILL_MODEL').__name__, - "code_path": os.path.join(path, "model.py"), - "weight_path": os.path.join(path, getattr(module, 'SKILL_MODEL_PATH')), - } - print(f"Skill model loaded: {json.dumps(metadata, indent=4)}") - return model - - -def load_skill_data_processor(path): - spec = importlib.util.spec_from_file_location("skill_model", os.path.join(path, "model.py")) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - if hasattr(module, 'SKILL_DATA_PROCESSOR'): - processor = getattr(module, 'SKILL_DATA_PROCESSOR') - return processor - else: - return None - - -class SkillsPipeline(MultiSkillModel): - def __init__(self, models: List[SkillModel]): - super().__init__(models) - - @staticmethod - def check_vram_config(model_config: ModelConfig): - params = [ - model_config.offload_device, model_config.offload_dtype, - model_config.onload_device, model_config.onload_dtype, - model_config.preparing_device, model_config.preparing_dtype, - model_config.computation_device, model_config.computation_dtype, - ] - for param in params: - if param is not None: - warnings.warn("SkillsPipeline doesn't support VRAM management. VRAM config will be ignored.") - - @staticmethod - def from_pretrained( - torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = get_device_type(), - model_configs: list[ModelConfig] = [], - ): - models = [] - for model_config in model_configs: - SkillsPipeline.check_vram_config(model_config) - model_config.download_if_necessary() - model = load_skill_model(model_config.path, torch_dtype=torch_dtype, device=device) - models.append(model) - pipe = SkillsPipeline(models) - return pipe - - def call_single_side(self, pipe = None, inputs: List[Dict] = None): - inputs = self.process_inputs(pipe=pipe, inputs=inputs) - skill_cache = self.forward(inputs) - return skill_cache - - @torch.no_grad() - def __call__( - self, - pipe = None, - inputs: List[Dict] = None, - positive_inputs: List[Dict] = None, - negative_inputs: List[Dict] = None, - ): - shared_cache = self.call_single_side(pipe=pipe, inputs=inputs or []) - positive_cache = self.call_single_side(pipe=pipe, inputs=positive_inputs or []) - negative_cache = self.call_single_side(pipe=pipe, inputs=negative_inputs or []) - positive_cache = self.merge([positive_cache, shared_cache]) - negative_cache = self.merge([negative_cache, shared_cache]) - return {"skill_cache": positive_cache, "negative_skill_cache": negative_cache} diff --git a/diffsynth/diffusion/template.py b/diffsynth/diffusion/template.py new file mode 100644 index 0000000..6b9a53f --- /dev/null +++ b/diffsynth/diffusion/template.py @@ -0,0 +1,176 @@ +import torch, os, importlib, warnings, json, inspect +from typing import Dict, List, Tuple, Union +from ..core import ModelConfig, load_model +from ..core.device.npu_compatible_device import get_device_type + + +KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]] + + +class TemplateModel(torch.nn.Module): + def __init__(self): + super().__init__() + + @torch.no_grad() + def process_inputs(self, **kwargs): + return {} + + def forward(self, **kwargs): + raise NotImplementedError() + + +def check_template_model_format(model): + if not hasattr(model, "process_inputs"): + raise NotImplementedError("`process_inputs` is not implemented in the Template model.") + if "kwargs" not in inspect.signature(model.process_inputs).parameters: + raise NotImplementedError("`**kwargs` is not included in `process_inputs`.") + if not hasattr(model, "forward"): + raise NotImplementedError("`forward` is not implemented in the Template model.") + if "kwargs" not in inspect.signature(model.forward).parameters: + raise NotImplementedError("`**kwargs` is not included in `forward`.") + + +def load_template_model(path, torch_dtype=torch.bfloat16, device="cuda", verbose=1): + spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py")) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + template_model_path = getattr(module, 'TEMPLATE_MODEL_PATH') if hasattr(module, 'TEMPLATE_MODEL_PATH') else None + if template_model_path is not None: + # With `TEMPLATE_MODEL_PATH`, a pretrained model will be loaded. + model = load_model( + model_class=getattr(module, 'TEMPLATE_MODEL'), + config=getattr(module, 'TEMPLATE_MODEL_CONFIG') if hasattr(module, 'TEMPLATE_MODEL_CONFIG') else None, + path=os.path.join(path, getattr(module, 'TEMPLATE_MODEL_PATH')), + torch_dtype=torch_dtype, + device=device, + ) + else: + # Without `TEMPLATE_MODEL_PATH`, a randomly initialized model or a non-model module will be loaded. + model = module.TEMPLATE_MODEL() + if hasattr(model, "to"): + model = model.to(dtype=torch_dtype, device=device) + if hasattr(model, "eval"): + model = model.eval() + check_template_model_format(model) + if verbose > 0: + metadata = { + "model_architecture": getattr(module, 'TEMPLATE_MODEL').__name__, + "code_path": os.path.join(path, "model.py"), + "weight_path": template_model_path, + } + print(f"Template model loaded: {json.dumps(metadata, indent=4)}") + return model + + +def load_template_data_processor(path): + spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py")) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, 'TEMPLATE_DATA_PROCESSOR'): + processor = getattr(module, 'TEMPLATE_DATA_PROCESSOR') + return processor + else: + return None + + +class TemplatePipeline(torch.nn.Module): + def __init__(self, models: List[TemplateModel]): + super().__init__() + self.models = torch.nn.ModuleList(models) + + def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache: + names = {} + for kv_cache in kv_cache_list: + for name in kv_cache: + names[name] = None + kv_cache_merged = {} + for name in names: + kv_list = [kv_cache.get(name) for kv_cache in kv_cache_list] + kv_list = [kv for kv in kv_list if kv is not None] + if len(kv_list) > 0: + k = torch.concat([kv[0] for kv in kv_list], dim=1) + v = torch.concat([kv[1] for kv in kv_list], dim=1) + kv_cache_merged[name] = (k, v) + return kv_cache_merged + + def merge_template_cache(self, template_cache_list): + params = sorted(list(set(sum([list(template_cache.keys()) for template_cache in template_cache_list], [])))) + template_cache_merged = {} + for param in params: + data = [template_cache[param] for template_cache in template_cache_list if param in template_cache] + if param == "kv_cache": + data = self.merge_kv_cache(data) + elif len(data) == 1: + data = data[0] + else: + print(f"Conflict detected: `{param}` appears in the outputs of multiple Template models. Only the first one will be retained.") + data = data[0] + template_cache_merged[param] = data + return template_cache_merged + + @staticmethod + def check_vram_config(model_config: ModelConfig): + params = [ + model_config.offload_device, model_config.offload_dtype, + model_config.onload_device, model_config.onload_dtype, + model_config.preparing_device, model_config.preparing_dtype, + model_config.computation_device, model_config.computation_dtype, + ] + for param in params: + if param is not None: + warnings.warn("TemplatePipeline doesn't support VRAM management. VRAM config will be ignored.") + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + ): + models = [] + for model_config in model_configs: + TemplatePipeline.check_vram_config(model_config) + model_config.download_if_necessary() + model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device) + models.append(model) + pipe = TemplatePipeline(models) + return pipe + + @torch.no_grad() + def process_inputs(self, inputs: List[Dict], pipe=None, **kwargs): + return [(i.get("model_id", 0), self.models[i.get("model_id", 0)].process_inputs(pipe=pipe, **i)) for i in inputs] + + def forward(self, inputs: List[Tuple[int, Dict]], pipe=None, **kwargs): + template_cache = [] + for model_id, model_inputs in inputs: + kv_cache = self.models[model_id](pipe=pipe, **model_inputs) + template_cache.append(kv_cache) + return template_cache + + def call_single_side(self, pipe=None, inputs: List[Dict] = None): + inputs = self.process_inputs(pipe=pipe, inputs=inputs) + template_cache = self.forward(pipe=pipe, inputs=inputs) + template_cache = self.merge_template_cache(template_cache) + return template_cache + + @torch.no_grad() + def __call__( + self, + pipe=None, + template_inputs: List[Dict] = None, + negative_template_inputs: List[Dict] = None, + **kwargs, + ): + template_cache = self.call_single_side(pipe=pipe, inputs=template_inputs or []) + negative_template_cache = self.call_single_side(pipe=pipe, inputs=negative_template_inputs or []) + required_params = list(inspect.signature(pipe.__call__).parameters.keys()) + for param in template_cache: + if param in required_params: + kwargs[param] = template_cache[param] + else: + print(f"`{param}` is not included in the inputs of `{pipe.__class__.__name__}`. This parameter will be ignored.") + for param in negative_template_cache: + if "negative_" + param in required_params: + kwargs["negative_" + param] = negative_template_cache[param] + else: + print(f"`{'negative_' + param}` is not included in the inputs of `{pipe.__class__.__name__}`. This parameter will be ignored.") + return pipe(**kwargs) diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py index 37c90d0..e1d3852 100644 --- a/diffsynth/diffusion/training_module.py +++ b/diffsynth/diffusion/training_module.py @@ -28,38 +28,45 @@ class GeneralUnit_RemoveCache(PipelineUnit): return inputs_shared, inputs_posi, inputs_nega -class GeneralUnit_SkillProcessInputs(PipelineUnit): +class GeneralUnit_TemplateProcessInputs(PipelineUnit): # Only used for training def __init__(self, data_processor): super().__init__( - input_params=("skill_inputs",), - output_params=("skill_inputs",), + input_params=("template_inputs",), + output_params=("template_inputs",), ) self.data_processor = data_processor - def process(self, pipe, skill_inputs): - if not hasattr(pipe, "skill_model"): + def process(self, pipe, template_inputs): + if not hasattr(pipe, "template_model"): return {} if self.data_processor is not None: - skill_inputs = self.data_processor(**skill_inputs) - skill_inputs = pipe.skill_model.process_inputs(pipe=pipe, **skill_inputs) - return {"skill_inputs": skill_inputs} + template_inputs = self.data_processor(**template_inputs) + template_inputs = pipe.template_model.process_inputs(pipe=pipe, **template_inputs) + return {"template_inputs": template_inputs} -class GeneralUnit_SkillForward(PipelineUnit): +class GeneralUnit_TemplateForward(PipelineUnit): # Only used for training - def __init__(self): + def __init__(self, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False): super().__init__( - input_params=("skill_inputs",), - output_params=("skill_cache",), - onload_model_names=("skill_model",) + input_params=("template_inputs",), + output_params=("kv_cache",), + onload_model_names=("template_model",) ) + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload - def process(self, pipe, skill_inputs): - if not hasattr(pipe, "skill_model"): + def process(self, pipe, template_inputs): + if not hasattr(pipe, "template_model"): return {} - skill_cache = pipe.skill_model.forward(**skill_inputs) - return {"skill_cache": skill_cache} + template_cache = pipe.template_model.forward( + **template_inputs, + pipe=pipe, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload, + ) + return template_cache class DiffusionTrainingModule(torch.nn.Module): @@ -246,13 +253,13 @@ class DiffusionTrainingModule(torch.nn.Module): return lora_target_modules - def load_training_skill_model(self, pipe, path_or_model_id): + def load_training_template_model(self, pipe, path_or_model_id, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False): if path_or_model_id is None: return pipe model_config = self.parse_path_or_model_id(path_or_model_id) - pipe.load_training_skill_model(model_config) - pipe.units.append(GeneralUnit_SkillProcessInputs(pipe.skill_data_processor)) - pipe.units.append(GeneralUnit_SkillForward()) + pipe.load_training_template_model(model_config) + pipe.units.append(GeneralUnit_TemplateProcessInputs(pipe.template_data_processor)) + pipe.units.append(GeneralUnit_TemplateForward(use_gradient_checkpointing, use_gradient_checkpointing_offload)) return pipe diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 9dda9bf..4bec241 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -40,6 +40,7 @@ class Flux2ImagePipeline(BasePipeline): Flux2Unit_InputImageEmbedder(), Flux2Unit_EditImageEmbedder(), Flux2Unit_ImageIDs(), + Flux2Unit_Inpaint(), ] self.model_fn = model_fn_flux2 @@ -94,8 +95,15 @@ class Flux2ImagePipeline(BasePipeline): # Steps num_inference_steps: int = 30, # KV Cache - skill_cache = None, - negative_skill_cache = None, + kv_cache = None, + negative_kv_cache = None, + # LoRA + lora = None, + negative_lora = None, + # Inpaint + inpaint_mask: Image.Image = None, + inpaint_blur_size: int = None, + inpaint_blur_sigma: float = None, # Progress bar progress_bar_cmd = tqdm, ): @@ -104,11 +112,11 @@ class Flux2ImagePipeline(BasePipeline): # Parameters inputs_posi = { "prompt": prompt, - "skill_cache": skill_cache, + "kv_cache": kv_cache, } inputs_nega = { "negative_prompt": negative_prompt, - "skill_cache": negative_skill_cache, + "kv_cache": negative_kv_cache, } inputs_shared = { "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, @@ -117,6 +125,9 @@ class Flux2ImagePipeline(BasePipeline): "height": height, "width": width, "seed": seed, "rand_device": rand_device, "initial_noise": initial_noise, "num_inference_steps": num_inference_steps, + "positive_only_lora": lora, + "negative_only_lora": negative_lora, + "inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -565,6 +576,26 @@ class Flux2Unit_ImageIDs(PipelineUnit): return {"image_ids": image_ids} +class Flux2Unit_Inpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"), + output_params=("inpaint_mask",), + ) + + def process(self, pipe: Flux2ImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma): + if inpaint_mask is None: + return {} + inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 16, height // 16)), min_value=0, max_value=1) + inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True) + if inpaint_blur_size is not None and inpaint_blur_sigma is not None: + from torchvision.transforms import GaussianBlur + blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma) + inpaint_mask = blur(inpaint_mask) + inpaint_mask = rearrange(inpaint_mask, "B C H W -> B (H W) C") + return {"inpaint_mask": inpaint_mask} + + def model_fn_flux2( dit: Flux2DiT, latents=None, @@ -575,7 +606,7 @@ def model_fn_flux2( image_ids=None, edit_latents=None, edit_image_ids=None, - skill_cache=None, + kv_cache=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, @@ -593,7 +624,7 @@ def model_fn_flux2( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=image_ids, - kv_cache=skill_cache, + kv_cache=kv_cache, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) diff --git a/examples/flux2/model_inference/FLUX.2-klein-base-4B-skills.py b/examples/flux2/model_inference/FLUX.2-klein-base-4B-skills.py deleted file mode 100644 index fcf7992..0000000 --- a/examples/flux2/model_inference/FLUX.2-klein-base-4B-skills.py +++ /dev/null @@ -1,56 +0,0 @@ -from diffsynth.diffusion.skills import SkillsPipeline -from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig -import torch -from PIL import Image - - -pipe = Flux2ImagePipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), - ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"), - ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), - ], - tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), -) -skills = SkillsPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="DiffSynth-Studio/F2KB4B-Skills-ControlNet"), - ModelConfig(model_id="DiffSynth-Studio/F2KB4B-Skills-Brightness"), - ], -) -skill_cache = skills( - positive_inputs = [ - { - "model_id": 0, - "image": Image.open("xxx.jpg"), - "prompt": "一位长发少女,四周环绕着魔法粒子", - }, - { - "model_id": 1, - "scale": 0.6, - }, - ], - negative_inputs = [ - { - "model_id": 0, - "image": Image.open("xxx.jpg"), - "prompt": "一位长发少女,四周环绕着魔法粒子", - }, - { - "model_id": 1, - "scale": 0.5, - }, - ], - pipe=pipe, -) -image = pipe( - prompt="一位长发少女,四周环绕着魔法粒子", - seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4, - height=1024, width=1024, - **skill_cache, -) -image.save("image.jpg") diff --git a/examples/flux2/model_inference/Template-KleinBase4B.py b/examples/flux2/model_inference/Template-KleinBase4B.py new file mode 100644 index 0000000..5b2dd93 --- /dev/null +++ b/examples/flux2/model_inference/Template-KleinBase4B.py @@ -0,0 +1,256 @@ +from diffsynth.diffusion.template import TemplatePipeline +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch +from PIL import Image +import numpy as np + +def load_template_pipeline(model_ids): + template = TemplatePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ModelConfig(model_id=model_id) for model_id in model_ids], + ) + return template + +# Base Model +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +# image = pipe( +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# ) +# image.save("image_base.jpg") + +# template = load_template_pipeline(["DiffSynth-Studio/Template-KleinBase4B-Brightness"]) +# image = template( +# pipe, +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{"scale": 0.7}], +# negative_template_inputs = [{"scale": 0.5}] +# ) +# image.save("image_Brightness_light.jpg") +# image = template( +# pipe, +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{"scale": 0.5}], +# negative_template_inputs = [{"scale": 0.5}] +# ) +# image.save("image_Brightness_normal.jpg") +# image = template( +# pipe, +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{"scale": 0.3}], +# negative_template_inputs = [{"scale": 0.5}] +# ) +# image.save("image_Brightness_dark.jpg") + +# template = load_template_pipeline(["DiffSynth-Studio/Template-KleinBase4B-ControlNet"]) +# image = template( +# pipe, +# prompt="A cat is sitting on a stone, bathed in bright sunshine.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "image": Image.open("data/assets/image_depth.jpg"), +# "prompt": "A cat is sitting on a stone, bathed in bright sunshine.", +# }], +# negative_template_inputs = [{ +# "image": Image.open("data/assets/image_depth.jpg"), +# "prompt": "", +# }], +# ) +# image.save("image_ControlNet_sunshine.jpg") +# image = template( +# pipe, +# prompt="A cat is sitting on a stone, surrounded by colorful magical particles.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "image": Image.open("data/assets/image_depth.jpg"), +# "prompt": "A cat is sitting on a stone, surrounded by colorful magical particles.", +# }], +# negative_template_inputs = [{ +# "image": Image.open("data/assets/image_depth.jpg"), +# "prompt": "", +# }], +# ) +# image.save("image_ControlNet_magic.jpg") + +# template = load_template_pipeline(["DiffSynth-Studio/Template-KleinBase4B-Edit"]) +# image = template( +# pipe, +# prompt="Put a hat on this cat.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "image": Image.open("data/assets/image_reference.jpg"), +# "prompt": "Put a hat on this cat.", +# }], +# negative_template_inputs = [{ +# "image": Image.open("data/assets/image_reference.jpg"), +# "prompt": "", +# }], +# ) +# image.save("image_Edit_hat.jpg") +# image = template( +# pipe, +# prompt="Make the cat turn its head to look to the right.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "image": Image.open("data/assets/image_reference.jpg"), +# "prompt": "Make the cat turn its head to look to the right.", +# }], +# negative_template_inputs = [{ +# "image": Image.open("data/assets/image_reference.jpg"), +# "prompt": "", +# }], +# ) +# image.save("image_Edit_head.jpg") + +# template = load_template_pipeline(["DiffSynth-Studio/Template-KleinBase4B-Upscaler"]) +# image = template( +# pipe, +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "image": Image.open("data/assets/image_lowres_512.jpg"), +# "prompt": "A cat is sitting on a stone.", +# }], +# negative_template_inputs = [{ +# "image": Image.open("data/assets/image_lowres_512.jpg"), +# "prompt": "", +# }], +# ) +# image.save("image_Upscaler_1.png") +# image = template( +# pipe, +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "image": Image.open("data/assets/image_lowres_100.jpg"), +# "prompt": "A cat is sitting on a stone.", +# }], +# negative_template_inputs = [{ +# "image": Image.open("data/assets/image_lowres_100.jpg"), +# "prompt": "", +# }], +# ) +# image.save("image_Upscaler_2.png") + +# template = load_template_pipeline(["DiffSynth-Studio/Template-KleinBase4B-SoftRGB"]) +# image = template( +# pipe, +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "R": 128/255, +# "G": 128/255, +# "B": 128/255 +# }], +# ) +# image.save("image_rgb_normal.jpg") +# image = template( +# pipe, +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "R": 208/255, +# "G": 185/255, +# "B": 138/255 +# }], +# ) +# image.save("image_rgb_warm.jpg") +# image = template( +# pipe, +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "R": 94/255, +# "G": 163/255, +# "B": 174/255 +# }], +# ) +# image.save("image_rgb_cold.jpg") + +# template = load_template_pipeline(["DiffSynth-Studio/Template-KleinBase4B-PandaMeme"]) +# image = template( +# pipe, +# prompt="A meme with a sleepy expression.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{}], +# negative_template_inputs = [{}], +# ) +# image.save("image_PandaMeme_sleepy.jpg") +# image = template( +# pipe, +# prompt="A meme with a happy expression.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{}], +# negative_template_inputs = [{}], +# ) +# image.save("image_PandaMeme_happy.jpg") +# image = template( +# pipe, +# prompt="A meme with a surprised expression.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{}], +# negative_template_inputs = [{}], +# ) +# image.save("image_PandaMeme_surprised.jpg") + +# template = load_template_pipeline(["DiffSynth-Studio/Template-KleinBase4B-Sharpness"]) +# image = template( +# pipe, +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{"scale": 0.1}], +# negative_template_inputs = [{"scale": 0.5}], +# ) +# image.save("image_Sharpness_0.1.jpg") +# image = template( +# pipe, +# prompt="A cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{"scale": 0.8}], +# negative_template_inputs = [{"scale": 0.5}], +# ) +# image.save("image_Sharpness_0.8.jpg") + +# template = load_template_pipeline(["DiffSynth-Studio/Template-KleinBase4B-Inpaint"]) +# image = template( +# pipe, +# prompt="An orange cat is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "image": Image.open("data/assets/image_reference.jpg"), +# "mask": Image.open("data/assets/image_mask_1.jpg"), +# "force_inpaint": True, +# }], +# negative_template_inputs = [{ +# "image": Image.open("data/assets/image_reference.jpg"), +# "mask": Image.open("data/assets/image_mask_1.jpg"), +# }], +# ) +# image.save("image_Inpaint_1.jpg") +# image = template( +# pipe, +# prompt="A cat wearing sunglasses is sitting on a stone.", +# seed=0, cfg_scale=4, num_inference_steps=50, +# template_inputs = [{ +# "image": Image.open("data/assets/image_reference.jpg"), +# "mask": Image.open("data/assets/image_mask_2.jpg"), +# }], +# negative_template_inputs = [{ +# "image": Image.open("data/assets/image_reference.jpg"), +# "mask": Image.open("data/assets/image_mask_2.jpg"), +# }], +# ) +# image.save("image_Inpaint_2.jpg") diff --git a/examples/flux2/model_training/full/FLUX.2-klein-base-4B-skills.sh b/examples/flux2/model_training/full/Template-KleinBase4B.sh similarity index 53% rename from examples/flux2/model_training/full/FLUX.2-klein-base-4B-skills.sh rename to examples/flux2/model_training/full/Template-KleinBase4B.sh index d56634b..093f1ef 100644 --- a/examples/flux2/model_training/full/FLUX.2-klein-base-4B-skills.sh +++ b/examples/flux2/model_training/full/Template-KleinBase4B.sh @@ -1,16 +1,17 @@ accelerate launch examples/flux2/model_training/train.py \ - --dataset_base_path /mnt/nas1/duanzhongjie.dzj/dataset/ImagePulseV2 \ - --dataset_metadata_path /mnt/nas1/duanzhongjie.dzj/dataset/ImagePulseV2/metadata_example_ti2ti.jsonl \ - --extra_inputs "skill_inputs" \ + --dataset_base_path xxx \ + --dataset_metadata_path xxx/metadata.jsonl \ + --extra_inputs "template_inputs" \ --max_pixels 1048576 \ --dataset_repeat 1 \ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ - --skill_model_id_or_path "models/base" \ + --template_model_id_or_path "xxx" \ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ --learning_rate 1e-4 \ --num_epochs 999 \ - --remove_prefix_in_ckpt "pipe.skill_model." \ - --output_path "./models/train/FLUX.2-klein-base-4B-skills_full" \ - --trainable_models "skill_model" \ + --remove_prefix_in_ckpt "pipe.template_model." \ + --output_path "./models/train/Template-KleinBase4B_full" \ + --trainable_models "template_model" \ + --save_steps 1000 \ --use_gradient_checkpointing \ - --save_steps 200 + --find_unused_parameters diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index 7a15267..5378da4 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -18,7 +18,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule): extra_inputs=None, fp8_models=None, offload_models=None, - skill_model_id_or_path=None, + template_model_id_or_path=None, device="cpu", task="sft", ): @@ -27,7 +27,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule): model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/")) self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) - self.pipe = self.load_training_skill_model(self.pipe, skill_model_id_or_path) + self.pipe = self.load_training_template_model(self.pipe, template_model_id_or_path, args.use_gradient_checkpointing, args.use_gradient_checkpointing_offload) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) # Training mode @@ -128,7 +128,7 @@ if __name__ == "__main__": extra_inputs=args.extra_inputs, fp8_models=args.fp8_models, offload_models=args.offload_models, - skill_model_id_or_path=args.skill_model_id_or_path, + template_model_id_or_path=args.template_model_id_or_path, task=args.task, device="cpu" if args.initialize_model_on_cpu else accelerator.device, )