LexiMind / src /visualization /attention.py
OliverPerrin's picture
Style: Apply ruff formatting
ee1a8a3
"""Attention plotting utilities."""
from typing import Sequence
import matplotlib.pyplot as plt
import numpy as np
def plot_attention(matrix: np.ndarray, tokens: Sequence[str]) -> None:
if matrix.ndim != 2:
raise ValueError("Attention matrix must be 2-dimensional")
token_count = len(tokens)
if token_count == 0:
raise ValueError("tokens must contain at least one item")
if matrix.shape != (token_count, token_count):
raise ValueError(
f"Attention matrix shape {matrix.shape} must match (len(tokens), len(tokens)) = ({token_count}, {token_count})"
)
fig, ax = plt.subplots()
heatmap = ax.imshow(matrix, cmap="viridis")
ax.set_xticks(range(token_count))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticks(range(token_count))
ax.set_yticklabels(tokens)
cbar = fig.colorbar(heatmap, ax=ax)
cbar.set_label("Attention Weight")
fig.tight_layout()
plt.show()