/* * Code imported via patch from https://github.com/triton-lang/triton/pull/6570, commit 2afae45951b74785b144151b31e91e6c82b0b02f. * Copyright (c) 2018-2022 Philippe Tillet, OpenAI. * Licensed under the MIT License. */ From 2afae45951b74785b144151b31e91e6c82b0b02f Mon Sep 17 00:00:00 2001 From: Han Zhu Date: Tue, 22 Apr 2025 18:42:23 -0700 Subject: [PATCH] [autotuner] Lazily initiailize do_bench --- diff --git a/a/autotuner.py b/b/autotuner.py index 0ee6bea09..b75c5e353 100644 --- a/a/autotuner.py +++ b/b/autotuner.py @@ -4,6 +4,7 @@ import builtins import os import time import inspect +from functools import cached_property from typing import Dict, Tuple, List, Optional from .jit import KernelInterface @@ -94,6 +95,7 @@ class Autotuner(KernelInterface): while not inspect.isfunction(self.base_fn): self.base_fn = self.base_fn.fn + self._do_bench = do_bench self.num_warmups = warmup self.num_reps = rep self.use_cuda_graph = use_cuda_graph @@ -107,7 +109,7 @@ class Autotuner(KernelInterface): stacklevel=1) if use_cuda_graph: from ..testing import do_bench_cudagraph - self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( + self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( kernel_call, rep=rep if rep is not None else 100, quantiles=quantiles, @@ -115,7 +117,7 @@ class Autotuner(KernelInterface): return import triton.testing - self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( + self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( kernel_call, warmup=warmup if warmup is not None else 25, rep=rep if rep is not None else 100, @@ -123,10 +125,11 @@ class Autotuner(KernelInterface): ) return - if do_bench is None: - self.do_bench = driver.active.get_benchmarker() - else: - self.do_bench = do_bench + @cached_property + def do_bench(self): + if self._do_bench is None: + return driver.active.get_benchmarker() + return self._do_bench def _bench(self, *args, config, **meta): from ..compiler.errors import CompileTimeAssertionFailure