mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
flux-kontext
This commit is contained in:
@@ -23,7 +23,9 @@ from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader
|
||||
|
||||
from ..vram_management import gradient_checkpoint_forward
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
from ..models.flux_dit import RMSNorm
|
||||
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
|
||||
|
||||
@@ -135,7 +137,119 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
||||
pass
|
||||
self.vram_management_enabled = True
|
||||
if num_persistent_param_in_dit is not None:
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
if self.text_encoder_1 is not None:
|
||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_1,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.text_encoder_2 is not None:
|
||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
T5LayerNorm: AutoWrappedModule,
|
||||
T5DenseActDense: AutoWrappedModule,
|
||||
T5DenseGatedActDense: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.dit is not None:
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae_decoder is not None:
|
||||
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_decoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae_encoder is not None:
|
||||
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user