# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import importlib.metadata import math from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models import ModelMixin from diffusers.utils import is_torch_version, logging from einops import rearrange try: from flash_attn import flash_attn_func, flash_attn_qkvpacked_func except ImportError: flash_attn_func = None MEMORY_LAYOUT = { "flash": ( lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), lambda x: x, ), "torch": ( lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2), ), "vanilla": ( lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2), ), } def attention( q, k, v, mode="flash", drop_rate=0, attn_mask=None, causal=False, max_seqlen_q=None, batch_size=1, ): """ Perform QKV self attention. Args: q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. k (torch.Tensor): Key tensor with shape [b, s1, a, d] v (torch.Tensor): Value tensor with shape [b, s1, a, d] mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. drop_rate (float): Dropout rate in attention map. (default: 0) attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). (default: None) causal (bool): Whether to use causal attention. (default: False) cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_q (int): The maximum sequence length in the batch of q. max_seqlen_kv (int): The maximum sequence length in the batch of k and v. Returns: torch.Tensor: Output tensor after self attention with shape [b, s, ad] """ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] if mode == "torch": if attn_mask is not None and attn_mask.dtype != torch.bool: attn_mask = attn_mask.to(q.dtype) x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) elif mode == "flash": x = flash_attn_func( q, k, v, ) # x with shape [(bxs), a, d] x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] elif mode == "vanilla": scale_factor = 1 / math.sqrt(q.size(-1)) b, a, s, _ = q.shape s1 = k.size(2) attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) if causal: # Only applied to self attention assert ( attn_mask is None), "Causal mask and attn_mask cannot be used together" temp_mask = torch.ones( b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(q.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask # TODO: Maybe force q and k to be float32 to avoid numerical overflow attn = (q @ k.transpose(-2, -1)) * scale_factor attn += attn_bias attn = attn.softmax(dim=-1) attn = torch.dropout(attn, p=drop_rate, train=True) x = attn @ v else: raise NotImplementedError(f"Unsupported attention mode: {mode}") x = post_attn_layout(x) b, s, a, d = x.shape out = x.reshape(b, s, -1) return out class CausalConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs): super().__init__() self.pad_mode = pad_mode padding = (kernel_size - 1, 0) # T self.time_causal_padding = padding self.conv = nn.Conv1d( chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) def forward(self, x): x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) return self.conv(x) class MotionEncoder_tc(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None): factory_kwargs = {"dtype": dtype, "device": device} super().__init__() self.num_heads = num_heads self.need_global = need_global self.conv1_local = CausalConv1d( in_dim, hidden_dim // 4 * num_heads, 3, stride=1) if need_global: self.conv1_global = CausalConv1d( in_dim, hidden_dim // 4, 3, stride=1) self.norm1 = nn.LayerNorm( hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.act = nn.SiLU() self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) if need_global: self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) self.norm1 = nn.LayerNorm( hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.norm2 = nn.LayerNorm( hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.norm3 = nn.LayerNorm( hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) def forward(self, x): x = rearrange(x, 'b t c -> b c t') x_ori = x.clone() b, c, t = x.shape x = self.conv1_local(x) x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) x = self.norm1(x) x = self.act(x) x = rearrange(x, 'b t c -> b c t') x = self.conv2(x) x = rearrange(x, 'b c t -> b t c') x = self.norm2(x) x = self.act(x) x = rearrange(x, 'b t c -> b c t') x = self.conv3(x) x = rearrange(x, 'b c t -> b t c') x = self.norm3(x) x = self.act(x) x = rearrange(x, '(b n) t c -> b t n c', b=b) padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) x = torch.cat([x, padding], dim=-2) x_local = x.clone() if not self.need_global: return x_local x = self.conv1_global(x_ori) x = rearrange(x, 'b c t -> b t c') x = self.norm1(x) x = self.act(x) x = rearrange(x, 'b t c -> b c t') x = self.conv2(x) x = rearrange(x, 'b c t -> b t c') x = self.norm2(x) x = self.act(x) x = rearrange(x, 'b t c -> b c t') x = self.conv3(x) x = rearrange(x, 'b c t -> b t c') x = self.norm3(x) x = self.act(x) x = self.final_linear(x) x = rearrange(x, '(b n) t c -> b t n c', b=b) return x, x_local