mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
flux series vram management
This commit is contained in:
@@ -104,6 +104,7 @@ class InfiniteYouImageProjector(nn.Module):
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
latents = latents.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
|
||||
@@ -40,7 +40,8 @@ class SingleValueEncoder(torch.nn.Module):
|
||||
emb = self.prefer_proj(value).to(dtype)
|
||||
emb = self.prefer_value_embedder(emb).squeeze(0)
|
||||
base_embeddings = emb.expand(self.prefer_len, -1)
|
||||
learned_embeddings = base_embeddings + self.positional_embedding
|
||||
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
|
||||
learned_embeddings = base_embeddings + positional_embedding
|
||||
return learned_embeddings
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -162,7 +162,7 @@ class TimestepEmbedder(nn.Module):
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(
|
||||
t, self.frequency_embedding_size, self.max_period
|
||||
).type(self.mlp[0].weight.dtype) # type: ignore
|
||||
).type(t.dtype) # type: ignore
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
@@ -656,7 +656,7 @@ class Qwen2Connector(torch.nn.Module):
|
||||
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||
x_mean = (x * mask_float).sum(
|
||||
dim=1
|
||||
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
|
||||
) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(dtype=x.dtype, device=x.device))
|
||||
|
||||
global_out=self.global_proj_out(x_mean)
|
||||
encoder_hidden_states = self.S(x,t,mask)
|
||||
|
||||
@@ -24,7 +24,6 @@ from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
|
||||
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
from ..models.flux_dit import RMSNorm
|
||||
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
@@ -185,22 +184,18 @@ class FluxImagePipeline(BasePipeline):
|
||||
return loss
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
||||
self.vram_management_enabled = True
|
||||
if num_persistent_param_in_dit is not None:
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
if self.text_encoder_1 is not None:
|
||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||
def _enable_vram_management_with_default_config(self, model, vram_limit):
|
||||
if model is not None:
|
||||
dtype = next(iter(model.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_1,
|
||||
model,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -212,7 +207,25 @@ class FluxImagePipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
||||
self.vram_management_enabled = True
|
||||
if num_persistent_param_in_dit is not None:
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
|
||||
# Default config
|
||||
default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector"]
|
||||
for model_name in default_vram_management_models:
|
||||
self._enable_vram_management_with_default_config(getattr(self, model_name), vram_limit)
|
||||
|
||||
# Special config
|
||||
if self.text_encoder_2 is not None:
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_2,
|
||||
@@ -261,14 +274,18 @@ class FluxImagePipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae_decoder is not None:
|
||||
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||
if self.ipadapter_image_encoder is not None:
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionEmbeddings, SiglipEncoder, SiglipMultiheadAttentionPoolingHead
|
||||
dtype = next(iter(self.ipadapter_image_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_decoder,
|
||||
self.ipadapter_image_encoder,
|
||||
module_map = {
|
||||
SiglipVisionEmbeddings: AutoWrappedModule,
|
||||
SiglipEncoder: AutoWrappedModule,
|
||||
SiglipMultiheadAttentionPoolingHead: AutoWrappedModule,
|
||||
torch.nn.MultiheadAttention: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -280,14 +297,25 @@ class FluxImagePipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae_encoder is not None:
|
||||
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||
if self.qwenvl is not None:
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VisionPatchEmbed, Qwen2_5_VLVisionBlock, Qwen2_5_VLPatchMerger,
|
||||
Qwen2_5_VLDecoderLayer, Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
|
||||
)
|
||||
dtype = next(iter(self.qwenvl.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_encoder,
|
||||
self.qwenvl,
|
||||
module_map = {
|
||||
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
|
||||
Qwen2_5_VLVisionBlock: AutoWrappedModule,
|
||||
Qwen2_5_VLPatchMerger: AutoWrappedModule,
|
||||
Qwen2_5_VLDecoderLayer: AutoWrappedModule,
|
||||
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
|
||||
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
|
||||
Qwen2RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -774,9 +802,13 @@ class FluxImageUnit_Flex(PipelineUnit):
|
||||
|
||||
class FluxImageUnit_InfiniteYou(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("infinityou_id_image", "infinityou_guidance"))
|
||||
super().__init__(
|
||||
input_params=("infinityou_id_image", "infinityou_guidance"),
|
||||
onload_model_names=("infinityou_processor",)
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance):
|
||||
pipe.load_models_to_device("infinityou_processor")
|
||||
if infinityou_id_image is not None:
|
||||
return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device)
|
||||
else:
|
||||
@@ -816,7 +848,7 @@ class InfinitYou(torch.nn.Module):
|
||||
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
|
||||
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
|
||||
self.arcface_model = init_recognition_model('arcface', device=self.device)
|
||||
self.arcface_model = init_recognition_model('arcface', device=self.device).to(torch_dtype)
|
||||
|
||||
def _detect_face(self, id_image_cv2):
|
||||
face_info = self.app_640.get(id_image_cv2)
|
||||
|
||||
Reference in New Issue
Block a user