This PR merges the new runtime back into the `triton-mlir` branch. This adds caching and just-in-time compilation functionality to the triton-mlir project, and paves the way for re-using tests from the master branch.
58 lines
1.5 KiB
Python
58 lines
1.5 KiB
Python
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
# TODO: function with no arguments don't work
|
|
@triton.jit
|
|
def cast_check(X):
|
|
zero_0d = tl.zeros([], dtype=tl.float32)
|
|
zero_1d = tl.zeros([2], dtype=tl.float32)
|
|
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
|
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
|
|
|
|
# scalar + scalar -> scalar
|
|
a0 = 0.0 + 0.0
|
|
# scalar + 0D -> 0D
|
|
a1 = 0.0 + zero_0d
|
|
a2 = zero_0d + 0.0
|
|
# scalar + 1D -> 1D
|
|
a3 = 0.0 + zero_1d
|
|
a4 = zero_1d + 0.0
|
|
# scalar + 2D -> 2D
|
|
a5 = 0.0 + zero_2d_22
|
|
a6 = zero_2d_22 + 0.0
|
|
|
|
# 0D + 0D -> 0D
|
|
b1 = zero_0d + zero_0d
|
|
# 0D + 1D -> 1D
|
|
b2 = zero_0d + zero_1d
|
|
b3 = zero_1d + zero_0d
|
|
# 0D + 2D -> 2D
|
|
b4 = zero_0d + zero_2d_22
|
|
b5 = zero_2d_22 + zero_0d
|
|
|
|
# 1D + 1D -> 1D
|
|
c1 = zero_1d + zero_1d
|
|
# 1D + 2D -> 2D
|
|
c2 = zero_1d + zero_2d_21
|
|
c3 = zero_1d + zero_2d_22
|
|
c4 = zero_2d_21 + zero_1d
|
|
c5 = zero_2d_22 + zero_1d
|
|
|
|
# 2D + 2D -> 2D
|
|
d1 = zero_2d_21 + zero_2d_21
|
|
d2 = zero_2d_22 + zero_2d_22
|
|
d3 = zero_2d_21 + zero_2d_22
|
|
d4 = zero_2d_22 + zero_2d_21
|
|
|
|
return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
|
|
|
|
|
def test_cast_check():
|
|
kernel = triton.compiler._compile(cast_check,
|
|
signature="*fp32",
|
|
device=0,
|
|
output="ttgir")
|
|
assert (kernel)
|
|
# TODO: Check types of the results
|