File size: 549 Bytes
1fbc47b
ee1a8a3
2286a5e
 
1fbc47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""Metric plotting helpers."""

from __future__ import annotations

import matplotlib.pyplot as plt


def plot_curve(
    values: list[float],
    title: str,
    *,
    save_path: str | None = None,
    show: bool = True,
) -> None:
    fig, ax = plt.subplots()
    ax.plot(values)
    ax.set_title(title)
    ax.set_xlabel("Step")
    ax.set_ylabel("Value")
    fig.tight_layout()

    if save_path is not None:
        fig.savefig(save_path)
        plt.close(fig)
        return

    if show:
        plt.show()
    else:
        plt.close(fig)