class UpsampleConformerEncoder(Module):
  __parameters__ = []
  __buffers__ = []
  embed : __torch__.cosyvoice.transformer.subsampling.___torch_mangle_11.LinearNoSubsampling
  up_embed : __torch__.cosyvoice.transformer.subsampling.___torch_mangle_11.LinearNoSubsampling
  def forward(self: __torch__.cosyvoice.transformer.upsample_encoder.___torch_mangle_10.UpsampleConformerEncoder,
    xs: Tensor,
    xs_lens: Tensor,
    decoding_chunk_size: int=0,
    num_decoding_left_chunks: int=-1) -> Tuple[Tensor, Tensor]:
    _0 = "Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"
    T = torch.size(xs, 1)
    batch_size = torch.size(xs_lens, 0)
    if torch.gt(T, 0):
      max_len = T
    else:
      max_len = torch.item(torch.max(xs_lens))
    seq_range = torch.arange(0, max_len, dtype=4, layout=None, device=ops.prim.device(xs_lens))
    seq_range_expand = torch.expand(torch.unsqueeze(seq_range, 0), [batch_size, int(max_len)])
    seq_length_expand = torch.unsqueeze(xs_lens, -1)
    mask = torch.ge(seq_range_expand, seq_length_expand)
    masks = torch.bitwise_not(torch.unsqueeze(mask, 1))
    embed = self.embed
    _1 = torch.add(torch.matmul(xs, CONSTANTS.c0), CONSTANTS.c1)
    input = torch.layer_norm(_1, [512], CONSTANTS.c2, CONSTANTS.c3)
    pos_enc = embed.pos_enc
    pe = pos_enc.pe
    _2 = torch.size(pe, 1)
    _3 = torch.size(input, 1)
    _4 = torch.ge(_2, torch.sub(torch.mul(_3, 2), 1))
    if _4:
      pe0 = pos_enc.pe
      _5 = ops.prim.dtype(pe0)
      _6 = ops.prim.dtype(input)
      if torch.ne(_5, _6):
        _7 = True
      else:
        pe1 = pos_enc.pe
        _8 = torch.ne(ops.prim.device(pe1), ops.prim.device(input))
        _7 = _8
      if _7:
        pe2 = pos_enc.pe
        _9 = torch.to(pe2, ops.prim.device(input), _6)
        pos_enc.pe = _9
      else:
        pass
    else:
      _10 = [_3, 512]
      pe_positive = torch.zeros(_10)
      pe_negative = torch.zeros(_10)
      position = torch.unsqueeze(torch.arange(0, _3, dtype=6), 1)
      _11 = torch.mul(position, CONSTANTS.c4)
      _12 = torch.sin(_11)
      _13 = torch.slice(torch.slice(pe_positive), 1, 0, None, 2)
      _14 = torch.copy_(_13, _12)
      _15 = torch.cos(_11)
      _16 = torch.slice(torch.slice(pe_positive), 1, 1, None, 2)
      _17 = torch.copy_(_16, _15)
      _18 = torch.mul(torch.mul(position, -1), CONSTANTS.c4)
      _19 = torch.sin(_18)
      _20 = torch.slice(torch.slice(pe_negative), 1, 0, None, 2)
      _21 = torch.copy_(_20, _19)
      _22 = torch.cos(_18)
      _23 = torch.slice(torch.slice(pe_negative), 1, 1, None, 2)
      _24 = torch.copy_(_23, _22)
      pe_positive0 = torch.unsqueeze(torch.flip(pe_positive, [0]), 0)
      pe_negative0 = torch.unsqueeze(torch.slice(pe_negative, 0, 1), 0)
      pe3 = torch.cat([pe_positive0, pe_negative0], 1)
      _25 = torch.to(pe3, ops.prim.device(input), ops.prim.dtype(input))
      pos_enc.pe = _25
    x = torch.mul(input, 22.627416997969522)
    _26 = torch.size(x, 1)
    pe4 = pos_enc.pe
    _27 = torch.slice(pe4)
    pe5 = pos_enc.pe
    _28 = torch.floordiv(torch.size(pe5, 1), 2)
    _29 = torch.add(torch.sub(_28, _26), 1)
    pe6 = pos_enc.pe
    _30 = torch.floordiv(torch.size(pe6, 1), 2)
    pos_emb = torch.slice(_27, 1, _29, torch.add(_30, _26))
    _31 = torch.size(x, 1)
    _32 = ops.prim.device(x)
    chunk_masks = torch.zeros([_31, _31], dtype=11, layout=None, device=_32)
    for i in range(_31):
      _33 = torch.lt(num_decoding_left_chunks, 0)
      if _33:
        start = 0
      else:
        _34 = torch.sub(torch.floordiv(i, 50), num_decoding_left_chunks)
        start0 = ops.prim.max(torch.mul(_34, 50), 0)
        start = start0
      _35 = torch.add(torch.floordiv(i, 50), 1)
      ending = ops.prim.min(torch.mul(_35, 50), _31)
      _36 = torch.slice(torch.select(chunk_masks, 0, i), 0, start, ending)
      _37 = torch.tensor(True, dtype=ops.prim.dtype(_36), device=ops.prim.device(_36))
      _38 = torch.copy_(_36, _37)
    chunk_masks0 = torch.unsqueeze(chunk_masks, 0)
    chunk_masks1 = torch.__and__(masks, chunk_masks0)
    outputs = torch.contiguous(torch.transpose(x, 1, 2))
    outputs0 = torch.pad(outputs, [0, 3], "constant", 0.)
    _39 = torch.conv1d(outputs0, CONSTANTS.c5, CONSTANTS.c6)
    result = torch.leaky_relu(_39)
    outputs1 = torch.pad(result, [2, 0], "constant", 0.)
    outputs2 = torch.conv1d(outputs1, CONSTANTS.c7, CONSTANTS.c8)
    outputs3 = torch.contiguous(torch.transpose(outputs2, 1, 2))
    xs0 = torch.add(outputs3, x)
    x0 = torch.layer_norm(xs0, [512], CONSTANTS.c9, CONSTANTS.c10, 9.9999999999999998e-13)
    n_batch = torch.size(x0, 0)
    _40 = torch.add(torch.matmul(x0, CONSTANTS.c11), CONSTANTS.c12)
    _41 = torch.slice(_40, -1, 1024, 1536)
    _42 = torch.slice(_40, -1, 512, 1024)
    _43 = torch.slice(_40, -1, 0, 512)
    _44 = [n_batch, -1, 8, 64]
    q = torch.view(_43, _44)
    k = torch.view(_42, _44)
    v = torch.view(_41, _44)
    q0 = torch.transpose(q, 1, 2)
    k0 = torch.transpose(k, 1, 2)
    v0 = torch.transpose(v, 1, 2)
    q1 = torch.transpose(q0, 1, 2)
    n_batch_pos = torch.size(pos_emb, 0)
    _45 = torch.matmul(pos_emb, CONSTANTS.c13)
    _46 = [n_batch_pos, -1, 8, 64]
    p = torch.view(_45, _46)
    p0 = torch.transpose(p, 1, 2)
    q_with_bias_u = torch.transpose(torch.add(q1, CONSTANTS.c14), 1, 2)
    q_with_bias_v = torch.transpose(torch.add(q1, CONSTANTS.c15), 1, 2)
    matrix_ac = torch.matmul(q_with_bias_u, torch.transpose(k0, -2, -1))
    matrix_bd = torch.matmul(q_with_bias_v, torch.transpose(p0, -2, -1))
    _47 = torch.size(matrix_ac)
    _48 = torch.size(matrix_bd)
    if torch.ne(_47, _48):
      _49 = _48[0]
      _50 = _48[1]
      _51 = _48[2]
      _52 = ops.prim.device(matrix_bd)
      _53 = ops.prim.dtype(matrix_bd)
      zero_pad = torch.zeros([_49, _50, _51, 1], dtype=_53, layout=None, device=_52)
      x_padded = torch.cat([zero_pad, matrix_bd], -1)
      _54 = torch.add(torch.size(matrix_bd, 3), 1)
      _55 = [_49, _50, _54, torch.size(matrix_bd, 2)]
      x_padded0 = torch.view(x_padded, _55)
      _56 = torch.slice(torch.slice(x_padded0), 1)
      _57 = torch.view_as(torch.slice(_56, 2, 1), matrix_bd)
      _58 = torch.slice(torch.slice(torch.slice(_57), 1), 2)
      _59 = torch.floordiv(torch.size(matrix_bd, -1), 2)
      matrix_bd1 = torch.slice(_58, 3, None, torch.add(_59, 1))
      matrix_bd0 = matrix_bd1
    else:
      matrix_bd0 = matrix_bd
    scores = torch.div(torch.add(matrix_ac, matrix_bd0), 8.)
    n_batch0 = torch.size(v0, 0)
    _60 = torch.gt(torch.size(chunk_masks1, 2), 0)
    if _60:
      mask0 = torch.eq(torch.unsqueeze(chunk_masks1, 1), 0)
      _61 = torch.slice(torch.slice(mask0), 1)
      mask1 = torch.slice(torch.slice(_61, 2), 3, None, torch.size(scores, -1))
      scores0 = torch.masked_fill(scores, mask1, -inf)
      attn0 = torch.masked_fill(torch.softmax(scores0, -1), mask1, 0.)
      attn = attn0
    else:
      attn = torch.softmax(scores, -1)
    x1 = torch.matmul(attn, v0)
    _62 = torch.contiguous(torch.transpose(x1, 1, 2))
    x2 = torch.view(_62, [n_batch0, -1, 512])
    _63 = torch.add(torch.matmul(x2, CONSTANTS.c16), CONSTANTS.c17)
    x3 = torch.add(xs0, _63)
    x4 = torch.layer_norm(x3, [512], CONSTANTS.c18, CONSTANTS.c19, 9.9999999999999998e-13)
    _64 = torch.add(torch.matmul(x4, CONSTANTS.c20), CONSTANTS.c21)
    _65 = torch.matmul(torch.silu(_64), CONSTANTS.c22)
    _66 = torch.mul(torch.add(_65, CONSTANTS.c23), 1.)
    x5 = torch.add(x3, _66)
    x6 = torch.layer_norm(x5, [512], CONSTANTS.c24, CONSTANTS.c25, 9.9999999999999998e-13)
    n_batch1 = torch.size(x6, 0)
    _67 = torch.add(torch.matmul(x6, CONSTANTS.c26), CONSTANTS.c27)
    _68 = torch.slice(_67, -1, 1024, 1536)
    _69 = torch.slice(_67, -1, 512, 1024)
    _70 = torch.slice(_67, -1, 0, 512)
    _71 = [n_batch1, -1, 8, 64]
    q2 = torch.view(_70, _71)
    k1 = torch.view(_69, _71)
    v1 = torch.view(_68, _71)
    q3 = torch.transpose(q2, 1, 2)
    k2 = torch.transpose(k1, 1, 2)
    v2 = torch.transpose(v1, 1, 2)
    q4 = torch.transpose(q3, 1, 2)
    _72 = torch.matmul(pos_emb, CONSTANTS.c28)
    p1 = torch.view(_72, _46)
    p2 = torch.transpose(p1, 1, 2)
    q_with_bias_u0 = torch.transpose(torch.add(q4, CONSTANTS.c29), 1, 2)
    q_with_bias_v0 = torch.transpose(torch.add(q4, CONSTANTS.c30), 1, 2)
    matrix_ac0 = torch.matmul(q_with_bias_u0, torch.transpose(k2, -2, -1))
    matrix_bd2 = torch.matmul(q_with_bias_v0, torch.transpose(p2, -2, -1))
    _73 = torch.size(matrix_ac0)
    _74 = torch.size(matrix_bd2)
    if torch.ne(_73, _74):
      _75 = _74[0]
      _76 = _74[1]
      _77 = _74[2]
      _78 = ops.prim.device(matrix_bd2)
      _79 = ops.prim.dtype(matrix_bd2)
      zero_pad0 = torch.zeros([_75, _76, _77, 1], dtype=_79, layout=None, device=_78)
      x_padded1 = torch.cat([zero_pad0, matrix_bd2], -1)
      _80 = torch.add(torch.size(matrix_bd2, 3), 1)
      _81 = [_75, _76, _80, torch.size(matrix_bd2, 2)]
      x_padded2 = torch.view(x_padded1, _81)
      _82 = torch.slice(torch.slice(x_padded2), 1)
      _83 = torch.view_as(torch.slice(_82, 2, 1), matrix_bd2)
      _84 = torch.slice(torch.slice(torch.slice(_83), 1), 2)
      _85 = torch.floordiv(torch.size(matrix_bd2, -1), 2)
      matrix_bd4 = torch.slice(_84, 3, None, torch.add(_85, 1))
      matrix_bd3 = matrix_bd4
    else:
      matrix_bd3 = matrix_bd2
    scores1 = torch.div(torch.add(matrix_ac0, matrix_bd3), 8.)
    n_batch2 = torch.size(v2, 0)
    _86 = torch.gt(torch.size(chunk_masks1, 2), 0)
    if _86:
      mask2 = torch.eq(torch.unsqueeze(chunk_masks1, 1), 0)
      _87 = torch.slice(torch.slice(mask2), 1)
      mask3 = torch.slice(torch.slice(_87, 2), 3, None, torch.size(scores1, -1))
      scores2 = torch.masked_fill(scores1, mask3, -inf)
      attn2 = torch.masked_fill(torch.softmax(scores2, -1), mask3, 0.)
      attn1 = attn2
    else:
      attn1 = torch.softmax(scores1, -1)
    x7 = torch.matmul(attn1, v2)
    _88 = torch.contiguous(torch.transpose(x7, 1, 2))
    x8 = torch.view(_88, [n_batch2, -1, 512])
    _89 = torch.add(torch.matmul(x8, CONSTANTS.c31), CONSTANTS.c32)
    x9 = torch.add(x5, _89)
    x10 = torch.layer_norm(x9, [512], CONSTANTS.c33, CONSTANTS.c34, 9.9999999999999998e-13)
    _90 = torch.add(torch.matmul(x10, CONSTANTS.c35), CONSTANTS.c36)
    _91 = torch.matmul(torch.silu(_90), CONSTANTS.c37)
    _92 = torch.mul(torch.add(_91, CONSTANTS.c38), 1.)
    x11 = torch.add(x9, _92)
    x12 = torch.layer_norm(x11, [512], CONSTANTS.c39, CONSTANTS.c40, 9.9999999999999998e-13)
    n_batch3 = torch.size(x12, 0)
    _93 = torch.add(torch.matmul(x12, CONSTANTS.c41), CONSTANTS.c42)
    _94 = torch.slice(_93, -1, 1024, 1536)
    _95 = torch.slice(_93, -1, 512, 1024)
    _96 = torch.slice(_93, -1, 0, 512)
    _97 = [n_batch3, -1, 8, 64]
    q5 = torch.view(_96, _97)
    k3 = torch.view(_95, _97)
    v3 = torch.view(_94, _97)
    q6 = torch.transpose(q5, 1, 2)
    k4 = torch.transpose(k3, 1, 2)
    v4 = torch.transpose(v3, 1, 2)
    q7 = torch.transpose(q6, 1, 2)
    _98 = torch.matmul(pos_emb, CONSTANTS.c43)
    p3 = torch.view(_98, _46)
    p4 = torch.transpose(p3, 1, 2)
    q_with_bias_u1 = torch.transpose(torch.add(q7, CONSTANTS.c44), 1, 2)
    q_with_bias_v1 = torch.transpose(torch.add(q7, CONSTANTS.c45), 1, 2)
    matrix_ac1 = torch.matmul(q_with_bias_u1, torch.transpose(k4, -2, -1))
    matrix_bd5 = torch.matmul(q_with_bias_v1, torch.transpose(p4, -2, -1))
    _99 = torch.size(matrix_ac1)
    _100 = torch.size(matrix_bd5)
    if torch.ne(_99, _100):
      _101 = _100[0]
      _102 = _100[1]
      _103 = _100[2]
      _104 = ops.prim.device(matrix_bd5)
      _105 = ops.prim.dtype(matrix_bd5)
      zero_pad1 = torch.zeros([_101, _102, _103, 1], dtype=_105, layout=None, device=_104)
      x_padded3 = torch.cat([zero_pad1, matrix_bd5], -1)
      _106 = torch.add(torch.size(matrix_bd5, 3), 1)
      _107 = [_101, _102, _106, torch.size(matrix_bd5, 2)]
      x_padded4 = torch.view(x_padded3, _107)
      _108 = torch.slice(torch.slice(x_padded4), 1)
      _109 = torch.view_as(torch.slice(_108, 2, 1), matrix_bd5)
      _110 = torch.slice(torch.slice(torch.slice(_109), 1), 2)
      _111 = torch.floordiv(torch.size(matrix_bd5, -1), 2)
      matrix_bd7 = torch.slice(_110, 3, None, torch.add(_111, 1))
      matrix_bd6 = matrix_bd7
    else:
      matrix_bd6 = matrix_bd5
    scores3 = torch.div(torch.add(matrix_ac1, matrix_bd6), 8.)
    n_batch4 = torch.size(v4, 0)
    _112 = torch.gt(torch.size(chunk_masks1, 2), 0)
    if _112:
      mask4 = torch.eq(torch.unsqueeze(chunk_masks1, 1), 0)
      _113 = torch.slice(torch.slice(mask4), 1)
      mask5 = torch.slice(torch.slice(_113, 2), 3, None, torch.size(scores3, -1))
      scores4 = torch.masked_fill(scores3, mask5, -inf)
      attn4 = torch.masked_fill(torch.softmax(scores4, -1), mask5, 0.)
      attn3 = attn4
    else:
      attn3 = torch.softmax(scores3, -1)
    x13 = torch.matmul(attn3, v4)
    _114 = torch.contiguous(torch.transpose(x13, 1, 2))
    x14 = torch.view(_114, [n_batch4, -1, 512])
    _115 = torch.add(torch.matmul(x14, CONSTANTS.c46), CONSTANTS.c47)
    x15 = torch.add(x11, _115)
    x16 = torch.layer_norm(x15, [512], CONSTANTS.c48, CONSTANTS.c49, 9.9999999999999998e-13)
    _116 = torch.add(torch.matmul(x16, CONSTANTS.c50), CONSTANTS.c51)
    _117 = torch.matmul(torch.silu(_116), CONSTANTS.c52)
    _118 = torch.mul(torch.add(_117, CONSTANTS.c53), 1.)
    x17 = torch.add(x15, _118)
    x18 = torch.layer_norm(x17, [512], CONSTANTS.c54, CONSTANTS.c55, 9.9999999999999998e-13)
    n_batch5 = torch.size(x18, 0)
    _119 = torch.add(torch.matmul(x18, CONSTANTS.c56), CONSTANTS.c57)
    _120 = torch.slice(_119, -1, 1024, 1536)
    _121 = torch.slice(_119, -1, 512, 1024)
    _122 = torch.slice(_119, -1, 0, 512)
    _123 = [n_batch5, -1, 8, 64]
    q8 = torch.view(_122, _123)
    k5 = torch.view(_121, _123)
    v5 = torch.view(_120, _123)
    q9 = torch.transpose(q8, 1, 2)
    k6 = torch.transpose(k5, 1, 2)
    v6 = torch.transpose(v5, 1, 2)
    q10 = torch.transpose(q9, 1, 2)
    _124 = torch.matmul(pos_emb, CONSTANTS.c58)
    p5 = torch.view(_124, _46)
    p6 = torch.transpose(p5, 1, 2)
    q_with_bias_u2 = torch.transpose(torch.add(q10, CONSTANTS.c59), 1, 2)
    q_with_bias_v2 = torch.transpose(torch.add(q10, CONSTANTS.c60), 1, 2)
    matrix_ac2 = torch.matmul(q_with_bias_u2, torch.transpose(k6, -2, -1))
    matrix_bd8 = torch.matmul(q_with_bias_v2, torch.transpose(p6, -2, -1))
    _125 = torch.size(matrix_ac2)
    _126 = torch.size(matrix_bd8)
    if torch.ne(_125, _126):
      _127 = _126[0]
      _128 = _126[1]
      _129 = _126[2]
      _130 = ops.prim.device(matrix_bd8)
      _131 = ops.prim.dtype(matrix_bd8)
      zero_pad2 = torch.zeros([_127, _128, _129, 1], dtype=_131, layout=None, device=_130)
      x_padded5 = torch.cat([zero_pad2, matrix_bd8], -1)
      _132 = torch.add(torch.size(matrix_bd8, 3), 1)
      _133 = [_127, _128, _132, torch.size(matrix_bd8, 2)]
      x_padded6 = torch.view(x_padded5, _133)
      _134 = torch.slice(torch.slice(x_padded6), 1)
      _135 = torch.view_as(torch.slice(_134, 2, 1), matrix_bd8)
      _136 = torch.slice(torch.slice(torch.slice(_135), 1), 2)
      _137 = torch.floordiv(torch.size(matrix_bd8, -1), 2)
      matrix_bd10 = torch.slice(_136, 3, None, torch.add(_137, 1))
      matrix_bd9 = matrix_bd10
    else:
      matrix_bd9 = matrix_bd8
    scores5 = torch.div(torch.add(matrix_ac2, matrix_bd9), 8.)
    n_batch6 = torch.size(v6, 0)
    _138 = torch.gt(torch.size(chunk_masks1, 2), 0)
    if _138:
      mask6 = torch.eq(torch.unsqueeze(chunk_masks1, 1), 0)
      _139 = torch.slice(torch.slice(mask6), 1)
      mask7 = torch.slice(torch.slice(_139, 2), 3, None, torch.size(scores5, -1))
      scores6 = torch.masked_fill(scores5, mask7, -inf)
      attn6 = torch.masked_fill(torch.softmax(scores6, -1), mask7, 0.)
      attn5 = attn6
    else:
      attn5 = torch.softmax(scores5, -1)
    x19 = torch.matmul(attn5, v6)
    _140 = torch.contiguous(torch.transpose(x19, 1, 2))
    x20 = torch.view(_140, [n_batch6, -1, 512])
    _141 = torch.add(torch.matmul(x20, CONSTANTS.c61), CONSTANTS.c62)
    x21 = torch.add(x17, _141)
    x22 = torch.layer_norm(x21, [512], CONSTANTS.c63, CONSTANTS.c64, 9.9999999999999998e-13)
    _142 = torch.add(torch.matmul(x22, CONSTANTS.c65), CONSTANTS.c66)
    _143 = torch.matmul(torch.silu(_142), CONSTANTS.c67)
    _144 = torch.mul(torch.add(_143, CONSTANTS.c68), 1.)
    x23 = torch.add(x21, _144)
    x24 = torch.layer_norm(x23, [512], CONSTANTS.c69, CONSTANTS.c70, 9.9999999999999998e-13)
    n_batch7 = torch.size(x24, 0)
    _145 = torch.add(torch.matmul(x24, CONSTANTS.c71), CONSTANTS.c72)
    _146 = torch.slice(_145, -1, 1024, 1536)
    _147 = torch.slice(_145, -1, 512, 1024)
    _148 = torch.slice(_145, -1, 0, 512)
    _149 = [n_batch7, -1, 8, 64]
    q11 = torch.view(_148, _149)
    k7 = torch.view(_147, _149)
    v7 = torch.view(_146, _149)
    q12 = torch.transpose(q11, 1, 2)
    k8 = torch.transpose(k7, 1, 2)
    v8 = torch.transpose(v7, 1, 2)
    q13 = torch.transpose(q12, 1, 2)
    _150 = torch.matmul(pos_emb, CONSTANTS.c73)
    p7 = torch.view(_150, _46)
    p8 = torch.transpose(p7, 1, 2)
    q_with_bias_u3 = torch.transpose(torch.add(q13, CONSTANTS.c74), 1, 2)
    q_with_bias_v3 = torch.transpose(torch.add(q13, CONSTANTS.c75), 1, 2)
    matrix_ac3 = torch.matmul(q_with_bias_u3, torch.transpose(k8, -2, -1))
    matrix_bd11 = torch.matmul(q_with_bias_v3, torch.transpose(p8, -2, -1))
    _151 = torch.size(matrix_ac3)
    _152 = torch.size(matrix_bd11)
    if torch.ne(_151, _152):
      _153 = _152[0]
      _154 = _152[1]
      _155 = _152[2]
      _156 = ops.prim.device(matrix_bd11)
      _157 = ops.prim.dtype(matrix_bd11)
      zero_pad3 = torch.zeros([_153, _154, _155, 1], dtype=_157, layout=None, device=_156)
      x_padded7 = torch.cat([zero_pad3, matrix_bd11], -1)
      _158 = torch.add(torch.size(matrix_bd11, 3), 1)
      _159 = [_153, _154, _158, torch.size(matrix_bd11, 2)]
      x_padded8 = torch.view(x_padded7, _159)
      _160 = torch.slice(torch.slice(x_padded8), 1)
      _161 = torch.view_as(torch.slice(_160, 2, 1), matrix_bd11)
      _162 = torch.slice(torch.slice(torch.slice(_161), 1), 2)
      _163 = torch.floordiv(torch.size(matrix_bd11, -1), 2)
      matrix_bd13 = torch.slice(_162, 3, None, torch.add(_163, 1))
      matrix_bd12 = matrix_bd13
    else:
      matrix_bd12 = matrix_bd11
    scores7 = torch.div(torch.add(matrix_ac3, matrix_bd12), 8.)
    n_batch8 = torch.size(v8, 0)
    _164 = torch.gt(torch.size(chunk_masks1, 2), 0)
    if _164:
      mask8 = torch.eq(torch.unsqueeze(chunk_masks1, 1), 0)
      _165 = torch.slice(torch.slice(mask8), 1)
      mask9 = torch.slice(torch.slice(_165, 2), 3, None, torch.size(scores7, -1))
      scores8 = torch.masked_fill(scores7, mask9, -inf)
      attn8 = torch.masked_fill(torch.softmax(scores8, -1), mask9, 0.)
      attn7 = attn8
    else:
      attn7 = torch.softmax(scores7, -1)
    x25 = torch.matmul(attn7, v8)
    _166 = torch.contiguous(torch.transpose(x25, 1, 2))
    x26 = torch.view(_166, [n_batch8, -1, 512])
    _167 = torch.add(torch.matmul(x26, CONSTANTS.c76), CONSTANTS.c77)
    x27 = torch.add(x23, _167)
    x28 = torch.layer_norm(x27, [512], CONSTANTS.c78, CONSTANTS.c79, 9.9999999999999998e-13)
    _168 = torch.add(torch.matmul(x28, CONSTANTS.c80), CONSTANTS.c81)
    _169 = torch.matmul(torch.silu(_168), CONSTANTS.c82)
    _170 = torch.mul(torch.add(_169, CONSTANTS.c83), 1.)
    x29 = torch.add(x27, _170)
    x30 = torch.layer_norm(x29, [512], CONSTANTS.c84, CONSTANTS.c85, 9.9999999999999998e-13)
    n_batch9 = torch.size(x30, 0)
    _171 = torch.add(torch.matmul(x30, CONSTANTS.c86), CONSTANTS.c87)
    _172 = torch.slice(_171, -1, 1024, 1536)
    _173 = torch.slice(_171, -1, 512, 1024)
    _174 = torch.slice(_171, -1, 0, 512)
    _175 = [n_batch9, -1, 8, 64]
    q14 = torch.view(_174, _175)
    k9 = torch.view(_173, _175)
    v9 = torch.view(_172, _175)
    q15 = torch.transpose(q14, 1, 2)
    k10 = torch.transpose(k9, 1, 2)
    v10 = torch.transpose(v9, 1, 2)
    q16 = torch.transpose(q15, 1, 2)
    _176 = torch.matmul(pos_emb, CONSTANTS.c88)
    p9 = torch.view(_176, _46)
    p10 = torch.transpose(p9, 1, 2)
    q_with_bias_u4 = torch.transpose(torch.add(q16, CONSTANTS.c89), 1, 2)
    q_with_bias_v4 = torch.transpose(torch.add(q16, CONSTANTS.c90), 1, 2)
    matrix_ac4 = torch.matmul(q_with_bias_u4, torch.transpose(k10, -2, -1))
    matrix_bd14 = torch.matmul(q_with_bias_v4, torch.transpose(p10, -2, -1))
    _177 = torch.size(matrix_ac4)
    _178 = torch.size(matrix_bd14)
    if torch.ne(_177, _178):
      _179 = _178[0]
      _180 = _178[1]
      _181 = _178[2]
      _182 = ops.prim.device(matrix_bd14)
      _183 = ops.prim.dtype(matrix_bd14)
      zero_pad4 = torch.zeros([_179, _180, _181, 1], dtype=_183, layout=None, device=_182)
      x_padded9 = torch.cat([zero_pad4, matrix_bd14], -1)
      _184 = torch.add(torch.size(matrix_bd14, 3), 1)
      _185 = [_179, _180, _184, torch.size(matrix_bd14, 2)]
      x_padded10 = torch.view(x_padded9, _185)
      _186 = torch.slice(torch.slice(x_padded10), 1)
      _187 = torch.view_as(torch.slice(_186, 2, 1), matrix_bd14)
      _188 = torch.slice(torch.slice(torch.slice(_187), 1), 2)
      _189 = torch.floordiv(torch.size(matrix_bd14, -1), 2)
      matrix_bd16 = torch.slice(_188, 3, None, torch.add(_189, 1))
      matrix_bd15 = matrix_bd16
    else:
      matrix_bd15 = matrix_bd14
    scores9 = torch.div(torch.add(matrix_ac4, matrix_bd15), 8.)
    n_batch10 = torch.size(v10, 0)
    _190 = torch.gt(torch.size(chunk_masks1, 2), 0)
    if _190:
      mask10 = torch.eq(torch.unsqueeze(chunk_masks1, 1), 0)
      _191 = torch.slice(torch.slice(mask10), 1)
      mask11 = torch.slice(torch.slice(_191, 2), 3, None, torch.size(scores9, -1))
      scores10 = torch.masked_fill(scores9, mask11, -inf)
      attn10 = torch.masked_fill(torch.softmax(scores10, -1), mask11, 0.)
      attn9 = attn10
    else:
      attn9 = torch.softmax(scores9, -1)
    x31 = torch.matmul(attn9, v10)
    _192 = torch.contiguous(torch.transpose(x31, 1, 2))
    x32 = torch.view(_192, [n_batch10, -1, 512])
    _193 = torch.add(torch.matmul(x32, CONSTANTS.c91), CONSTANTS.c92)
    x33 = torch.add(x29, _193)
    x34 = torch.layer_norm(x33, [512], CONSTANTS.c93, CONSTANTS.c94, 9.9999999999999998e-13)
    _194 = torch.add(torch.matmul(x34, CONSTANTS.c95), CONSTANTS.c96)
    _195 = torch.matmul(torch.silu(_194), CONSTANTS.c97)
    _196 = torch.mul(torch.add(_195, CONSTANTS.c98), 1.)
    x35 = torch.add(x33, _196)
    xs1 = torch.contiguous(torch.transpose(x35, 1, 2))
    _197 = uninitialized(Tensor)
    _198 = torch.dim(xs1)
    dim = torch.sub(_198, 2)
    scale_factors = annotate(List[float], [])
    for _199 in range(dim):
      _200 = torch.append(scale_factors, 2.)
    if torch.eq(_198, 3):
      _201 = torch.upsample_nearest1d(xs1, None, scale_factors)
      outputs4 = _201
    else:
      if torch.eq(_198, 4):
        _203 = torch.upsample_nearest2d(xs1, None, scale_factors)
        _202 = _203
      else:
        if torch.eq(_198, 5):
          _205 = torch.upsample_nearest3d(xs1, None, scale_factors)
          _204 = _205
        else:
          _206 = torch.format(_0, _198, "nearest")
          ops.prim.RaiseException(_206, "builtins.NotImplementedError")
          _204 = _197
        _202 = _204
      outputs4 = _202
    outputs5 = torch.pad(outputs4, [4, 0], "constant", 0.)
    outputs6 = torch.conv1d(outputs5, CONSTANTS.c99, CONSTANTS.c100)
    _207 = torch.mul(xs_lens, 2)
    xs2 = torch.contiguous(torch.transpose(outputs6, 1, 2))
    T0 = torch.size(xs2, 1)
    batch_size0 = torch.size(_207, 0)
    if torch.gt(T0, 0):
      max_len0 = T0
    else:
      max_len0 = torch.item(torch.max(_207))
    seq_range0 = torch.arange(0, max_len0, dtype=4, layout=None, device=ops.prim.device(_207))
    seq_range_expand0 = torch.expand(torch.unsqueeze(seq_range0, 0), [batch_size0, int(max_len0)])
    seq_length_expand0 = torch.unsqueeze(_207, -1)
    mask12 = torch.ge(seq_range_expand0, seq_length_expand0)
    masks0 = torch.bitwise_not(torch.unsqueeze(mask12, 1))
    up_embed = self.up_embed
    _208 = torch.add(torch.matmul(xs2, CONSTANTS.c101), CONSTANTS.c102)
    input0 = torch.layer_norm(_208, [512], CONSTANTS.c103, CONSTANTS.c104)
    pos_enc0 = up_embed.pos_enc
    pe7 = pos_enc0.pe
    _209 = torch.size(pe7, 1)
    _210 = torch.size(input0, 1)
    _211 = torch.ge(_209, torch.sub(torch.mul(_210, 2), 1))
    if _211:
      pe8 = pos_enc0.pe
      _212 = ops.prim.dtype(pe8)
      _213 = ops.prim.dtype(input0)
      if torch.ne(_212, _213):
        _214 = True
      else:
        pe9 = pos_enc0.pe
        _215 = torch.ne(ops.prim.device(pe9), ops.prim.device(input0))
        _214 = _215
      if _214:
        pe10 = pos_enc0.pe
        _216 = torch.to(pe10, ops.prim.device(input0), _213)
        pos_enc0.pe = _216
      else:
        pass
    else:
      _217 = [_210, 512]
      pe_positive1 = torch.zeros(_217)
      pe_negative1 = torch.zeros(_217)
      position0 = torch.unsqueeze(torch.arange(0, _210, dtype=6), 1)
      _218 = torch.mul(position0, CONSTANTS.c4)
      _219 = torch.sin(_218)
      _220 = torch.slice(torch.slice(pe_positive1), 1, 0, None, 2)
      _221 = torch.copy_(_220, _219)
      _222 = torch.cos(_218)
      _223 = torch.slice(torch.slice(pe_positive1), 1, 1, None, 2)
      _224 = torch.copy_(_223, _222)
      _225 = torch.mul(torch.mul(position0, -1), CONSTANTS.c4)
      _226 = torch.sin(_225)
      _227 = torch.slice(torch.slice(pe_negative1), 1, 0, None, 2)
      _228 = torch.copy_(_227, _226)
      _229 = torch.cos(_225)
      _230 = torch.slice(torch.slice(pe_negative1), 1, 1, None, 2)
      _231 = torch.copy_(_230, _229)
      pe_positive2 = torch.unsqueeze(torch.flip(pe_positive1, [0]), 0)
      pe_negative2 = torch.unsqueeze(torch.slice(pe_negative1, 0, 1), 0)
      pe11 = torch.cat([pe_positive2, pe_negative2], 1)
      _232 = torch.to(pe11, ops.prim.device(input0), ops.prim.dtype(input0))
      pos_enc0.pe = _232
    x36 = torch.mul(input0, 22.627416997969522)
    _233 = torch.size(x36, 1)
    pe12 = pos_enc0.pe
    _234 = torch.slice(pe12)
    pe13 = pos_enc0.pe
    _235 = torch.floordiv(torch.size(pe13, 1), 2)
    _236 = torch.add(torch.sub(_235, _233), 1)
    pe14 = pos_enc0.pe
    _237 = torch.floordiv(torch.size(pe14, 1), 2)
    pos_emb0 = torch.slice(_234, 1, _236, torch.add(_237, _233))
    _238 = torch.size(x36, 1)
    _239 = ops.prim.device(x36)
    chunk_masks2 = torch.zeros([_238, _238], dtype=11, layout=None, device=_239)
    for i0 in range(_238):
      _240 = torch.lt(num_decoding_left_chunks, 0)
      if _240:
        start1 = 0
      else:
        _241 = torch.sub(torch.floordiv(i0, 100), num_decoding_left_chunks)
        start2 = ops.prim.max(torch.mul(_241, 100), 0)
        start1 = start2
      _242 = torch.add(torch.floordiv(i0, 100), 1)
      ending0 = ops.prim.min(torch.mul(_242, 100), _238)
      _243 = torch.slice(torch.select(chunk_masks2, 0, i0), 0, start1, ending0)
      _244 = torch.tensor(True, dtype=ops.prim.dtype(_243), device=ops.prim.device(_243))
      _245 = torch.copy_(_243, _244)
    chunk_masks3 = torch.unsqueeze(chunk_masks2, 0)
    chunk_masks4 = torch.__and__(masks0, chunk_masks3)
    x37 = torch.layer_norm(x36, [512], CONSTANTS.c105, CONSTANTS.c106, 9.9999999999999998e-13)
    n_batch11 = torch.size(x37, 0)
    _246 = torch.add(torch.matmul(x37, CONSTANTS.c107), CONSTANTS.c108)
    _247 = torch.slice(_246, -1, 1024, 1536)
    _248 = torch.slice(_246, -1, 512, 1024)
    _249 = torch.slice(_246, -1, 0, 512)
    _250 = [n_batch11, -1, 8, 64]
    q17 = torch.view(_249, _250)
    k11 = torch.view(_248, _250)
    v11 = torch.view(_247, _250)
    q18 = torch.transpose(q17, 1, 2)
    k12 = torch.transpose(k11, 1, 2)
    v12 = torch.transpose(v11, 1, 2)
    q19 = torch.transpose(q18, 1, 2)
    n_batch_pos0 = torch.size(pos_emb0, 0)
    _251 = torch.matmul(pos_emb0, CONSTANTS.c109)
    _252 = [n_batch_pos0, -1, 8, 64]
    p11 = torch.view(_251, _252)
    p12 = torch.transpose(p11, 1, 2)
    q_with_bias_u5 = torch.transpose(torch.add(q19, CONSTANTS.c110), 1, 2)
    q_with_bias_v5 = torch.transpose(torch.add(q19, CONSTANTS.c111), 1, 2)
    matrix_ac5 = torch.matmul(q_with_bias_u5, torch.transpose(k12, -2, -1))
    matrix_bd17 = torch.matmul(q_with_bias_v5, torch.transpose(p12, -2, -1))
    _253 = torch.size(matrix_ac5)
    _254 = torch.size(matrix_bd17)
    if torch.ne(_253, _254):
      _255 = _254[0]
      _256 = _254[1]
      _257 = _254[2]
      _258 = ops.prim.device(matrix_bd17)
      _259 = ops.prim.dtype(matrix_bd17)
      zero_pad5 = torch.zeros([_255, _256, _257, 1], dtype=_259, layout=None, device=_258)
      x_padded11 = torch.cat([zero_pad5, matrix_bd17], -1)
      _260 = torch.add(torch.size(matrix_bd17, 3), 1)
      _261 = [_255, _256, _260, torch.size(matrix_bd17, 2)]
      x_padded12 = torch.view(x_padded11, _261)
      _262 = torch.slice(torch.slice(x_padded12), 1)
      _263 = torch.view_as(torch.slice(_262, 2, 1), matrix_bd17)
      _264 = torch.slice(torch.slice(torch.slice(_263), 1), 2)
      _265 = torch.floordiv(torch.size(matrix_bd17, -1), 2)
      matrix_bd19 = torch.slice(_264, 3, None, torch.add(_265, 1))
      matrix_bd18 = matrix_bd19
    else:
      matrix_bd18 = matrix_bd17
    scores11 = torch.div(torch.add(matrix_ac5, matrix_bd18), 8.)
    n_batch12 = torch.size(v12, 0)
    _266 = torch.gt(torch.size(chunk_masks4, 2), 0)
    if _266:
      mask13 = torch.eq(torch.unsqueeze(chunk_masks4, 1), 0)
      _267 = torch.slice(torch.slice(mask13), 1)
      mask14 = torch.slice(torch.slice(_267, 2), 3, None, torch.size(scores11, -1))
      scores12 = torch.masked_fill(scores11, mask14, -inf)
      attn12 = torch.masked_fill(torch.softmax(scores12, -1), mask14, 0.)
      attn11 = attn12
    else:
      attn11 = torch.softmax(scores11, -1)
    x38 = torch.matmul(attn11, v12)
    _268 = torch.contiguous(torch.transpose(x38, 1, 2))
    x39 = torch.view(_268, [n_batch12, -1, 512])
    _269 = torch.add(torch.matmul(x39, CONSTANTS.c112), CONSTANTS.c113)
    x40 = torch.add(x36, _269)
    x41 = torch.layer_norm(x40, [512], CONSTANTS.c114, CONSTANTS.c115, 9.9999999999999998e-13)
    _270 = torch.add(torch.matmul(x41, CONSTANTS.c116), CONSTANTS.c117)
    _271 = torch.matmul(torch.silu(_270), CONSTANTS.c118)
    _272 = torch.mul(torch.add(_271, CONSTANTS.c119), 1.)
    x42 = torch.add(x40, _272)
    x43 = torch.layer_norm(x42, [512], CONSTANTS.c120, CONSTANTS.c121, 9.9999999999999998e-13)
    n_batch13 = torch.size(x43, 0)
    _273 = torch.add(torch.matmul(x43, CONSTANTS.c122), CONSTANTS.c123)
    _274 = torch.slice(_273, -1, 1024, 1536)
    _275 = torch.slice(_273, -1, 512, 1024)
    _276 = torch.slice(_273, -1, 0, 512)
    _277 = [n_batch13, -1, 8, 64]
    q20 = torch.view(_276, _277)
    k13 = torch.view(_275, _277)
    v13 = torch.view(_274, _277)
    q21 = torch.transpose(q20, 1, 2)
    k14 = torch.transpose(k13, 1, 2)
    v14 = torch.transpose(v13, 1, 2)
    q22 = torch.transpose(q21, 1, 2)
    _278 = torch.matmul(pos_emb0, CONSTANTS.c124)
    p13 = torch.view(_278, _252)
    p14 = torch.transpose(p13, 1, 2)
    q_with_bias_u6 = torch.transpose(torch.add(q22, CONSTANTS.c125), 1, 2)
    q_with_bias_v6 = torch.transpose(torch.add(q22, CONSTANTS.c126), 1, 2)
    matrix_ac6 = torch.matmul(q_with_bias_u6, torch.transpose(k14, -2, -1))
    matrix_bd20 = torch.matmul(q_with_bias_v6, torch.transpose(p14, -2, -1))
    _279 = torch.size(matrix_ac6)
    _280 = torch.size(matrix_bd20)
    if torch.ne(_279, _280):
      _281 = _280[0]
      _282 = _280[1]
      _283 = _280[2]
      _284 = ops.prim.device(matrix_bd20)
      _285 = ops.prim.dtype(matrix_bd20)
      zero_pad6 = torch.zeros([_281, _282, _283, 1], dtype=_285, layout=None, device=_284)
      x_padded13 = torch.cat([zero_pad6, matrix_bd20], -1)
      _286 = torch.add(torch.size(matrix_bd20, 3), 1)
      _287 = [_281, _282, _286, torch.size(matrix_bd20, 2)]
      x_padded14 = torch.view(x_padded13, _287)
      _288 = torch.slice(torch.slice(x_padded14), 1)
      _289 = torch.view_as(torch.slice(_288, 2, 1), matrix_bd20)
      _290 = torch.slice(torch.slice(torch.slice(_289), 1), 2)
      _291 = torch.floordiv(torch.size(matrix_bd20, -1), 2)
      matrix_bd22 = torch.slice(_290, 3, None, torch.add(_291, 1))
      matrix_bd21 = matrix_bd22
    else:
      matrix_bd21 = matrix_bd20
    scores13 = torch.div(torch.add(matrix_ac6, matrix_bd21), 8.)
    n_batch14 = torch.size(v14, 0)
    _292 = torch.gt(torch.size(chunk_masks4, 2), 0)
    if _292:
      mask15 = torch.eq(torch.unsqueeze(chunk_masks4, 1), 0)
      _293 = torch.slice(torch.slice(mask15), 1)
      mask16 = torch.slice(torch.slice(_293, 2), 3, None, torch.size(scores13, -1))
      scores14 = torch.masked_fill(scores13, mask16, -inf)
      attn14 = torch.masked_fill(torch.softmax(scores14, -1), mask16, 0.)
      attn13 = attn14
    else:
      attn13 = torch.softmax(scores13, -1)
    x44 = torch.matmul(attn13, v14)
    _294 = torch.contiguous(torch.transpose(x44, 1, 2))
    x45 = torch.view(_294, [n_batch14, -1, 512])
    _295 = torch.add(torch.matmul(x45, CONSTANTS.c127), CONSTANTS.c128)
    x46 = torch.add(x42, _295)
    x47 = torch.layer_norm(x46, [512], CONSTANTS.c129, CONSTANTS.c130, 9.9999999999999998e-13)
    _296 = torch.add(torch.matmul(x47, CONSTANTS.c131), CONSTANTS.c132)
    _297 = torch.matmul(torch.silu(_296), CONSTANTS.c133)
    _298 = torch.mul(torch.add(_297, CONSTANTS.c134), 1.)
    x48 = torch.add(x46, _298)
    x49 = torch.layer_norm(x48, [512], CONSTANTS.c135, CONSTANTS.c136, 9.9999999999999998e-13)
    n_batch15 = torch.size(x49, 0)
    _299 = torch.add(torch.matmul(x49, CONSTANTS.c137), CONSTANTS.c138)
    _300 = torch.slice(_299, -1, 1024, 1536)
    _301 = torch.slice(_299, -1, 512, 1024)
    _302 = torch.slice(_299, -1, 0, 512)
    _303 = [n_batch15, -1, 8, 64]
    q23 = torch.view(_302, _303)
    k15 = torch.view(_301, _303)
    v15 = torch.view(_300, _303)
    q24 = torch.transpose(q23, 1, 2)
    k16 = torch.transpose(k15, 1, 2)
    v16 = torch.transpose(v15, 1, 2)
    q25 = torch.transpose(q24, 1, 2)
    _304 = torch.matmul(pos_emb0, CONSTANTS.c139)
    p15 = torch.view(_304, _252)
    p16 = torch.transpose(p15, 1, 2)
    q_with_bias_u7 = torch.transpose(torch.add(q25, CONSTANTS.c140), 1, 2)
    q_with_bias_v7 = torch.transpose(torch.add(q25, CONSTANTS.c141), 1, 2)
    matrix_ac7 = torch.matmul(q_with_bias_u7, torch.transpose(k16, -2, -1))
    matrix_bd23 = torch.matmul(q_with_bias_v7, torch.transpose(p16, -2, -1))
    _305 = torch.size(matrix_ac7)
    _306 = torch.size(matrix_bd23)
    if torch.ne(_305, _306):
      _307 = _306[0]
      _308 = _306[1]
      _309 = _306[2]
      _310 = ops.prim.device(matrix_bd23)
      _311 = ops.prim.dtype(matrix_bd23)
      zero_pad7 = torch.zeros([_307, _308, _309, 1], dtype=_311, layout=None, device=_310)
      x_padded15 = torch.cat([zero_pad7, matrix_bd23], -1)
      _312 = torch.add(torch.size(matrix_bd23, 3), 1)
      _313 = [_307, _308, _312, torch.size(matrix_bd23, 2)]
      x_padded16 = torch.view(x_padded15, _313)
      _314 = torch.slice(torch.slice(x_padded16), 1)
      _315 = torch.view_as(torch.slice(_314, 2, 1), matrix_bd23)
      _316 = torch.slice(torch.slice(torch.slice(_315), 1), 2)
      _317 = torch.floordiv(torch.size(matrix_bd23, -1), 2)
      matrix_bd25 = torch.slice(_316, 3, None, torch.add(_317, 1))
      matrix_bd24 = matrix_bd25
    else:
      matrix_bd24 = matrix_bd23
    scores15 = torch.div(torch.add(matrix_ac7, matrix_bd24), 8.)
    n_batch16 = torch.size(v16, 0)
    _318 = torch.gt(torch.size(chunk_masks4, 2), 0)
    if _318:
      mask17 = torch.eq(torch.unsqueeze(chunk_masks4, 1), 0)
      _319 = torch.slice(torch.slice(mask17), 1)
      mask18 = torch.slice(torch.slice(_319, 2), 3, None, torch.size(scores15, -1))
      scores16 = torch.masked_fill(scores15, mask18, -inf)
      attn16 = torch.masked_fill(torch.softmax(scores16, -1), mask18, 0.)
      attn15 = attn16
    else:
      attn15 = torch.softmax(scores15, -1)
    x50 = torch.matmul(attn15, v16)
    _320 = torch.contiguous(torch.transpose(x50, 1, 2))
    x51 = torch.view(_320, [n_batch16, -1, 512])
    _321 = torch.add(torch.matmul(x51, CONSTANTS.c142), CONSTANTS.c143)
    x52 = torch.add(x48, _321)
    x53 = torch.layer_norm(x52, [512], CONSTANTS.c144, CONSTANTS.c145, 9.9999999999999998e-13)
    _322 = torch.add(torch.matmul(x53, CONSTANTS.c146), CONSTANTS.c147)
    _323 = torch.matmul(torch.silu(_322), CONSTANTS.c148)
    _324 = torch.mul(torch.add(_323, CONSTANTS.c149), 1.)
    x54 = torch.add(x52, _324)
    x55 = torch.layer_norm(x54, [512], CONSTANTS.c150, CONSTANTS.c151, 9.9999999999999998e-13)
    n_batch17 = torch.size(x55, 0)
    _325 = torch.add(torch.matmul(x55, CONSTANTS.c152), CONSTANTS.c153)
    _326 = torch.slice(_325, -1, 1024, 1536)
    _327 = torch.slice(_325, -1, 512, 1024)
    _328 = torch.slice(_325, -1, 0, 512)
    _329 = [n_batch17, -1, 8, 64]
    q26 = torch.view(_328, _329)
    k17 = torch.view(_327, _329)
    v17 = torch.view(_326, _329)
    q27 = torch.transpose(q26, 1, 2)
    k18 = torch.transpose(k17, 1, 2)
    v18 = torch.transpose(v17, 1, 2)
    q28 = torch.transpose(q27, 1, 2)
    _330 = torch.matmul(pos_emb0, CONSTANTS.c154)
    p17 = torch.view(_330, _252)
    p18 = torch.transpose(p17, 1, 2)
    q_with_bias_u8 = torch.transpose(torch.add(q28, CONSTANTS.c155), 1, 2)
    q_with_bias_v8 = torch.transpose(torch.add(q28, CONSTANTS.c156), 1, 2)
    matrix_ac8 = torch.matmul(q_with_bias_u8, torch.transpose(k18, -2, -1))
    matrix_bd26 = torch.matmul(q_with_bias_v8, torch.transpose(p18, -2, -1))
    _331 = torch.size(matrix_ac8)
    _332 = torch.size(matrix_bd26)
    if torch.ne(_331, _332):
      _333 = _332[0]
      _334 = _332[1]
      _335 = _332[2]
      _336 = ops.prim.device(matrix_bd26)
      _337 = ops.prim.dtype(matrix_bd26)
      zero_pad8 = torch.zeros([_333, _334, _335, 1], dtype=_337, layout=None, device=_336)
      x_padded17 = torch.cat([zero_pad8, matrix_bd26], -1)
      _338 = torch.add(torch.size(matrix_bd26, 3), 1)
      _339 = [_333, _334, _338, torch.size(matrix_bd26, 2)]
      x_padded18 = torch.view(x_padded17, _339)
      _340 = torch.slice(torch.slice(x_padded18), 1)
      _341 = torch.view_as(torch.slice(_340, 2, 1), matrix_bd26)
      _342 = torch.slice(torch.slice(torch.slice(_341), 1), 2)
      _343 = torch.floordiv(torch.size(matrix_bd26, -1), 2)
      matrix_bd28 = torch.slice(_342, 3, None, torch.add(_343, 1))
      matrix_bd27 = matrix_bd28
    else:
      matrix_bd27 = matrix_bd26
    scores17 = torch.div(torch.add(matrix_ac8, matrix_bd27), 8.)
    n_batch18 = torch.size(v18, 0)
    _344 = torch.gt(torch.size(chunk_masks4, 2), 0)
    if _344:
      mask19 = torch.eq(torch.unsqueeze(chunk_masks4, 1), 0)
      _345 = torch.slice(torch.slice(mask19), 1)
      mask20 = torch.slice(torch.slice(_345, 2), 3, None, torch.size(scores17, -1))
      scores18 = torch.masked_fill(scores17, mask20, -inf)
      attn18 = torch.masked_fill(torch.softmax(scores18, -1), mask20, 0.)
      attn17 = attn18
    else:
      attn17 = torch.softmax(scores17, -1)
    x56 = torch.matmul(attn17, v18)
    _346 = torch.contiguous(torch.transpose(x56, 1, 2))
    x57 = torch.view(_346, [n_batch18, -1, 512])
    _347 = torch.add(torch.matmul(x57, CONSTANTS.c157), CONSTANTS.c158)
    x58 = torch.add(x54, _347)
    x59 = torch.layer_norm(x58, [512], CONSTANTS.c159, CONSTANTS.c160, 9.9999999999999998e-13)
    _348 = torch.add(torch.matmul(x59, CONSTANTS.c161), CONSTANTS.c162)
    _349 = torch.matmul(torch.silu(_348), CONSTANTS.c163)
    _350 = torch.mul(torch.add(_349, CONSTANTS.c164), 1.)
    x60 = torch.add(x58, _350)
    xs3 = torch.layer_norm(x60, [512], CONSTANTS.c165, CONSTANTS.c166)
    return (xs3, masks0)
