mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -63,7 +63,8 @@ class AutoTorchModule(torch.nn.Module):
|
||||
return r
|
||||
|
||||
def check_free_vram(self):
|
||||
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device)
|
||||
device = self.computation_device if self.computation_device != "npu" else "npu:0"
|
||||
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
@@ -309,6 +310,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
self.lora_B_weights = []
|
||||
self.lora_merger = None
|
||||
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||
self.computation_device_type = parse_device_type(self.computation_device)
|
||||
|
||||
if offload_dtype == "disk":
|
||||
self.disk_map = disk_map
|
||||
|
||||
@@ -177,7 +177,8 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
device = self.device if self.device != "npu" else "npu:0"
|
||||
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
||||
|
||||
def get_module(self, model, name):
|
||||
if "." in name:
|
||||
|
||||
@@ -46,7 +46,7 @@ pipe = WanVideoPipeline.from_pretrained(
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||
+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2,
|
||||
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
|
||||
)
|
||||
|
||||
video = pipe(
|
||||
|
||||
@@ -46,7 +46,7 @@ pipe = WanVideoPipeline.from_pretrained(
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||
+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2,
|
||||
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
|
||||
)
|
||||
|
||||
video = pipe(
|
||||
|
||||
Reference in New Issue
Block a user