From 49f9a11eb3ce76cf30a70757cb6e8ac14f4b723e Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 14 Aug 2025 13:50:04 +0800 Subject: [PATCH 1/2] lora_checkpoint & weight_decay & qwen_image_controlnet_train --- diffsynth/trainers/utils.py | 19 +++++++++++++++-- examples/flux/README.md | 2 ++ examples/flux/README_zh.md | 2 ++ examples/flux/model_training/train.py | 12 +++++++++-- examples/qwen_image/README.md | 2 ++ examples/qwen_image/README_zh.md | 2 ++ examples/qwen_image/model_training/train.py | 23 +++++++++++++++++++-- examples/wanvideo/README.md | 2 ++ examples/wanvideo/README_zh.md | 2 ++ examples/wanvideo/model_training/train.py | 12 +++++++++-- 10 files changed, 70 insertions(+), 8 deletions(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index ac13b2e..0187065 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -344,8 +344,17 @@ class DiffusionTrainingModule(torch.nn.Module): lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) model = inject_adapter_in_model(lora_config, model) return model - - + + + def mapping_lora_state_dict(self, state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if "lora_A.weight" in key or "lora_B.weight" in key: + new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight") + new_state_dict[new_key] = value + return new_state_dict + + def export_trainable_state_dict(self, state_dict, remove_prefix=None): trainable_param_names = self.trainable_param_names() state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names} @@ -467,6 +476,7 @@ def wan_parser(): parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") + parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.") parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") @@ -475,6 +485,7 @@ def wan_parser(): parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") return parser @@ -498,6 +509,7 @@ def flux_parser(): parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") + parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.") parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.") parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") @@ -506,6 +518,7 @@ def flux_parser(): parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") return parser @@ -530,6 +543,7 @@ def qwen_image_parser(): parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") + parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.") parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") @@ -537,4 +551,5 @@ def qwen_image_parser(): parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") return parser diff --git a/examples/flux/README.md b/examples/flux/README.md index 8137b70..492e073 100644 --- a/examples/flux/README.md +++ b/examples/flux/README.md @@ -255,6 +255,7 @@ The script includes the following parameters: * `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas. * Training * `--learning_rate`: Learning rate. + * `--weight_decay`: Weight decay. * `--num_epochs`: Number of epochs. * `--output_path`: Save path. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. @@ -265,6 +266,7 @@ The script includes the following parameters: * `--lora_base_model`: Which model to add LoRA to. * `--lora_target_modules`: Which layers to add LoRA to. * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint. * Extra Model Inputs * `--extra_inputs`: Extra model inputs, separated by commas. * VRAM Management diff --git a/examples/flux/README_zh.md b/examples/flux/README_zh.md index 6bbd6fe..65331f9 100644 --- a/examples/flux/README_zh.md +++ b/examples/flux/README_zh.md @@ -255,6 +255,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。 * 训练 * `--learning_rate`: 学习率。 + * `--weight_decay`:权重衰减大小。 * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 @@ -265,6 +266,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra * `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_target_modules`: LoRA 添加到哪一层上。 * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 * 额外模型输入 * `--extra_inputs`: 额外的模型输入,以逗号分隔。 * 显存管理 diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index 5ee4dff..e1b66c8 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -1,4 +1,5 @@ import torch, os, json +from diffsynth import load_state_dict from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser from diffsynth.models.lora import FluxLoRAConverter @@ -11,7 +12,7 @@ class FluxTrainingModule(DiffusionTrainingModule): self, model_paths=None, model_id_with_origin_paths=None, trainable_models=None, - lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32, + lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32, lora_checkpoint=None, use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, @@ -40,6 +41,12 @@ class FluxTrainingModule(DiffusionTrainingModule): target_modules=lora_target_modules.split(","), lora_rank=lora_rank ) + if lora_checkpoint is not None: + state_dict = load_state_dict(lora_checkpoint) + state_dict = self.mapping_lora_state_dict(state_dict) + load_result = model.load_state_dict(state_dict, strict=False) + if len(load_result[1]) > 0: + print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") setattr(self.pipe, lora_base_model, model) # Store other configs @@ -106,6 +113,7 @@ if __name__ == "__main__": lora_base_model=args.lora_base_model, lora_target_modules=args.lora_target_modules, lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, use_gradient_checkpointing=args.use_gradient_checkpointing, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, @@ -115,7 +123,7 @@ if __name__ == "__main__": remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, ) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) launch_training_task( dataset, model, model_logger, optimizer, scheduler, diff --git a/examples/qwen_image/README.md b/examples/qwen_image/README.md index 7940013..80bb905 100644 --- a/examples/qwen_image/README.md +++ b/examples/qwen_image/README.md @@ -237,6 +237,7 @@ The script includes the following parameters: * `--tokenizer_path`: Tokenizer path. Leave empty to auto-download. * Training * `--learning_rate`: Learning rate. + * `--weight_decay`: Weight decay. * `--num_epochs`: Number of epochs. * `--output_path`: Save path. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. @@ -247,6 +248,7 @@ The script includes the following parameters: * `--lora_base_model`: Which model to add LoRA to. * `--lora_target_modules`: Which layers to add LoRA to. * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint. * Extra Model Inputs * `--extra_inputs`: Extra model inputs, separated by commas. * VRAM Management diff --git a/examples/qwen_image/README_zh.md b/examples/qwen_image/README_zh.md index 0440aae..2fb751f 100644 --- a/examples/qwen_image/README_zh.md +++ b/examples/qwen_image/README_zh.md @@ -237,6 +237,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--tokenizer_path`: tokenizer 路径,留空将会自动下载。 * 训练 * `--learning_rate`: 学习率。 + * `--weight_decay`:权重衰减大小。 * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 @@ -247,6 +248,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_target_modules`: LoRA 添加到哪一层上。 * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 * 额外模型输入 * `--extra_inputs`: 额外的模型输入,以逗号分隔。 * 显存管理 diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index d8f6343..8fb3bf8 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -1,5 +1,7 @@ import torch, os, json +from diffsynth import load_state_dict from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.pipelines.flux_image_new import ControlNetInput from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -11,7 +13,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): model_paths=None, model_id_with_origin_paths=None, tokenizer_path=None, trainable_models=None, - lora_base_model=None, lora_target_modules="", lora_rank=32, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, @@ -43,6 +45,12 @@ class QwenImageTrainingModule(DiffusionTrainingModule): target_modules=lora_target_modules.split(","), lora_rank=lora_rank ) + if lora_checkpoint is not None: + state_dict = load_state_dict(lora_checkpoint) + state_dict = self.mapping_lora_state_dict(state_dict) + load_result = model.load_state_dict(state_dict, strict=False) + if len(load_result[1]) > 0: + print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") setattr(self.pipe, lora_base_model, model) # Store other configs @@ -72,8 +80,18 @@ class QwenImageTrainingModule(DiffusionTrainingModule): } # Extra inputs + controlnet_input = {} for extra_input in self.extra_inputs: inputs_shared[extra_input] = data[extra_input] + if extra_input.startswith("blockwise_controlnet_"): + controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input] + elif extra_input.startswith("controlnet_"): + controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input] + else: + inputs_shared[extra_input] = data[extra_input] + if len(controlnet_input) > 0: + controlnet_key = "blockwise_controlnet_inputs" if "blockwise_controlnet_image" in self.extra_inputs else "controlnet_inputs" + inputs_shared[controlnet_key] = [ControlNetInput(**controlnet_input)] # Pipeline units will automatically process the input parameters. for unit in self.pipe.units: @@ -101,12 +119,13 @@ if __name__ == "__main__": lora_base_model=args.lora_base_model, lora_target_modules=args.lora_target_modules, lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, use_gradient_checkpointing=args.use_gradient_checkpointing, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, ) model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) launch_training_task( dataset, model, model_logger, optimizer, scheduler, diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 8af8dca..4e5195a 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -288,6 +288,7 @@ The script includes the following parameters: * `--min_timestep_boundary`: Minimum value of the timestep interval, ranging from 0 to 1. Default is 1. This needs to be manually set only when training mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B). * Training * `--learning_rate`: Learning rate. + * `--weight_decay`: Weight decay. * `--num_epochs`: Number of epochs. * `--output_path`: Output save path. * `--remove_prefix_in_ckpt`: Remove prefix in ckpt. @@ -298,6 +299,7 @@ The script includes the following parameters: * `--lora_base_model`: Which model LoRA is added to. * `--lora_target_modules`: Which layers LoRA is added to. * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint. * Extra Inputs * `--extra_inputs`: Additional model inputs, comma-separated. * VRAM Management diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 06e81fa..89c539e 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -290,11 +290,13 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai * `--min_timestep_boundary`: Timestep 区间最小值,范围为 0~1,默认为 1,仅在多 DiT 的混合模型训练中需要手动设置,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。 * 训练 * `--learning_rate`: 学习率。 + * `--weight_decay`:权重衰减大小。 * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 1b79004..726243c 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -1,4 +1,5 @@ import torch, os, json +from diffsynth import load_state_dict from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -10,7 +11,7 @@ class WanTrainingModule(DiffusionTrainingModule): self, model_paths=None, model_id_with_origin_paths=None, trainable_models=None, - lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, + lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, lora_checkpoint=None, use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, @@ -41,6 +42,12 @@ class WanTrainingModule(DiffusionTrainingModule): target_modules=lora_target_modules.split(","), lora_rank=lora_rank ) + if lora_checkpoint is not None: + state_dict = load_state_dict(lora_checkpoint) + state_dict = self.mapping_lora_state_dict(state_dict) + load_result = model.load_state_dict(state_dict, strict=False) + if len(load_result[1]) > 0: + print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") setattr(self.pipe, lora_base_model, model) # Store other configs @@ -112,6 +119,7 @@ if __name__ == "__main__": lora_base_model=args.lora_base_model, lora_target_modules=args.lora_target_modules, lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, max_timestep_boundary=args.max_timestep_boundary, @@ -121,7 +129,7 @@ if __name__ == "__main__": args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt ) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) launch_training_task( dataset, model, model_logger, optimizer, scheduler, From 3212c833980ba6660c0caa6ac077a10c619a55a2 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 14 Aug 2025 13:59:04 +0800 Subject: [PATCH 2/2] minor fix --- examples/qwen_image/model_training/train.py | 1 - examples/wanvideo/README_zh.md | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 8fb3bf8..56db9e2 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -82,7 +82,6 @@ class QwenImageTrainingModule(DiffusionTrainingModule): # Extra inputs controlnet_input = {} for extra_input in self.extra_inputs: - inputs_shared[extra_input] = data[extra_input] if extra_input.startswith("blockwise_controlnet_"): controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input] elif extra_input.startswith("controlnet_"): diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 89c539e..bcc076f 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -296,12 +296,12 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 - * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_target_modules`: LoRA 添加到哪一层上。 * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 * 额外模型输入 * `--extra_inputs`: 额外的模型输入,以逗号分隔。 * 显存管理