diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 3355f2e..b076ca9 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -404,6 +404,77 @@ class FluxDiT(torch.nn.Module): hidden_states = self.unpatchify(hidden_states, height, width) return hidden_states + + + def quantize(self): + def cast_to(weight, dtype=None, device=None, copy=False): + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + return weight.to(dtype=dtype, copy=copy) + + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + def cast_weight(s, input=None, dtype=None, device=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if device is None: + device = input.device + weight = cast_to(s.weight, dtype, device) + return weight + + def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if bias_dtype is None: + bias_dtype = dtype + if device is None: + device = input.device + bias = None + weight = cast_to(s.weight, dtype, device) + bias = cast_to(s.bias, bias_dtype, device) + return weight, bias + + class quantized_layer: + class Linear(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self,input,**kwargs): + weight,bias= cast_bias_weight(self.module,input) + return torch.nn.functional.linear(input,weight,bias) + + class RMSNorm(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self,hidden_states,**kwargs): + weight= cast_weight(self.module,hidden_states) + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) + hidden_states = hidden_states.to(input_dtype) * weight + return hidden_states + + def replace_layer(model): + for name, module in model.named_children(): + if isinstance(module, torch.nn.Linear): + new_layer = quantized_layer.Linear(module) + setattr(model, name, new_layer) + elif isinstance(module, RMSNorm): + new_layer = quantized_layer.RMSNorm(module) + setattr(model, name, new_layer) + else: + replace_layer(module) + + replace_layer(self) @staticmethod diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 150565d..7f5eef8 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -415,8 +415,10 @@ class ModelManager: break - def load_model(self, file_path, model_names=None): + def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): print(f"Loading models from: {file_path}") + if device is None: device = self.device + if torch_dtype is None: torch_dtype = self.torch_dtype if os.path.isfile(file_path): state_dict = load_state_dict(file_path) else: @@ -425,7 +427,7 @@ class ModelManager: if model_detector.match(file_path, state_dict): model_names, models = model_detector.load( file_path, state_dict, - device=self.device, torch_dtype=self.torch_dtype, + device=device, torch_dtype=torch_dtype, allowed_model_names=model_names, model_manager=self ) for model_name, model in zip(model_names, models): @@ -438,9 +440,9 @@ class ModelManager: print(f" We cannot detect the model type. No models are loaded.") - def load_models(self, file_path_list, model_names=None): + def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None): for file_path in file_path_list: - self.load_model(file_path, model_names) + self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype) def fetch_model(self, model_name, file_path=None, require_model_path=False): diff --git a/examples/image_synthesis/flux_float8_e4m3fn.py b/examples/image_synthesis/flux_float8_e4m3fn.py deleted file mode 100644 index 2286cdd..0000000 --- a/examples/image_synthesis/flux_float8_e4m3fn.py +++ /dev/null @@ -1,116 +0,0 @@ -import torch -from torch import nn -from diffsynth import download_models, ModelManager, OmostPromter, FluxImagePipeline - -from diffsynth.models.flux_dit import RMSNorm - - -def cast_to(weight, dtype=None, device=None, copy=False): - if device is None or weight.device == device: - if not copy: - if dtype is None or weight.dtype == dtype: - return weight - return weight.to(dtype=dtype, copy=copy) - - r = torch.empty_like(weight, dtype=dtype, device=device) - r.copy_(weight) - return r - -def cast_weight(s, input=None, dtype=None, device=None): - if input is not None: - if dtype is None: - dtype = input.dtype - if device is None: - device = input.device - weight = cast_to(s.weight, dtype, device) - return weight - -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): - if input is not None: - if dtype is None: - dtype = input.dtype - if bias_dtype is None: - bias_dtype = dtype - if device is None: - device = input.device - bias = None - weight = cast_to(s.weight, dtype, device) - bias = cast_to(s.bias, bias_dtype, device) - return weight, bias - -class quantized_layer: - class Linear(torch.nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self,input,**kwargs): - weight,bias= cast_bias_weight(self.module,input) - return torch.nn.functional.linear(input,weight,bias) - - class RMSNorm(torch.nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self,hidden_states,**kwargs): - weight= cast_weight(self.module,hidden_states) - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) - hidden_states = hidden_states.to(input_dtype) * weight - return hidden_states - -def replace_layer(model): - for name, module in model.named_children(): - if isinstance(module, nn.Linear): - new_layer = quantized_layer.Linear(module) - setattr(model, name, new_layer) - elif isinstance(module, RMSNorm): - new_layer = quantized_layer.RMSNorm(module) - setattr(model, name, new_layer) - else: - replace_layer(module) - -def print_layers(model): - for name, module in model.named_modules(): - print(type(module)) - - -def fetch_models(self, model_manager: ModelManager, model_manager2: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]): - self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1") - self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2") - self.dit = model_manager2.fetch_model("flux_dit") - self.vae_decoder = model_manager.fetch_model("flux_vae_decoder") - self.vae_encoder = model_manager.fetch_model("flux_vae_encoder") - self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2) - self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes) - self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes) - -download_models(["FLUX.1-dev"]) - -model_manager = ModelManager(torch_dtype=torch.bfloat16, device='cuda') -model_manager.load_models([ - "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", - "models/FLUX/FLUX.1-dev/text_encoder_2", - "models/FLUX/FLUX.1-dev/ae.safetensors", -]) - -model_manager2 = ModelManager(torch_dtype=torch.float8_e4m3fn, device="cuda") -model_manager2.load_models(["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"]) - -pipe =FluxImagePipeline(device="cuda",torch_dtype=torch.bfloat16) -fetch_models(pipe,model_manager,model_manager2) -# pipe.enable_cpu_offload() - -trans = pipe.dit -replace_layer(trans) - -prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." - -torch.manual_seed(6) -image = pipe( - prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5 -) -image.save("image_1024_float8.jpg") \ No newline at end of file diff --git a/examples/image_synthesis/flux_text_to_image_cpu_offload.py b/examples/image_synthesis/flux_text_to_image_low_vram.py similarity index 77% rename from examples/image_synthesis/flux_text_to_image_cpu_offload.py rename to examples/image_synthesis/flux_text_to_image_low_vram.py index 4298a2f..b98929c 100644 --- a/examples/image_synthesis/flux_text_to_image_cpu_offload.py +++ b/examples/image_synthesis/flux_text_to_image_low_vram.py @@ -1,17 +1,26 @@ import torch -from diffsynth import ModelManager, FluxImagePipeline, download_models +from diffsynth import download_models, ModelManager, FluxImagePipeline download_models(["FLUX.1-dev"]) -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + +model_manager = ModelManager( + torch_dtype=torch.bfloat16, + device="cpu" # To reduce VRAM required, we load models to RAM. +) model_manager.load_models([ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2", "models/FLUX/FLUX.1-dev/ae.safetensors", - "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" ]) -pipe = FluxImagePipeline.from_model_manager(model_manager, device='cuda') +model_manager.load_models( + ["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"], + torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format. +) + +pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda") pipe.enable_cpu_offload() +pipe.dit.quantize() prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," @@ -39,4 +48,4 @@ image = pipe( num_inference_steps=30, embedded_guidance=3.5, input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True ) -image.save("image_2048_highres.jpg") +image.save("image_2048_highres.jpg") \ No newline at end of file