diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py
index 09b6ee4..e8de777 100644
--- a/diffsynth/configs/model_config.py
+++ b/diffsynth/configs/model_config.py
@@ -40,6 +40,8 @@ from ..models.flux_controlnet import FluxControlNet
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
+from ..models.omnigen import OmniGenTransformer
+
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
@@ -81,6 +83,7 @@ model_loader_configs = [
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
+ (None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers")
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
@@ -536,6 +539,15 @@ preset_models_on_modelscope = {
"RIFE": [
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
],
+ # Omnigen
+ "OmniGen-v1": [
+ ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
+ ],
# CogVideo
"CogVideoX-5B": {
"file_list": [
@@ -600,6 +612,7 @@ Preset_model_id: TypeAlias = Literal[
"OmostPrompt",
"ESRGAN_x4",
"RIFE",
+ "OmniGen-v1",
"CogVideoX-5B",
"Annotators:Depth",
"Annotators:Softedge",
diff --git a/diffsynth/models/omnigen.py b/diffsynth/models/omnigen.py
new file mode 100644
index 0000000..571d6c0
--- /dev/null
+++ b/diffsynth/models/omnigen.py
@@ -0,0 +1,803 @@
+# The code is revised from DiT
+import os
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+from safetensors.torch import load_file
+from typing import List, Optional, Tuple, Union
+import torch.utils.checkpoint
+from huggingface_hub import snapshot_download
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers import Phi3Config, Phi3Model
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Phi3Transformer(Phi3Model):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
+ We only modified the attention mask
+ Args:
+ config: Phi3Config
+ """
+ def prefetch_layer(self, layer_idx: int, device: torch.device):
+ "Starts prefetching the next layer cache"
+ with torch.cuda.stream(self.prefetch_stream):
+ # Prefetch next layer tensors to GPU
+ for name, param in self.layers[layer_idx].named_parameters():
+ param.data = param.data.to(device, non_blocking=True)
+
+ def evict_previous_layer(self, layer_idx: int):
+ "Moves the previous layer cache to the CPU"
+ prev_layer_idx = layer_idx - 1
+ for name, param in self.layers[prev_layer_idx].named_parameters():
+ param.data = param.data.to("cpu", non_blocking=True)
+
+ def get_offlaod_layer(self, layer_idx: int, device: torch.device):
+ # init stream
+ if not hasattr(self, "prefetch_stream"):
+ self.prefetch_stream = torch.cuda.Stream()
+
+ # delete previous layer
+ torch.cuda.current_stream().synchronize()
+ self.evict_previous_layer(layer_idx)
+
+ # make sure the current layer is ready
+ torch.cuda.synchronize(self.prefetch_stream)
+
+ # load next layer
+ self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
+
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ offload_model: Optional[bool] = False,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
+
+ # if inputs_embeds is None:
+ # inputs_embeds = self.embed_tokens(input_ids)
+
+ # if cache_position is None:
+ # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ # cache_position = torch.arange(
+ # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ # )
+ # if position_ids is None:
+ # position_ids = cache_position.unsqueeze(0)
+
+ if attention_mask is not None and attention_mask.dim() == 3:
+ dtype = inputs_embeds.dtype
+ min_dtype = torch.finfo(dtype).min
+ attention_mask = (1 - attention_mask) * min_dtype
+ attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
+ else:
+ raise Exception("attention_mask parameter was unavailable or invalid")
+ # causal_mask = self._update_causal_mask(
+ # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ # )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ layer_idx = -1
+ for decoder_layer in self.layers:
+ layer_idx += 1
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ if offload_model and not self.training:
+ self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ print('************')
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t, dtype=torch.float32):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT.
+ """
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, int):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class PatchEmbedMR(nn.Module):
+ """ 2D Image to Patch Embedding
+ """
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_chans: int = 4,
+ embed_dim: int = 768,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
+ return x
+
+
+class OmniGenOriginalModel(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+ def __init__(
+ self,
+ transformer_config: Phi3Config,
+ patch_size=2,
+ in_channels=4,
+ pe_interpolation: float = 1.0,
+ pos_embed_max_size: int = 192,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.patch_size = patch_size
+ self.pos_embed_max_size = pos_embed_max_size
+
+ hidden_size = transformer_config.hidden_size
+
+ self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
+ self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
+
+ self.time_token = TimestepEmbedder(hidden_size)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+
+ self.pe_interpolation = pe_interpolation
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
+
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
+
+ self.initialize_weights()
+
+ self.llm = Phi3Transformer(config=transformer_config)
+ self.llm.config.use_cache = False
+
+ @classmethod
+ def from_pretrained(cls, model_name):
+ if not os.path.exists(model_name):
+ cache_folder = os.getenv('HF_HUB_CACHE')
+ model_name = snapshot_download(repo_id=model_name,
+ cache_dir=cache_folder,
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
+ config = Phi3Config.from_pretrained(model_name)
+ model = cls(config)
+ if os.path.exists(os.path.join(model_name, 'model.safetensors')):
+ print("Loading safetensors")
+ ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
+ else:
+ ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
+ model.load_state_dict(ckpt)
+ return model
+
+ def initialize_weights(self):
+ assert not hasattr(self, "llama")
+
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ w = self.input_x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+ nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x, h, w):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+
+ x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h, w))
+ return imgs
+
+
+ def cropped_pos_embed(self, height, width):
+ """Crops positional embeddings for SD3 compatibility."""
+ if self.pos_embed_max_size is None:
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ if height > self.pos_embed_max_size:
+ raise ValueError(
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+ if width > self.pos_embed_max_size:
+ raise ValueError(
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+
+ top = (self.pos_embed_max_size - height) // 2
+ left = (self.pos_embed_max_size - width) // 2
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
+ # print(top, top + height, left, left + width, spatial_pos_embed.size())
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
+ return spatial_pos_embed
+
+
+ def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
+ if isinstance(latents, list):
+ return_list = False
+ if padding_latent is None:
+ padding_latent = [None] * len(latents)
+ return_list = True
+ patched_latents, num_tokens, shapes = [], [], []
+ for latent, padding in zip(latents, padding_latent):
+ height, width = latent.shape[-2:]
+ if is_input_images:
+ latent = self.input_x_embedder(latent)
+ else:
+ latent = self.x_embedder(latent)
+ pos_embed = self.cropped_pos_embed(height, width)
+ latent = latent + pos_embed
+ if padding is not None:
+ latent = torch.cat([latent, padding], dim=-2)
+ patched_latents.append(latent)
+
+ num_tokens.append(pos_embed.size(1))
+ shapes.append([height, width])
+ if not return_list:
+ latents = torch.cat(patched_latents, dim=0)
+ else:
+ latents = patched_latents
+ else:
+ height, width = latents.shape[-2:]
+ if is_input_images:
+ latents = self.input_x_embedder(latents)
+ else:
+ latents = self.x_embedder(latents)
+ pos_embed = self.cropped_pos_embed(height, width)
+ latents = latents + pos_embed
+ num_tokens = latents.size(1)
+ shapes = [height, width]
+ return latents, num_tokens, shapes
+
+
+ def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
+ """
+
+ """
+ input_is_list = isinstance(x, list)
+ x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
+ time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
+
+ if input_img_latents is not None:
+ input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
+ if input_ids is not None:
+ condition_embeds = self.llm.embed_tokens(input_ids).clone()
+ input_img_inx = 0
+ for b_inx in input_image_sizes.keys():
+ for start_inx, end_inx in input_image_sizes[b_inx]:
+ condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
+ input_img_inx += 1
+ if input_img_latents is not None:
+ assert input_img_inx == len(input_latents)
+
+ input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
+ else:
+ input_emb = torch.cat([time_token, x], dim=1)
+ output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
+ output, past_key_values = output.last_hidden_state, output.past_key_values
+ if input_is_list:
+ image_embedding = output[:, -max(num_tokens):]
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
+ x = self.final_layer(image_embedding, time_emb)
+ latents = []
+ for i in range(x.size(0)):
+ latent = x[i:i+1, :num_tokens[i]]
+ latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
+ latents.append(latent)
+ else:
+ image_embedding = output[:, -num_tokens:]
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
+ x = self.final_layer(image_embedding, time_emb)
+ latents = self.unpatchify(x, shapes[0], shapes[1])
+
+ if return_past_key_values:
+ return latents, past_key_values
+ return latents
+
+ @torch.no_grad()
+ def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
+ self.llm.config.use_cache = use_kv_cache
+ model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
+ if use_img_cfg:
+ cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
+ cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
+ model_out = [cond, cond, cond]
+ else:
+ cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
+ cond = uncond + cfg_scale * (cond - uncond)
+ model_out = [cond, cond]
+
+ return torch.cat(model_out, dim=0), past_key_values
+
+
+ @torch.no_grad()
+ def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
+ self.llm.config.use_cache = use_kv_cache
+ if past_key_values is None:
+ past_key_values = [None] * len(attention_mask)
+
+ x = torch.split(x, len(x) // len(attention_mask), dim=0)
+ timestep = timestep.to(x[0].dtype)
+ timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
+
+ model_out, pask_key_values = [], []
+ for i in range(len(input_ids)):
+ temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
+ model_out.append(temp_out)
+ pask_key_values.append(temp_pask_key_values)
+
+ if len(model_out) == 3:
+ cond, uncond, img_cond = model_out
+ cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
+ model_out = [cond, cond, cond]
+ elif len(model_out) == 2:
+ cond, uncond = model_out
+ cond = uncond + cfg_scale * (cond - uncond)
+ model_out = [cond, cond]
+ else:
+ return model_out[0]
+
+ return torch.cat(model_out, dim=0), pask_key_values
+
+
+
+class OmniGenTransformer(OmniGenOriginalModel):
+ def __init__(self):
+ config = {
+ "_name_or_path": "Phi-3-vision-128k-instruct",
+ "architectures": [
+ "Phi3ForCausalLM"
+ ],
+ "attention_dropout": 0.0,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "hidden_act": "silu",
+ "hidden_size": 3072,
+ "initializer_range": 0.02,
+ "intermediate_size": 8192,
+ "max_position_embeddings": 131072,
+ "model_type": "phi3",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 32,
+ "original_max_position_embeddings": 4096,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {
+ "long_factor": [
+ 1.0299999713897705,
+ 1.0499999523162842,
+ 1.0499999523162842,
+ 1.0799999237060547,
+ 1.2299998998641968,
+ 1.2299998998641968,
+ 1.2999999523162842,
+ 1.4499999284744263,
+ 1.5999999046325684,
+ 1.6499998569488525,
+ 1.8999998569488525,
+ 2.859999895095825,
+ 3.68999981880188,
+ 5.419999599456787,
+ 5.489999771118164,
+ 5.489999771118164,
+ 9.09000015258789,
+ 11.579999923706055,
+ 15.65999984741211,
+ 15.769999504089355,
+ 15.789999961853027,
+ 18.360000610351562,
+ 21.989999771118164,
+ 23.079999923706055,
+ 30.009998321533203,
+ 32.35000228881836,
+ 32.590003967285156,
+ 35.56000518798828,
+ 39.95000457763672,
+ 53.840003967285156,
+ 56.20000457763672,
+ 57.95000457763672,
+ 59.29000473022461,
+ 59.77000427246094,
+ 59.920005798339844,
+ 61.190006256103516,
+ 61.96000671386719,
+ 62.50000762939453,
+ 63.3700065612793,
+ 63.48000717163086,
+ 63.48000717163086,
+ 63.66000747680664,
+ 63.850006103515625,
+ 64.08000946044922,
+ 64.760009765625,
+ 64.80001068115234,
+ 64.81001281738281,
+ 64.81001281738281
+ ],
+ "short_factor": [
+ 1.05,
+ 1.05,
+ 1.05,
+ 1.1,
+ 1.1,
+ 1.1,
+ 1.2500000000000002,
+ 1.2500000000000002,
+ 1.4000000000000004,
+ 1.4500000000000004,
+ 1.5500000000000005,
+ 1.8500000000000008,
+ 1.9000000000000008,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.1000000000000005,
+ 2.1000000000000005,
+ 2.2,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3999999999999995,
+ 2.3999999999999995,
+ 2.6499999999999986,
+ 2.6999999999999984,
+ 2.8999999999999977,
+ 2.9499999999999975,
+ 3.049999999999997,
+ 3.049999999999997,
+ 3.049999999999997
+ ],
+ "type": "su"
+ },
+ "rope_theta": 10000.0,
+ "sliding_window": 131072,
+ "tie_word_embeddings": False,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.38.1",
+ "use_cache": True,
+ "vocab_size": 32064,
+ "_attn_implementation": "sdpa"
+ }
+ config = Phi3Config(**config)
+ super().__init__(config)
+
+
+ def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
+ input_is_list = isinstance(x, list)
+ x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
+ time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
+
+ if input_img_latents is not None:
+ input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
+ if input_ids is not None:
+ condition_embeds = self.llm.embed_tokens(input_ids).clone()
+ input_img_inx = 0
+ for b_inx in input_image_sizes.keys():
+ for start_inx, end_inx in input_image_sizes[b_inx]:
+ condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
+ input_img_inx += 1
+ if input_img_latents is not None:
+ assert input_img_inx == len(input_latents)
+
+ input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
+ else:
+ input_emb = torch.cat([time_token, x], dim=1)
+ output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
+ output, past_key_values = output.last_hidden_state, output.past_key_values
+ if input_is_list:
+ image_embedding = output[:, -max(num_tokens):]
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
+ x = self.final_layer(image_embedding, time_emb)
+ latents = []
+ for i in range(x.size(0)):
+ latent = x[i:i+1, :num_tokens[i]]
+ latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
+ latents.append(latent)
+ else:
+ image_embedding = output[:, -num_tokens:]
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
+ x = self.final_layer(image_embedding, time_emb)
+ latents = self.unpatchify(x, shapes[0], shapes[1])
+
+ if return_past_key_values:
+ return latents, past_key_values
+ return latents
+
+
+ @torch.no_grad()
+ def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
+ self.llm.config.use_cache = use_kv_cache
+ if past_key_values is None:
+ past_key_values = [None] * len(attention_mask)
+
+ x = torch.split(x, len(x) // len(attention_mask), dim=0)
+ timestep = timestep.to(x[0].dtype)
+ timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
+
+ model_out, pask_key_values = [], []
+ for i in range(len(input_ids)):
+ temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
+ model_out.append(temp_out)
+ pask_key_values.append(temp_pask_key_values)
+
+ if len(model_out) == 3:
+ cond, uncond, img_cond = model_out
+ cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
+ model_out = [cond, cond, cond]
+ elif len(model_out) == 2:
+ cond, uncond = model_out
+ cond = uncond + cfg_scale * (cond - uncond)
+ model_out = [cond, cond]
+ else:
+ return model_out[0]
+
+ return torch.cat(model_out, dim=0), pask_key_values
+
+
+ @staticmethod
+ def state_dict_converter():
+ return OmniGenTransformerStateDictConverter()
+
+
+
+class OmniGenTransformerStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ return state_dict
+
+ def from_civitai(self, state_dict):
+ return state_dict
diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py
index 79d9fc5..7636255 100644
--- a/diffsynth/pipelines/__init__.py
+++ b/diffsynth/pipelines/__init__.py
@@ -7,5 +7,6 @@ from .hunyuan_image import HunyuanDiTImagePipeline
from .svd_video import SVDVideoPipeline
from .flux_image import FluxImagePipeline
from .cog_video import CogVideoPipeline
+from .omnigen_image import OmnigenImagePipeline
from .pipeline_runner import SDVideoPipelineRunner
KolorsImagePipeline = SDXLImagePipeline
diff --git a/diffsynth/pipelines/omnigen_image.py b/diffsynth/pipelines/omnigen_image.py
new file mode 100644
index 0000000..89fd4ac
--- /dev/null
+++ b/diffsynth/pipelines/omnigen_image.py
@@ -0,0 +1,287 @@
+from ..models.omnigen import OmniGenTransformer
+from ..models.sdxl_vae_encoder import SDXLVAEEncoder
+from ..models.sdxl_vae_decoder import SDXLVAEDecoder
+from ..models.model_manager import ModelManager
+from ..prompters.omnigen_prompter import OmniGenPrompter
+from ..schedulers import FlowMatchScheduler
+from .base import BasePipeline
+from typing import Optional, Dict, Any, Tuple, List
+from transformers.cache_utils import DynamicCache
+import torch, os
+from tqdm import tqdm
+
+
+
+class OmniGenCache(DynamicCache):
+ def __init__(self,
+ num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
+ if not torch.cuda.is_available():
+ print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
+ offload_kv_cache = False
+ raise RuntimeError("OffloadedCache can only be used with a GPU")
+ super().__init__()
+ self.original_device = []
+ self.prefetch_stream = torch.cuda.Stream()
+ self.num_tokens_for_img = num_tokens_for_img
+ self.offload_kv_cache = offload_kv_cache
+
+ def prefetch_layer(self, layer_idx: int):
+ "Starts prefetching the next layer cache"
+ if layer_idx < len(self):
+ with torch.cuda.stream(self.prefetch_stream):
+ # Prefetch next layer tensors to GPU
+ device = self.original_device[layer_idx]
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
+
+
+ def evict_previous_layer(self, layer_idx: int):
+ "Moves the previous layer cache to the CPU"
+ if len(self) > 2:
+ # We do it on the default stream so it occurs after all earlier computations on these tensors are done
+ if layer_idx == 0:
+ prev_layer_idx = -1
+ else:
+ prev_layer_idx = (layer_idx - 1) % len(self)
+ self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
+ self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
+
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+ "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
+ if layer_idx < len(self):
+ if self.offload_kv_cache:
+ # Evict the previous layer if necessary
+ torch.cuda.current_stream().synchronize()
+ self.evict_previous_layer(layer_idx)
+ # Load current layer cache to its original device if not already there
+ original_device = self.original_device[layer_idx]
+ # self.prefetch_stream.synchronize(original_device)
+ torch.cuda.synchronize(self.prefetch_stream)
+ key_tensor = self.key_cache[layer_idx]
+ value_tensor = self.value_cache[layer_idx]
+
+ # Prefetch the next layer
+ self.prefetch_layer((layer_idx + 1) % len(self))
+ else:
+ key_tensor = self.key_cache[layer_idx]
+ value_tensor = self.value_cache[layer_idx]
+ return (key_tensor, value_tensor)
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Update the cache
+ if len(self.key_cache) < layer_idx:
+ raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
+ elif len(self.key_cache) == layer_idx:
+ # only cache the states for condition tokens
+ key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
+ value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
+
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ self.original_device.append(key_states.device)
+ if self.offload_kv_cache:
+ self.evict_previous_layer(layer_idx)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ # only cache the states for condition tokens
+ key_tensor, value_tensor = self[layer_idx]
+ k = torch.cat([key_tensor, key_states], dim=-2)
+ v = torch.cat([value_tensor, value_states], dim=-2)
+ return k, v
+
+
+
+class OmnigenImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchScheduler(num_train_timesteps=1, shift=1, inverse_timesteps=True, sigma_min=0, sigma_max=1)
+ # models
+ self.vae_decoder: SDXLVAEDecoder = None
+ self.vae_encoder: SDXLVAEEncoder = None
+ self.transformer: OmniGenTransformer = None
+ self.prompter: OmniGenPrompter = None
+ self.model_names = ['transformer', 'vae_decoder', 'vae_encoder']
+
+
+ def denoising_model(self):
+ return self.transformer
+
+
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
+ # Main models
+ self.transformer, model_path = model_manager.fetch_model("omnigen_transformer", require_model_path=True)
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
+ self.prompter = OmniGenPrompter.from_pretrained(os.path.dirname(model_path))
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
+ pipe = OmnigenImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, prompt_refiner_classes=[])
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def encode_images(self, images, tiled=False, tile_size=64, tile_stride=32):
+ latents = [self.encode_image(image.to(device=self.device), tiled, tile_size, tile_stride).to(self.torch_dtype) for image in images]
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, clip_skip=1, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
+ return {"encoder_hidden_states": prompt_emb}
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
+ if isinstance(position_ids, list):
+ for i in range(len(position_ids)):
+ position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
+ else:
+ position_ids = position_ids[:, -(num_tokens_for_img+1):]
+ return position_ids
+
+
+ def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
+ if isinstance(attention_mask, list):
+ return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
+ return attention_mask[..., -(num_tokens_for_img+1):, :]
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ reference_images=[],
+ cfg_scale=2.0,
+ image_cfg_scale=2.0,
+ use_kv_cache=True,
+ offload_kv_cache=True,
+ input_image=None,
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ latents = self.encode_image(image, **tiler_kwargs)
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = latents.repeat(3, 1, 1, 1)
+
+ # Encode prompts
+ input_data = self.prompter(prompt, reference_images, height=height, width=width, use_img_cfg=True, separate_cfg_input=True, use_input_image_size_as_output=False)
+
+ # Encode images
+ reference_latents = [self.encode_images(images, **tiler_kwargs) for images in input_data['input_pixel_values']]
+
+ # Pack all parameters
+ model_kwargs = dict(input_ids=[input_ids.to(self.device) for input_ids in input_data['input_ids']],
+ input_img_latents=reference_latents,
+ input_image_sizes=input_data['input_image_sizes'],
+ attention_mask=[attention_mask.to(self.device) for attention_mask in input_data["attention_mask"]],
+ position_ids=[position_ids.to(self.device) for position_ids in input_data["position_ids"]],
+ cfg_scale=cfg_scale,
+ img_cfg_scale=image_cfg_scale,
+ use_img_cfg=True,
+ use_kv_cache=use_kv_cache,
+ offload_model=False,
+ )
+
+ # Denoise
+ self.load_models_to_device(['transformer'])
+ cache = [OmniGenCache(latents.size(-1)*latents.size(-2) // 4, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).repeat(latents.shape[0]).to(self.device)
+
+ # Forward
+ noise_pred, cache = self.transformer.forward_with_separate_cfg(latents, timestep, past_key_values=cache, **model_kwargs)
+
+ # Scheduler
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # Update KV cache
+ if progress_id == 0 and use_kv_cache:
+ num_tokens_for_img = latents.size(-1)*latents.size(-2) // 4
+ if isinstance(cache, list):
+ model_kwargs['input_ids'] = [None] * len(cache)
+ else:
+ model_kwargs['input_ids'] = None
+ model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
+ model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ del cache
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/diffsynth/prompters/omnigen_prompter.py b/diffsynth/prompters/omnigen_prompter.py
new file mode 100644
index 0000000..8a6c38b
--- /dev/null
+++ b/diffsynth/prompters/omnigen_prompter.py
@@ -0,0 +1,327 @@
+import os
+import re
+from typing import Dict, List
+
+import torch
+from PIL import Image
+from torchvision import transforms
+from transformers import AutoTokenizer
+from huggingface_hub import snapshot_download
+
+from OmniGen.utils import crop_arr
+
+
+
+class OmniGenPrompter:
+ def __init__(self,
+ text_tokenizer,
+ max_image_size: int=1024):
+ self.text_tokenizer = text_tokenizer
+ self.max_image_size = max_image_size
+
+ self.image_transform = transforms.Compose([
+ transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ self.collator = OmniGenCollator()
+ self.separate_collator = OmniGenSeparateCollator()
+
+ @classmethod
+ def from_pretrained(cls, model_name):
+ if not os.path.exists(model_name):
+ cache_folder = os.getenv('HF_HUB_CACHE')
+ model_name = snapshot_download(repo_id=model_name,
+ cache_dir=cache_folder,
+ allow_patterns="*.json")
+ text_tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ return cls(text_tokenizer)
+
+
+ def process_image(self, image):
+ return self.image_transform(image)
+
+ def process_multi_modal_prompt(self, text, input_images):
+ text = self.add_prefix_instruction(text)
+ if input_images is None or len(input_images) == 0:
+ model_inputs = self.text_tokenizer(text)
+ return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
+
+ pattern = r"<\|image_\d+\|>"
+ prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
+
+ for i in range(1, len(prompt_chunks)):
+ if prompt_chunks[i][0] == 1:
+ prompt_chunks[i] = prompt_chunks[i][1:]
+
+ image_tags = re.findall(pattern, text)
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
+
+ unique_image_ids = sorted(list(set(image_ids)))
+ assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ # total images must be the same as the number of image tags
+ assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+
+ input_images = [input_images[x-1] for x in image_ids]
+
+ all_input_ids = []
+ img_inx = []
+ idx = 0
+ for i in range(len(prompt_chunks)):
+ all_input_ids.extend(prompt_chunks[i])
+ if i != len(prompt_chunks) -1:
+ start_inx = len(all_input_ids)
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
+ img_inx.append([start_inx, start_inx+size])
+ all_input_ids.extend([0]*size)
+
+ return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
+
+
+ def add_prefix_instruction(self, prompt):
+ user_prompt = '<|user|>\n'
+ generation_prompt = 'Generate an image according to the following instructions\n'
+ assistant_prompt = '<|assistant|>\n<|diffusion|>'
+ prompt_suffix = "<|end|>\n"
+ prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
+ return prompt
+
+
+ def __call__(self,
+ instructions: List[str],
+ input_images: List[List[str]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
+ use_img_cfg: bool = True,
+ separate_cfg_input: bool = False,
+ use_input_image_size_as_output: bool=False,
+ ) -> Dict:
+
+ if input_images is None:
+ use_img_cfg = False
+ if isinstance(instructions, str):
+ instructions = [instructions]
+ input_images = [input_images]
+
+ input_data = []
+ for i in range(len(instructions)):
+ cur_instruction = instructions[i]
+ cur_input_images = None if input_images is None else input_images[i]
+ if cur_input_images is not None and len(cur_input_images) > 0:
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
+ else:
+ cur_input_images = None
+ assert "
<|image_1|>" not in cur_instruction
+
+ mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
+
+
+ neg_mllm_input, img_cfg_mllm_input = None, None
+ neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
+ if use_img_cfg:
+ if cur_input_images is not None and len(cur_input_images) >= 1:
+ img_cfg_prompt = [f"
<|image_{i+1}|>" for i in range(len(cur_input_images))]
+ img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
+ else:
+ img_cfg_mllm_input = neg_mllm_input
+
+ if use_input_image_size_as_output:
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
+ else:
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
+
+ if separate_cfg_input:
+ return self.separate_collator(input_data)
+ return self.collator(input_data)
+
+
+
+
+class OmniGenCollator:
+ def __init__(self, pad_token_id=2, hidden_size=3072):
+ self.pad_token_id = pad_token_id
+ self.hidden_size = hidden_size
+
+ def create_position(self, attention_mask, num_tokens_for_output_images):
+ position_ids = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
+ position_ids.append(temp_position)
+ return torch.LongTensor(position_ids)
+
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
+ extended_mask = []
+ padding_images = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
+ inx = 0
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ pad_l = text_length - temp_l
+
+ temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
+
+ image_mask = torch.zeros(size=(temp_l+1, img_length))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
+
+ image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
+
+ if pad_l > 0:
+ pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
+
+ pad_mask = torch.ones(size=(pad_l, seq_len))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
+
+ true_img_length = num_tokens_for_output_images[inx]
+ pad_img_length = img_length - true_img_length
+ if pad_img_length > 0:
+ temp_mask[:, -pad_img_length:] = 0
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
+ else:
+ temp_padding_imgs = None
+
+ extended_mask.append(temp_mask.unsqueeze(0))
+ padding_images.append(temp_padding_imgs)
+ inx += 1
+ return torch.cat(extended_mask, dim=0), padding_images
+
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
+ for b_inx in image_sizes.keys():
+ for start_inx, end_inx in image_sizes[b_inx]:
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
+
+ return attention_mask
+
+ def pad_input_ids(self, input_ids, image_sizes):
+ max_l = max([len(x) for x in input_ids])
+ padded_ids = []
+ attention_mask = []
+ new_image_sizes = []
+
+ for i in range(len(input_ids)):
+ temp_ids = input_ids[i]
+ temp_l = len(temp_ids)
+ pad_l = max_l - temp_l
+ if pad_l == 0:
+ attention_mask.append([1]*max_l)
+ padded_ids.append(temp_ids)
+ else:
+ attention_mask.append([0]*pad_l+[1]*temp_l)
+ padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
+
+ if i in image_sizes:
+ new_inx = []
+ for old_inx in image_sizes[i]:
+ new_inx.append([x+pad_l for x in old_inx])
+ image_sizes[i] = new_inx
+
+ return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
+
+
+ def process_mllm_input(self, mllm_inputs, target_img_size):
+ num_tokens_for_output_images = []
+ for img_size in target_img_size:
+ num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
+
+ pixel_values, image_sizes = [], {}
+ b_inx = 0
+ for x in mllm_inputs:
+ if x['pixel_values'] is not None:
+ pixel_values.extend(x['pixel_values'])
+ for size in x['image_sizes']:
+ if b_inx not in image_sizes:
+ image_sizes[b_inx] = [size]
+ else:
+ image_sizes[b_inx].append(size)
+ b_inx += 1
+ pixel_values = [x.unsqueeze(0) for x in pixel_values]
+
+
+ input_ids = [x['input_ids'] for x in mllm_inputs]
+ padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
+ position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
+ attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
+ attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
+
+ return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
+
+
+ def __call__(self, features):
+ mllm_inputs = [f[0] for f in features]
+ cfg_mllm_inputs = [f[1] for f in features]
+ img_cfg_mllm_input = [f[2] for f in features]
+ target_img_size = [f[3] for f in features]
+
+
+ if img_cfg_mllm_input[0] is not None:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
+ target_img_size = target_img_size + target_img_size + target_img_size
+ else:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs
+ target_img_size = target_img_size + target_img_size
+
+
+ all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
+
+ data = {"input_ids": all_padded_input_ids,
+ "attention_mask": all_attention_mask,
+ "position_ids": all_position_ids,
+ "input_pixel_values": all_pixel_values,
+ "input_image_sizes": all_image_sizes,
+ "padding_images": all_padding_images,
+ }
+ return data
+
+
+class OmniGenSeparateCollator(OmniGenCollator):
+ def __call__(self, features):
+ mllm_inputs = [f[0] for f in features]
+ cfg_mllm_inputs = [f[1] for f in features]
+ img_cfg_mllm_input = [f[2] for f in features]
+ target_img_size = [f[3] for f in features]
+
+ all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
+
+
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
+ all_padded_input_ids.append(padded_input_ids)
+ all_attention_mask.append(attention_mask)
+ all_position_ids.append(position_ids)
+ all_pixel_values.append(pixel_values)
+ all_image_sizes.append(image_sizes)
+ all_padding_images.append(padding_images)
+
+ if cfg_mllm_inputs[0] is not None:
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
+ all_padded_input_ids.append(padded_input_ids)
+ all_attention_mask.append(attention_mask)
+ all_position_ids.append(position_ids)
+ all_pixel_values.append(pixel_values)
+ all_image_sizes.append(image_sizes)
+ all_padding_images.append(padding_images)
+ if img_cfg_mllm_input[0] is not None:
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
+ all_padded_input_ids.append(padded_input_ids)
+ all_attention_mask.append(attention_mask)
+ all_position_ids.append(position_ids)
+ all_pixel_values.append(pixel_values)
+ all_image_sizes.append(image_sizes)
+ all_padding_images.append(padding_images)
+
+ data = {"input_ids": all_padded_input_ids,
+ "attention_mask": all_attention_mask,
+ "position_ids": all_position_ids,
+ "input_pixel_values": all_pixel_values,
+ "input_image_sizes": all_image_sizes,
+ "padding_images": all_padding_images,
+ }
+ return data
diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py
index ab965ee..fe6e762 100644
--- a/diffsynth/schedulers/flow_match.py
+++ b/diffsynth/schedulers/flow_match.py
@@ -4,17 +4,20 @@ import torch
class FlowMatchScheduler():
- def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002):
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
+ self.inverse_timesteps = inverse_timesteps
self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
+ if self.inverse_timesteps:
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
self.timesteps = self.sigmas * self.num_train_timesteps
if training:
@@ -31,7 +34,7 @@ class FlowMatchScheduler():
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
- sigma_ = 0
+ sigma_ = 1 if self.inverse_timesteps else 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)
diff --git a/examples/image_synthesis/README.md b/examples/image_synthesis/README.md
index 65dbe66..ac133d2 100644
--- a/examples/image_synthesis/README.md
+++ b/examples/image_synthesis/README.md
@@ -2,6 +2,14 @@
Image synthesis is the base feature of DiffSynth Studio. We can generate images with very high resolution.
+### OmniGen
+
+OmniGen is a text-image-to-image model, you can synthesize an image according to several given reference images.
+
+|Reference image 1|Reference image 2|Synthesized image|
+|-|-|-|
+||||
+
### Example: FLUX
Example script: [`flux_text_to_image.py`](./flux_text_to_image.py) and [`flux_text_to_image_low_vram.py`](./flux_text_to_image_low_vram.py)(low VRAM).
diff --git a/examples/image_synthesis/omnigen_text_to_image.py b/examples/image_synthesis/omnigen_text_to_image.py
new file mode 100644
index 0000000..4277753
--- /dev/null
+++ b/examples/image_synthesis/omnigen_text_to_image.py
@@ -0,0 +1,25 @@
+import torch
+from diffsynth import ModelManager, OmnigenImagePipeline
+
+
+model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["OmniGen-v1"])
+pipe = OmnigenImagePipeline.from_model_manager(model_manager)
+
+image_man = pipe(
+ prompt="A portrait of a man.",
+ cfg_scale=2.5, num_inference_steps=50, seed=0
+)
+image_man.save("image_man.jpg")
+
+image_woman = pipe(
+ prompt="A portrait of an Asian woman with a white t-shirt.",
+ cfg_scale=2.5, num_inference_steps=50, seed=1
+)
+image_woman.save("image_woman.jpg")
+
+image_merged = pipe(
+ prompt="a man and a woman. The man is the man in
<|image_1|>. The woman is the woman in
<|image_2|>.",
+ reference_images=[image_man, image_woman],
+ cfg_scale=2.5, image_cfg_scale=2.5, num_inference_steps=50, seed=2
+)
+image_merged.save("image_merged.jpg")
diff --git a/requirements.txt b/requirements.txt
index df207cc..9bedb1a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
torch>=2.0.0
cupy-cuda12x
-transformers==4.44.1
+transformers==4.46.2
controlnet-aux==0.0.7
imageio
imageio[ffmpeg]