mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
@@ -63,7 +63,8 @@ class AutoTorchModule(torch.nn.Module):
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
def check_free_vram(self):
|
def check_free_vram(self):
|
||||||
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device)
|
device = self.computation_device if self.computation_device != "npu" else "npu:0"
|
||||||
|
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
||||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||||
return used_memory < self.vram_limit
|
return used_memory < self.vram_limit
|
||||||
|
|
||||||
@@ -309,6 +310,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
|||||||
self.lora_B_weights = []
|
self.lora_B_weights = []
|
||||||
self.lora_merger = None
|
self.lora_merger = None
|
||||||
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||||
|
self.computation_device_type = parse_device_type(self.computation_device)
|
||||||
|
|
||||||
if offload_dtype == "disk":
|
if offload_dtype == "disk":
|
||||||
self.disk_map = disk_map
|
self.disk_map = disk_map
|
||||||
|
|||||||
@@ -177,7 +177,8 @@ class BasePipeline(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def get_vram(self):
|
def get_vram(self):
|
||||||
return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3)
|
device = self.device if self.device != "npu" else "npu:0"
|
||||||
|
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
||||||
|
|
||||||
def get_module(self, model, name):
|
def get_module(self, model, name):
|
||||||
if "." in name:
|
if "." in name:
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ pipe = WanVideoPipeline.from_pretrained(
|
|||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||||
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||||
+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2,
|
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
video = pipe(
|
video = pipe(
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ pipe = WanVideoPipeline.from_pretrained(
|
|||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||||
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||||
+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2,
|
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
video = pipe(
|
video = pipe(
|
||||||
|
|||||||
Reference in New Issue
Block a user