[FRONTEND] Fixed inliner and got more tests to pass (#822)
This adds a `DialectInlinerInterface` to the Triton dialect. This, along with a few other minor semantic changes, fixes our tests on call instructions. Also added the option to provide use an "LLVM_SYSPATH" environment variable to link against locally build of LLVM; this was useful for debugging this issue.
This commit is contained in:
@@ -8,7 +8,7 @@ namespace triton {
|
|||||||
|
|
||||||
std::unique_ptr<Pass> createCombineOpsPass();
|
std::unique_ptr<Pass> createCombineOpsPass();
|
||||||
|
|
||||||
}
|
} // namespace triton
|
||||||
|
|
||||||
#define GEN_PASS_REGISTRATION
|
#define GEN_PASS_REGISTRATION
|
||||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||||
|
@@ -8,11 +8,30 @@
|
|||||||
|
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
|
||||||
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
|
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TritonDialect Dialect Interfaces
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct TritonInlinerInterface : public DialectInlinerInterface {
|
||||||
|
using DialectInlinerInterface::DialectInlinerInterface;
|
||||||
|
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||||
|
BlockAndValueMapping &valueMapping) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||||
|
BlockAndValueMapping &) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void TritonDialect::initialize() {
|
void TritonDialect::initialize() {
|
||||||
registerTypes();
|
registerTypes();
|
||||||
|
|
||||||
@@ -22,6 +41,7 @@ void TritonDialect::initialize() {
|
|||||||
>();
|
>();
|
||||||
|
|
||||||
// We can also add interface here.
|
// We can also add interface here.
|
||||||
|
addInterfaces<TritonInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
|
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
|
||||||
|
@@ -38,12 +38,13 @@ class Package(NamedTuple):
|
|||||||
test_file: str
|
test_file: str
|
||||||
include_flag: str
|
include_flag: str
|
||||||
lib_flag: str
|
lib_flag: str
|
||||||
|
syspath_var_name: str
|
||||||
|
|
||||||
|
|
||||||
def get_pybind11_package_info():
|
def get_pybind11_package_info():
|
||||||
name = "pybind11-2.10.0"
|
name = "pybind11-2.10.0"
|
||||||
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
|
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
|
||||||
return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "")
|
return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
|
||||||
|
|
||||||
|
|
||||||
def get_llvm_package_info():
|
def get_llvm_package_info():
|
||||||
@@ -57,7 +58,7 @@ def get_llvm_package_info():
|
|||||||
else:
|
else:
|
||||||
name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix)
|
name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix)
|
||||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name)
|
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name)
|
||||||
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
|
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||||
|
|
||||||
|
|
||||||
def get_thirdparty_packages(triton_cache_path):
|
def get_thirdparty_packages(triton_cache_path):
|
||||||
@@ -67,6 +68,8 @@ def get_thirdparty_packages(triton_cache_path):
|
|||||||
package_root_dir = os.path.join(triton_cache_path, p.package)
|
package_root_dir = os.path.join(triton_cache_path, p.package)
|
||||||
package_dir = os.path.join(package_root_dir, p.name)
|
package_dir = os.path.join(package_root_dir, p.name)
|
||||||
test_file_path = os.path.join(package_dir, p.test_file)
|
test_file_path = os.path.join(package_dir, p.test_file)
|
||||||
|
if p.syspath_var_name in os.environ:
|
||||||
|
package_dir = os.environ[p.syspath_var_name]
|
||||||
if not os.path.exists(test_file_path):
|
if not os.path.exists(test_file_path):
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(package_root_dir)
|
shutil.rmtree(package_root_dir)
|
||||||
|
@@ -422,7 +422,12 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr)
|
.def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr)
|
||||||
// Use arith.ConstantOp to create constants
|
// Use arith.ConstantOp to create constants
|
||||||
// // Constants
|
// // Constants
|
||||||
// .def("get_int1", &ir::builder::get_int1, ret::reference)
|
.def("get_int1",
|
||||||
|
[](mlir::OpBuilder &self, bool v) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||||
|
loc, v, self.getI1Type()));
|
||||||
|
})
|
||||||
.def("get_int32",
|
.def("get_int32",
|
||||||
[](mlir::OpBuilder &self, int64_t v) -> mlir::Value {
|
[](mlir::OpBuilder &self, int64_t v) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
|
@@ -1177,20 +1177,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
# # ---------------
|
# # ---------------
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("start", [0, 1, 7, 16])
|
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
||||||
# def test_arange(start, device='cuda'):
|
def test_arange(start, device='cuda'):
|
||||||
# BLOCK = 128
|
BLOCK = 128
|
||||||
# z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def _kernel(z, BLOCK: tl.constexpr,
|
def _kernel(z, BLOCK: tl.constexpr,
|
||||||
# START: tl.constexpr, END: tl.constexpr):
|
START: tl.constexpr, END: tl.constexpr):
|
||||||
# off = tl.arange(0, BLOCK)
|
off = tl.arange(0, BLOCK)
|
||||||
# val = tl.arange(START, END)
|
val = tl.arange(START, END)
|
||||||
# tl.store(z + off, val)
|
tl.store(z + off, val)
|
||||||
# _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
|
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
|
||||||
# z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
|
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
|
||||||
# triton.testing.assert_almost_equal(z_tri, z_ref)
|
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||||
|
|
||||||
# # ---------------
|
# # ---------------
|
||||||
# # test load
|
# # test load
|
||||||
@@ -1248,47 +1248,47 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
# triton.testing.allclose(out, reference_out)
|
# triton.testing.allclose(out, reference_out)
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||||
# def test_load_cache_modifier(cache):
|
def test_load_cache_modifier(cache):
|
||||||
# src = torch.empty(128, device='cuda')
|
src = torch.empty(128, device='cuda')
|
||||||
# dst = torch.empty(128, device='cuda')
|
dst = torch.empty(128, device='cuda')
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def _kernel(dst, src, CACHE: tl.constexpr):
|
def _kernel(dst, src, CACHE: tl.constexpr):
|
||||||
# offsets = tl.arange(0, 128)
|
offsets = tl.arange(0, 128)
|
||||||
# x = tl.load(src + offsets, cache_modifier=CACHE)
|
x = tl.load(src + offsets, cache_modifier=CACHE)
|
||||||
# tl.store(dst + offsets, x)
|
tl.store(dst + offsets, x)
|
||||||
|
|
||||||
# pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||||
# ptx = pgm.asm['ptx']
|
ptx = pgm.asm['ptx']
|
||||||
# if cache == '':
|
if cache == '':
|
||||||
# assert 'ld.global.ca' not in ptx
|
assert 'ld.global.ca' not in ptx
|
||||||
# assert 'ld.global.cg' not in ptx
|
assert 'ld.global.cg' not in ptx
|
||||||
# if cache == '.cg':
|
if cache == '.cg':
|
||||||
# assert 'ld.global.cg' in ptx
|
assert 'ld.global.cg' in ptx
|
||||||
# assert 'ld.global.ca' not in ptx
|
assert 'ld.global.ca' not in ptx
|
||||||
# if cache == '.ca':
|
if cache == '.ca':
|
||||||
# assert 'ld.global.ca' in ptx
|
assert 'ld.global.ca' in ptx
|
||||||
# assert 'ld.global.cg' not in ptx
|
assert 'ld.global.cg' not in ptx
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
||||||
# def test_vectorization(N):
|
def test_vectorization(N):
|
||||||
# src = torch.empty(1024, device='cuda')
|
src = torch.empty(1024, device='cuda')
|
||||||
# dst = torch.empty(1024, device='cuda')
|
dst = torch.empty(1024, device='cuda')
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
||||||
# offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
# x = tl.load(src + offsets, mask=offsets < N)
|
x = tl.load(src + offsets, mask=offsets < N)
|
||||||
# tl.store(dst + offsets, x, mask=offsets < N)
|
tl.store(dst + offsets, x, mask=offsets < N)
|
||||||
# pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||||
# ptx = pgm.asm["ptx"]
|
ptx = pgm.asm["ptx"]
|
||||||
# if N % 16 == 0:
|
if N % 16 == 0:
|
||||||
# assert "ld.global.v4.b32" in ptx
|
assert "ld.global.v4.b32" in ptx
|
||||||
# else:
|
else:
|
||||||
# assert "ld.global.b32" in ptx
|
assert "ld.global.b32" in ptx
|
||||||
# # triton.testing.assert_almost_equal(dst, src[:N])
|
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||||
# # ---------------
|
# # ---------------
|
||||||
# # test store
|
# # test store
|
||||||
# # ---------------
|
# # ---------------
|
||||||
@@ -1335,145 +1335,149 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
# # ----------------
|
# # ----------------
|
||||||
|
|
||||||
|
|
||||||
# def test_noop(device='cuda'):
|
def test_noop(device='cuda'):
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(x):
|
def kernel(x):
|
||||||
# pass
|
pass
|
||||||
# x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
||||||
# kernel[(1, )](x)
|
kernel[(1, )](x)
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("value, value_type", [
|
@pytest.mark.parametrize("value, value_type", [
|
||||||
# (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||||
# (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||||
# (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||||
# ])
|
])
|
||||||
# def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||||
# spec_type = None
|
spec_type = None
|
||||||
|
|
||||||
# def cache_hook(*args, **kwargs):
|
def cache_hook(*args, **kwargs):
|
||||||
# nonlocal spec_type
|
nonlocal spec_type
|
||||||
# spec_type = kwargs["compile"]["signature"][0]
|
spec_type = kwargs["compile"]["signature"][0]
|
||||||
# JITFunction.cache_hook = cache_hook
|
JITFunction.cache_hook = cache_hook
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(VALUE, X):
|
def kernel(VALUE, X):
|
||||||
# pass
|
pass
|
||||||
|
|
||||||
# x = torch.tensor([3.14159], device='cuda')
|
x = torch.tensor([3.14159], device='cuda')
|
||||||
# pgm = kernel[(1, )](value, x)
|
pgm = kernel[(1, )](value, x)
|
||||||
|
|
||||||
# JITFunction.cache_hook = None
|
JITFunction.cache_hook = None
|
||||||
# assert spec_type == value_type
|
assert spec_type == value_type
|
||||||
|
|
||||||
|
# # --------------------
|
||||||
|
# # value specialization
|
||||||
|
# # --------------------
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# "value, overflow",
|
"value, overflow",
|
||||||
# [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
|
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
|
||||||
# )
|
)
|
||||||
# def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
|
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(VALUE, X):
|
def kernel(VALUE, X):
|
||||||
# pass
|
pass
|
||||||
|
|
||||||
# x = torch.tensor([3.14159], device='cuda')
|
x = torch.tensor([3.14159], device='cuda')
|
||||||
|
|
||||||
# if overflow:
|
if overflow:
|
||||||
# with pytest.raises(OverflowError):
|
with pytest.raises(OverflowError):
|
||||||
# kernel[(1, )](value, x)
|
kernel[(1, )](value, x)
|
||||||
# else:
|
else:
|
||||||
# kernel[(1, )](value, x)
|
kernel[(1, )](value, x)
|
||||||
|
|
||||||
|
|
||||||
# # ----------------
|
# # ----------------
|
||||||
# # test constexpr
|
# # test constexpr
|
||||||
# # ----------------
|
# # ----------------
|
||||||
|
|
||||||
# @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
|
@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
|
||||||
# @pytest.mark.parametrize("is_lhs_constexpr", [False, True])
|
@pytest.mark.parametrize("is_lhs_constexpr", [False, True])
|
||||||
# @pytest.mark.parametrize("is_rhs_constexpr", [True, False])
|
@pytest.mark.parametrize("is_rhs_constexpr", [True, False])
|
||||||
# def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
|
def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(Z, X, Y):
|
def kernel(Z, X, Y):
|
||||||
# x = tl.load(X)
|
x = tl.load(X)
|
||||||
# y = tl.load(Y)
|
y = tl.load(Y)
|
||||||
# z = GENERATE_TEST_HERE
|
z = GENERATE_TEST_HERE
|
||||||
# tl.store(Z, z)
|
tl.store(Z, z)
|
||||||
|
|
||||||
# x_str = "3.14" if is_lhs_constexpr else "x"
|
x_str = "3.14" if is_lhs_constexpr else "x"
|
||||||
# y_str = "4.13" if is_rhs_constexpr else "y"
|
y_str = "4.13" if is_rhs_constexpr else "y"
|
||||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
|
||||||
# x = numpy_random((1,), dtype_str="float32")
|
x = numpy_random((1,), dtype_str="float32")
|
||||||
# y = numpy_random((1,), dtype_str="float32")
|
y = numpy_random((1,), dtype_str="float32")
|
||||||
# z = np.array(eval(f"{x_str} {op} {y_str}"))
|
z = np.array(eval(f"{x_str} {op} {y_str}"))
|
||||||
# x_tri = to_triton(x)
|
x_tri = to_triton(x)
|
||||||
# y_tri = to_triton(y)
|
y_tri = to_triton(y)
|
||||||
# z_tri = to_triton(np.empty((1,), dtype=z.dtype))
|
z_tri = to_triton(np.empty((1,), dtype=z.dtype))
|
||||||
# kernel[(1,)](z_tri, x_tri, y_tri)
|
kernel[(1,)](z_tri, x_tri, y_tri)
|
||||||
# np.testing.assert_allclose(z, to_numpy(z_tri))
|
np.testing.assert_allclose(z, to_numpy(z_tri))
|
||||||
|
|
||||||
|
|
||||||
# def test_constexpr_shape():
|
def test_constexpr_shape():
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(X):
|
def kernel(X):
|
||||||
# off = tl.arange(0, 128 + 128)
|
off = tl.arange(0, 128 + 128)
|
||||||
# tl.store(X + off, off)
|
tl.store(X + off, off)
|
||||||
|
|
||||||
# x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||||
# kernel[(1,)](x_tri)
|
kernel[(1,)](x_tri)
|
||||||
# np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
||||||
|
|
||||||
|
|
||||||
# def test_constexpr_scalar_shape():
|
def test_constexpr_scalar_shape():
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(X, s):
|
def kernel(X, s):
|
||||||
# off = tl.arange(0, 256)
|
off = tl.arange(0, 256)
|
||||||
# val = off % (256 // s)
|
val = off % (256 // s)
|
||||||
# tl.store(X + off, val)
|
tl.store(X + off, val)
|
||||||
|
|
||||||
# x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||||
# kernel[(1,)](x_tri, 32)
|
kernel[(1,)](x_tri, 32)
|
||||||
# np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
|
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
|
||||||
|
|
||||||
# # -------------
|
# # -------------
|
||||||
# # test call
|
# # test call
|
||||||
# # -------------
|
# # -------------
|
||||||
|
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def val_multiplier(val, i):
|
def val_multiplier(val, i):
|
||||||
# return val * i
|
return val * i
|
||||||
|
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def vecmul_kernel(ptr, n_elements, rep):
|
def vecmul_kernel(ptr, n_elements, rep):
|
||||||
# pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
# offsets = pid * 128 + tl.arange(0, 128)
|
offsets = pid * 128 + tl.arange(0, 128)
|
||||||
# mask = offsets < n_elements
|
mask = offsets < n_elements
|
||||||
# vec = tl.load(ptr + offsets, mask=mask)
|
vec = tl.load(ptr + offsets, mask=mask)
|
||||||
# for i in range(1, rep):
|
for i in range(1, rep):
|
||||||
# vec = val_multiplier(vec, i)
|
vec = val_multiplier(vec, i)
|
||||||
# tl.store(ptr + offsets, vec, mask=mask)
|
tl.store(ptr + offsets, vec, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
# def test_call():
|
def test_call():
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(ptr, n_elements, num1, num2):
|
def kernel(ptr, n_elements, num1, num2):
|
||||||
# vecmul_kernel(ptr, n_elements, num1)
|
vecmul_kernel(ptr, n_elements, num1)
|
||||||
# vecmul_kernel(ptr, n_elements, num2)
|
vecmul_kernel(ptr, n_elements, num2)
|
||||||
|
|
||||||
# size = 1024
|
size = 1024
|
||||||
# rand_val = numpy_random((size,), dtype_str="float32")
|
rand_val = numpy_random((size,), dtype_str="float32")
|
||||||
# rand_val_tri = to_triton(rand_val, device='cuda')
|
rand_val_tri = to_triton(rand_val, device='cuda')
|
||||||
# kernel[(size // 128,)](rand_val_tri, size, 3, 5)
|
kernel[(size // 128,)](rand_val_tri, size, 3, 5)
|
||||||
|
|
||||||
# ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
||||||
# np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
||||||
|
|
||||||
# # -------------
|
# # -------------
|
||||||
# # test if
|
# # test if
|
||||||
|
@@ -685,8 +685,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
||||||
# generate function def if necessary
|
# generate function def if necessary
|
||||||
if not self.module.has_function(fn_name):
|
if not self.module.has_function(fn_name):
|
||||||
ret_type = triton.language.void
|
prototype = triton.language.function_type([], arg_types)
|
||||||
prototype = triton.language.function_type([ret_type], arg_types)
|
|
||||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types)
|
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types)
|
||||||
generator.visit(fn.parse())
|
generator.visit(fn.parse())
|
||||||
@@ -696,7 +695,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
callee_ret_type = self.function_ret_types[fn_name]
|
callee_ret_type = self.function_ret_types[fn_name]
|
||||||
symbol = self.module.get_function(fn_name)
|
symbol = self.module.get_function(fn_name)
|
||||||
call_op = self.builder.call(symbol, arg_vals)
|
call_op = self.builder.call(symbol, arg_vals)
|
||||||
if call_op.get_num_results() == 0:
|
if call_op.get_num_results() == 0 or callee_ret_type is None:
|
||||||
return None
|
return None
|
||||||
elif call_op.get_num_results() == 1:
|
elif call_op.get_num_results() == 1:
|
||||||
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
|
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
|
||||||
|
@@ -473,6 +473,11 @@ class tensor:
|
|||||||
other = _to_tensor(other, _builder)
|
other = _to_tensor(other, _builder)
|
||||||
return semantic.mod(self, other, _builder)
|
return semantic.mod(self, other, _builder)
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def __rmod__(self, other, _builder=None):
|
||||||
|
other = _to_tensor(other, _builder)
|
||||||
|
return semantic.mod(other, self, _builder)
|
||||||
|
|
||||||
# unary operators
|
# unary operators
|
||||||
@builtin
|
@builtin
|
||||||
def __neg__(self, _builder=None):
|
def __neg__(self, _builder=None):
|
||||||
@@ -541,6 +546,7 @@ class tensor:
|
|||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __rlt__(self, other, _builder=None):
|
def __rlt__(self, other, _builder=None):
|
||||||
|
other = _to_tensor(other, _builder)
|
||||||
return semantic.less_than(other, self, _builder)
|
return semantic.less_than(other, self, _builder)
|
||||||
|
|
||||||
# <=
|
# <=
|
||||||
|
Reference in New Issue
Block a user