mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
@@ -4,3 +4,4 @@ from .gradient import *
|
|||||||
from .loader import *
|
from .loader import *
|
||||||
from .vram import *
|
from .vram import *
|
||||||
from .device import *
|
from .device import *
|
||||||
|
from .npu_patch import *
|
||||||
|
|||||||
5
diffsynth/core/npu_patch/__init__.py
Normal file
5
diffsynth/core/npu_patch/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
||||||
|
from .npu_autocast_patch import npu_autocast_patch
|
||||||
|
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
npu_autocast_patch()
|
||||||
21
diffsynth/core/npu_patch/npu_autocast_patch.py
Normal file
21
diffsynth/core/npu_patch/npu_autocast_patch.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
def npu_autocast_patch_wrapper(func):
|
||||||
|
@contextmanager
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
flag = False
|
||||||
|
if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"):
|
||||||
|
if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32):
|
||||||
|
flag = True
|
||||||
|
with func(*args, **kwargs) as ctx:
|
||||||
|
if flag:
|
||||||
|
torch.npu.set_autocast_enabled(True)
|
||||||
|
yield ctx
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def npu_autocast_patch():
|
||||||
|
torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast)
|
||||||
|
torch.autocast = npu_autocast_patch_wrapper(torch.autocast)
|
||||||
@@ -4,6 +4,7 @@ import numpy as np
|
|||||||
from einops import repeat, reduce
|
from einops import repeat, reduce
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
from ..utils.lora import GeneralLoRALoader
|
from ..utils.lora import GeneralLoRALoader
|
||||||
from ..models.model_loader import ModelPool
|
from ..models.model_loader import ModelPool
|
||||||
from ..utils.controlnet import ControlNetInput
|
from ..utils.controlnet import ControlNetInput
|
||||||
@@ -61,7 +62,7 @@ class BasePipeline(torch.nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device="cuda", torch_dtype=torch.float16,
|
device=get_device_type(), torch_dtype=torch.float16,
|
||||||
height_division_factor=64, width_division_factor=64,
|
height_division_factor=64, width_division_factor=64,
|
||||||
time_division_factor=None, time_division_remainder=None,
|
time_division_factor=None, time_division_remainder=None,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
|||||||
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
|
||||||
|
|
||||||
class DINOv3ImageEncoder(DINOv3ViTModel):
|
class DINOv3ImageEncoder(DINOv3ViTModel):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -70,7 +72,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
||||||
inputs = self.processor(images=image, return_tensors="pt")
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
||||||
bool_masked_pos = None
|
bool_masked_pos = None
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import numpy as np
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from .wan_video_dit import flash_attention
|
from .wan_video_dit import flash_attention
|
||||||
|
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||||
from ..core.gradient import gradient_checkpoint_forward
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
|
||||||
@@ -373,7 +374,9 @@ class FinalLayer_FP32(nn.Module):
|
|||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
T, _, _ = latent_shape
|
T, _, _ = latent_shape
|
||||||
|
|
||||||
with amp.autocast('cuda', dtype=torch.float32):
|
with amp.autocast(get_device_type(), dtype=torch.float32):
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
torch.npu.set_autocast_enabled(True)
|
||||||
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||||
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
@@ -583,7 +586,9 @@ class LongCatSingleStreamBlock(nn.Module):
|
|||||||
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
||||||
|
|
||||||
# compute modulation params in fp32
|
# compute modulation params in fp32
|
||||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
torch.npu.set_autocast_enabled(True)
|
||||||
shift_msa, scale_msa, gate_msa, \
|
shift_msa, scale_msa, gate_msa, \
|
||||||
shift_mlp, scale_mlp, gate_mlp = \
|
shift_mlp, scale_mlp, gate_mlp = \
|
||||||
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||||
@@ -602,7 +607,9 @@ class LongCatSingleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x_s = attn_outputs
|
x_s = attn_outputs
|
||||||
|
|
||||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
torch.npu.set_autocast_enabled(True)
|
||||||
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||||
x = x.to(x_dtype)
|
x = x.to(x_dtype)
|
||||||
|
|
||||||
@@ -615,7 +622,9 @@ class LongCatSingleStreamBlock(nn.Module):
|
|||||||
# ffn with modulation
|
# ffn with modulation
|
||||||
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||||
x_s = self.ffn(x_m)
|
x_s = self.ffn(x_m)
|
||||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
torch.npu.set_autocast_enabled(True)
|
||||||
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||||
x = x.to(x_dtype)
|
x = x.to(x_dtype)
|
||||||
|
|
||||||
@@ -797,7 +806,9 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
|
|||||||
|
|
||||||
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||||
|
|
||||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
torch.npu.set_autocast_enabled(True)
|
||||||
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||||
|
|
||||||
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||||
|
|||||||
@@ -583,7 +583,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
|||||||
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
|
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
|
||||||
is_compileable = is_compileable and not self.generation_config.disable_compile
|
is_compileable = is_compileable and not self.generation_config.disable_compile
|
||||||
if is_compileable and (
|
if is_compileable and (
|
||||||
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
|
self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices
|
||||||
):
|
):
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer,
|
|||||||
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from diffsynth.core.device.npu_compatible_device import get_device_type
|
||||||
|
|
||||||
|
|
||||||
class Siglip2ImageEncoder(SiglipVisionTransformer):
|
class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -47,7 +49,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
||||||
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
|
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
|
||||||
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
|
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from .qwen_image_text_encoder import QwenImageTextEncoder
|
from .qwen_image_text_encoder import QwenImageTextEncoder
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type, get_torch_device
|
||||||
|
|
||||||
|
|
||||||
class Step1xEditEmbedder(torch.nn.Module):
|
class Step1xEditEmbedder(torch.nn.Module):
|
||||||
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
@@ -77,13 +78,13 @@ User Prompt:'''
|
|||||||
self.max_length,
|
self.max_length,
|
||||||
self.model.config.hidden_size,
|
self.model.config.hidden_size,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
device=torch.cuda.current_device(),
|
device=get_torch_device().current_device(),
|
||||||
)
|
)
|
||||||
masks = torch.zeros(
|
masks = torch.zeros(
|
||||||
len(text_list),
|
len(text_list),
|
||||||
self.max_length,
|
self.max_length,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=torch.cuda.current_device(),
|
device=get_torch_device().current_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def split_string(s):
|
def split_string(s):
|
||||||
@@ -158,7 +159,7 @@ User Prompt:'''
|
|||||||
else:
|
else:
|
||||||
token_list.append(token_each)
|
token_list.append(token_each)
|
||||||
|
|
||||||
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())
|
||||||
|
|
||||||
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
||||||
|
|
||||||
@@ -167,15 +168,15 @@ User Prompt:'''
|
|||||||
inputs.input_ids = (
|
inputs.input_ids = (
|
||||||
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.to("cuda")
|
.to(get_device_type())
|
||||||
)
|
)
|
||||||
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
|
||||||
outputs = self.model_forward(
|
outputs = self.model_forward(
|
||||||
self.model,
|
self.model,
|
||||||
input_ids=inputs.input_ids,
|
input_ids=inputs.input_ids,
|
||||||
attention_mask=inputs.attention_mask,
|
attention_mask=inputs.attention_mask,
|
||||||
pixel_values=inputs.pixel_values.to("cuda"),
|
pixel_values=inputs.pixel_values.to(get_device_type()),
|
||||||
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -188,7 +189,7 @@ User Prompt:'''
|
|||||||
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
||||||
(min(self.max_length, emb.shape[1] - 217)),
|
(min(self.max_length, emb.shape[1] - 217)),
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=torch.cuda.current_device(),
|
device=get_torch_device().current_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return embs, masks
|
return embs, masks
|
||||||
|
|||||||
@@ -94,7 +94,6 @@ def rope_apply(x, freqs, num_heads):
|
|||||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs
|
|
||||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from torch.nn.utils.rnn import pad_sequence
|
|||||||
|
|
||||||
from torch.nn import RMSNorm
|
from torch.nn import RMSNorm
|
||||||
from ..core.attention import attention_forward
|
from ..core.attention import attention_forward
|
||||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||||
from ..core.gradient import gradient_checkpoint_forward
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ class TimestepEmbedder(nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def timestep_embedding(t, dim, max_period=10000):
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
with torch.amp.autocast("cuda", enabled=False):
|
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = torch.exp(
|
freqs = torch.exp(
|
||||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||||
@@ -105,7 +105,7 @@ class Attention(torch.nn.Module):
|
|||||||
|
|
||||||
# Apply RoPE
|
# Apply RoPE
|
||||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||||
with torch.amp.autocast("cuda", enabled=False):
|
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||||
freqs_cis = freqs_cis.unsqueeze(2)
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from einops import rearrange
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union, List, Optional, Tuple
|
from typing import Union, List, Optional, Tuple
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
from ..diffusion import FlowMatchScheduler
|
from ..diffusion import FlowMatchScheduler
|
||||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||||
@@ -18,7 +19,7 @@ from ..models.flux2_vae import Flux2VAE
|
|||||||
|
|
||||||
class Flux2ImagePipeline(BasePipeline):
|
class Flux2ImagePipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device=device, torch_dtype=torch_dtype,
|
device=device, torch_dtype=torch_dtype,
|
||||||
height_division_factor=16, width_division_factor=16,
|
height_division_factor=16, width_division_factor=16,
|
||||||
@@ -42,7 +43,7 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
torch_dtype: torch.dtype = torch.bfloat16,
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
device: Union[str, torch.device] = "cuda",
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
model_configs: list[ModelConfig] = [],
|
model_configs: list[ModelConfig] = [],
|
||||||
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
||||||
vram_limit: float = None,
|
vram_limit: float = None,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from einops import rearrange, repeat
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
from ..diffusion import FlowMatchScheduler
|
from ..diffusion import FlowMatchScheduler
|
||||||
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
|
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
|
||||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||||
@@ -55,7 +56,7 @@ class MultiControlNet(torch.nn.Module):
|
|||||||
|
|
||||||
class FluxImagePipeline(BasePipeline):
|
class FluxImagePipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device=device, torch_dtype=torch_dtype,
|
device=device, torch_dtype=torch_dtype,
|
||||||
height_division_factor=16, width_division_factor=16,
|
height_division_factor=16, width_division_factor=16,
|
||||||
@@ -117,7 +118,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
torch_dtype: torch.dtype = torch.bfloat16,
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
device: Union[str, torch.device] = "cuda",
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
model_configs: list[ModelConfig] = [],
|
model_configs: list[ModelConfig] = [],
|
||||||
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
|
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
|
||||||
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
|
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
|
||||||
@@ -377,7 +378,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
|
|||||||
text_encoder_2,
|
text_encoder_2,
|
||||||
prompt,
|
prompt,
|
||||||
positive=True,
|
positive=True,
|
||||||
device="cuda",
|
device=get_device_type(),
|
||||||
t5_sequence_length=512,
|
t5_sequence_length=512,
|
||||||
):
|
):
|
||||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
|
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
|
||||||
@@ -558,7 +559,7 @@ class FluxImageUnit_EntityControl(PipelineUnit):
|
|||||||
text_encoder_2,
|
text_encoder_2,
|
||||||
prompt,
|
prompt,
|
||||||
positive=True,
|
positive=True,
|
||||||
device="cuda",
|
device=get_device_type(),
|
||||||
t5_sequence_length=512,
|
t5_sequence_length=512,
|
||||||
):
|
):
|
||||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
|
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
|
||||||
@@ -793,7 +794,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
|
|||||||
|
|
||||||
|
|
||||||
class InfinitYou(torch.nn.Module):
|
class InfinitYou(torch.nn.Module):
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from facexlib.recognition import init_recognition_model
|
from facexlib.recognition import init_recognition_model
|
||||||
from insightface.app import FaceAnalysis
|
from insightface.app import FaceAnalysis
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from einops import rearrange
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from math import prod
|
from math import prod
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
from ..diffusion import FlowMatchScheduler
|
from ..diffusion import FlowMatchScheduler
|
||||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||||
@@ -22,7 +23,7 @@ from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
|
|||||||
|
|
||||||
class QwenImagePipeline(BasePipeline):
|
class QwenImagePipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device=device, torch_dtype=torch_dtype,
|
device=device, torch_dtype=torch_dtype,
|
||||||
height_division_factor=16, width_division_factor=16,
|
height_division_factor=16, width_division_factor=16,
|
||||||
@@ -60,7 +61,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
torch_dtype: torch.dtype = torch.bfloat16,
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
device: Union[str, torch.device] = "cuda",
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
model_configs: list[ModelConfig] = [],
|
model_configs: list[ModelConfig] = [],
|
||||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||||
processor_config: ModelConfig = None,
|
processor_config: ModelConfig = None,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import Optional
|
|||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
from transformers import Wav2Vec2Processor
|
from transformers import Wav2Vec2Processor
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
from ..diffusion import FlowMatchScheduler
|
from ..diffusion import FlowMatchScheduler
|
||||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
@@ -30,7 +31,7 @@ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
|
|||||||
|
|
||||||
class WanVideoPipeline(BasePipeline):
|
class WanVideoPipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device=device, torch_dtype=torch_dtype,
|
device=device, torch_dtype=torch_dtype,
|
||||||
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
||||||
@@ -98,7 +99,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
torch_dtype: torch.dtype = torch.bfloat16,
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
device: Union[str, torch.device] = "cuda",
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
model_configs: list[ModelConfig] = [],
|
model_configs: list[ModelConfig] = [],
|
||||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||||
audio_processor_config: ModelConfig = None,
|
audio_processor_config: ModelConfig = None,
|
||||||
@@ -960,7 +961,7 @@ class WanVideoUnit_AnimateInpaint(PipelineUnit):
|
|||||||
onload_model_names=("vae",)
|
onload_model_names=("vae",)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
|
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
|
||||||
if mask_pixel_values is None:
|
if mask_pixel_values is None:
|
||||||
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from einops import rearrange
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
from ..diffusion import FlowMatchScheduler
|
from ..diffusion import FlowMatchScheduler
|
||||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
from ..core.data.operators import ImageCropAndResize
|
from ..core.data.operators import ImageCropAndResize
|
||||||
@@ -25,7 +26,7 @@ from ..models.z_image_image2lora import ZImageImage2LoRAModel
|
|||||||
|
|
||||||
class ZImagePipeline(BasePipeline):
|
class ZImagePipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device=device, torch_dtype=torch_dtype,
|
device=device, torch_dtype=torch_dtype,
|
||||||
height_division_factor=16, width_division_factor=16,
|
height_division_factor=16, width_division_factor=16,
|
||||||
@@ -58,7 +59,7 @@ class ZImagePipeline(BasePipeline):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
torch_dtype: torch.dtype = torch.bfloat16,
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
device: Union[str, torch.device] = "cuda",
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
model_configs: list[ModelConfig] = [],
|
model_configs: list[ModelConfig] = [],
|
||||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
vram_limit: float = None,
|
vram_limit: float = None,
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
from typing_extensions import Literal, TypeAlias
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
|
||||||
|
from diffsynth.core.device.npu_compatible_device import get_device_type
|
||||||
|
|
||||||
Processor_id: TypeAlias = Literal[
|
Processor_id: TypeAlias = Literal[
|
||||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
|
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
|
||||||
]
|
]
|
||||||
|
|
||||||
class Annotator:
|
class Annotator:
|
||||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False):
|
||||||
if not skip_processor:
|
if not skip_processor:
|
||||||
if processor_id == "canny":
|
if processor_id == "canny":
|
||||||
from controlnet_aux.processor import CannyDetector
|
from controlnet_aux.processor import CannyDetector
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ def rope_apply(x, freqs, num_heads):
|
|||||||
sp_rank = get_sequence_parallel_rank()
|
sp_rank = get_sequence_parallel_rank()
|
||||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||||
freqs_rank = freqs_rank.to(torch.complex64) if IS_NPU_AVAILABLE else freqs_rank
|
|
||||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user