mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
@@ -73,9 +73,10 @@ def usp_dit_forward(self,
|
|||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
# Context Parallel
|
# Context Parallel
|
||||||
x = torch.chunk(
|
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||||
x, get_sequence_parallel_world_size(),
|
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||||
dim=1)[get_sequence_parallel_rank()]
|
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||||
|
x = chunks[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
if self.training and use_gradient_checkpointing:
|
if self.training and use_gradient_checkpointing:
|
||||||
@@ -99,6 +100,7 @@ def usp_dit_forward(self,
|
|||||||
|
|
||||||
# Context Parallel
|
# Context Parallel
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
x = x[:, :-pad_shape] if pad_shape > 0 else x
|
||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, (f, h, w))
|
x = self.unpatchify(x, (f, h, w))
|
||||||
|
|||||||
@@ -594,24 +594,33 @@ def model_fn_wan_video(
|
|||||||
# blocks
|
# blocks
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||||
|
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||||
|
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||||
|
x = chunks[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
for block_id, block in enumerate(dit.blocks):
|
for block_id, block in enumerate(dit.blocks):
|
||||||
x = block(x, context, t_mod, freqs)
|
x = block(x, context, t_mod, freqs)
|
||||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||||
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||||
|
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
|
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
||||||
|
x = x + current_vace_hint * vace_scale
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(x)
|
tea_cache.store(x)
|
||||||
|
|
||||||
if reference_latents is not None:
|
|
||||||
x = x[:, reference_latents.shape[1]:]
|
|
||||||
f -= 1
|
|
||||||
|
|
||||||
x = dit.head(x, t)
|
x = dit.head(x, t)
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
x = x[:, :-pad_shape] if pad_shape > 0 else x
|
||||||
|
# Remove reference latents
|
||||||
|
if reference_latents is not None:
|
||||||
|
x = x[:, reference_latents.shape[1]:]
|
||||||
|
f -= 1
|
||||||
x = dit.unpatchify(x, (f, h, w))
|
x = dit.unpatchify(x, (f, h, w))
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -1074,7 +1074,10 @@ def model_fn_wan_video(
|
|||||||
# blocks
|
# blocks
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||||
|
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||||
|
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||||
|
x = chunks[get_sequence_parallel_rank()]
|
||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
@@ -1103,6 +1106,7 @@ def model_fn_wan_video(
|
|||||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
|
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
||||||
x = x + current_vace_hint * vace_scale
|
x = x + current_vace_hint * vace_scale
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(x)
|
tea_cache.store(x)
|
||||||
@@ -1111,6 +1115,7 @@ def model_fn_wan_video(
|
|||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
x = x[:, :-pad_shape] if pad_shape > 0 else x
|
||||||
# Remove reference latents
|
# Remove reference latents
|
||||||
if reference_latents is not None:
|
if reference_latents is not None:
|
||||||
x = x[:, reference_latents.shape[1]:]
|
x = x[:, reference_latents.shape[1]:]
|
||||||
|
|||||||
@@ -178,9 +178,9 @@ class ModelConfig:
|
|||||||
is_folder = False
|
is_folder = False
|
||||||
|
|
||||||
# Download
|
# Download
|
||||||
|
if self.local_model_path is None:
|
||||||
|
self.local_model_path = "./models"
|
||||||
if not skip_download:
|
if not skip_download:
|
||||||
if self.local_model_path is None:
|
|
||||||
self.local_model_path = "./models"
|
|
||||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
self.model_id,
|
self.model_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user