[RUNTIME] Re-vamped cache so users can manually patch IR / ptx / cubin files (#845)

Also deprecates a couple of tests
This commit is contained in:
Philippe Tillet
2022-11-04 10:57:29 -07:00
committed by GitHub
parent 4218e68d74
commit b6dbe959f0
5 changed files with 93 additions and 197 deletions

View File

@@ -1,33 +0,0 @@
import triton
import triton.language as tl
@triton.jit
def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
x1 = tl.load(x1_ptr + offsets, mask=offsets < n)
x2 = tl.load(x2_ptr + offsets, mask=offsets < n)
x3 = tl.load(x3_ptr + offsets, mask=offsets < n)
x4 = tl.load(x4_ptr + offsets, mask=offsets < n)
y1 = tl.sin(x1)
y2 = tl.libdevice.sin(x2)
y3 = tl.libdevice.div_rn(x3, x3)
y4 = tl.libdevice.fma_rd(x4, x4, x4)
tl.store(x1_ptr + offsets, y1, mask=offsets < n)
tl.store(x2_ptr + offsets, y2, mask=offsets < n)
tl.store(x3_ptr + offsets, y3, mask=offsets < n)
tl.store(x4_ptr + offsets, y4, mask=offsets < n)
def test_empty_kernel_cubin_compile():
kernel = triton.compiler._compile(math_kernel,
"*fp32,*fp32,*fp32,*fp32,i32",
device=0,
constants={"BLOCK_SIZE": 256},
output="ttgir") # "cubin"
assert kernel
# TODO: Check if the values are correct.
# TODO: Cover all the math operators

View File

@@ -1,80 +0,0 @@
import triton
import triton.language as tl
# TODO: function with no arguments don't work
@triton.jit
def binop_type_check(X):
# 0d-tensor is not allowed.
# zero_0d = tl.zeros([], dtype=tl.float32)
zero_1d = tl.zeros([2], dtype=tl.float32)
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
# scalar + scalar -> scalar
a0 = 0.0 + 0.0
# # scalar + 0D -> 0D
# a1 = 0.0 + zero_0d
# a2 = zero_0d + 0.0
# scalar + 1D -> 1D
a3 = 0.0 + zero_1d
a4 = zero_1d + 0.0
# scalar + 2D -> 2D
a5 = 0.0 + zero_2d_22
a6 = zero_2d_22 + 0.0
# # 0D + 0D -> 0D
# b1 = zero_0d + zero_0d
# # 0D + 1D -> 1D
# b2 = zero_0d + zero_1d
# b3 = zero_1d + zero_0d
# # 0D + 2D -> 2D
# b4 = zero_0d + zero_2d_22
# b5 = zero_2d_22 + zero_0d
# 1D + 1D -> 1D
c1 = zero_1d + zero_1d
# 1D + 2D -> 2D
c2 = zero_1d + zero_2d_21
c3 = zero_1d + zero_2d_22
c4 = zero_2d_21 + zero_1d
c5 = zero_2d_22 + zero_1d
# 2D + 2D -> 2D
d1 = zero_2d_21 + zero_2d_21
d2 = zero_2d_22 + zero_2d_22
d3 = zero_2d_21 + zero_2d_22
d4 = zero_2d_22 + zero_2d_21
# return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4
def test_binop_type_check():
kernel = triton.compiler._compile(binop_type_check,
signature="*fp32",
device=0,
output="ttir")
assert (kernel)
# TODO: Check types of the results
@triton.jit
def reduce_type_check(ptr):
v_32 = tl.load(ptr + tl.arange(0, 32))
v_scalar = tl.min(v_32, axis=0)
tl.store(ptr, v_scalar)
v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :])
v_64 = tl.max(v_64x128, axis=1)
tl.store(ptr + tl.arange(0, 64), v_64)
v_128 = tl.max(v_64x128, axis=0)
tl.store(ptr + tl.arange(0, 128), v_128)
def test_reduce_type_check():
kernel = triton.compiler._compile(reduce_type_check,
signature="*fp32",
device=0,
output="ttir")
assert (kernel)
# TODO: Check types of the results