File size: 960 Bytes
1fbc47b
ee1a8a3
1fbc47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""Attention plotting utilities."""

from typing import Sequence

import matplotlib.pyplot as plt
import numpy as np


def plot_attention(matrix: np.ndarray, tokens: Sequence[str]) -> None:
    if matrix.ndim != 2:
        raise ValueError("Attention matrix must be 2-dimensional")
    token_count = len(tokens)
    if token_count == 0:
        raise ValueError("tokens must contain at least one item")
    if matrix.shape != (token_count, token_count):
        raise ValueError(
            f"Attention matrix shape {matrix.shape} must match (len(tokens), len(tokens)) = ({token_count}, {token_count})"
        )

    fig, ax = plt.subplots()
    heatmap = ax.imshow(matrix, cmap="viridis")
    ax.set_xticks(range(token_count))
    ax.set_xticklabels(tokens, rotation=90)
    ax.set_yticks(range(token_count))
    ax.set_yticklabels(tokens)
    cbar = fig.colorbar(heatmap, ax=ax)
    cbar.set_label("Attention Weight")
    fig.tight_layout()
    plt.show()