[FRONTEND] Refresh cache when the source code of outlined functions are changed (#590)
This commit is contained in:
@@ -236,8 +236,8 @@ def matmul_kernel(
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
# you can fuse arbitrary activation functions here
|
||||
# while the accumulator is still in FP32!
|
||||
if ACTIVATION:
|
||||
accumulator = ACTIVATION(accumulator)
|
||||
if ACTIVATION == "leaky_relu":
|
||||
accumulator = leaky_relu(accumulator)
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
@@ -261,7 +261,7 @@ def leaky_relu(x):
|
||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||
|
||||
|
||||
def matmul(a, b, activation=None):
|
||||
def matmul(a, b, activation=""):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||
@@ -347,7 +347,7 @@ def benchmark(M, N, K, provider):
|
||||
)
|
||||
if provider == 'triton + relu':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b, activation=leaky_relu)
|
||||
lambda: matmul(a, b, activation="leaky_relu")
|
||||
)
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
Reference in New Issue
Block a user