[FRONTEND] Fixed inliner and got more tests to pass (#822)

This adds a `DialectInlinerInterface` to the Triton dialect. This, along
with a few other minor semantic changes, fixes our tests on call
instructions. Also added the option to provide use an "LLVM_SYSPATH"
environment variable to link against locally build of LLVM; this was
useful for debugging this issue.
This commit is contained in:
Philippe Tillet
2022-10-30 14:10:02 -07:00
committed by GitHub
parent 71428194a1
commit e61dc75942
7 changed files with 192 additions and 155 deletions

View File

@@ -8,7 +8,7 @@ namespace triton {
std::unique_ptr<Pass> createCombineOpsPass(); std::unique_ptr<Pass> createCombineOpsPass();
} } // namespace triton
#define GEN_PASS_REGISTRATION #define GEN_PASS_REGISTRATION
#include "triton/Dialect/Triton/Transforms/Passes.h.inc" #include "triton/Dialect/Triton/Transforms/Passes.h.inc"

View File

@@ -8,11 +8,30 @@
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" #include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
using namespace mlir; using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
//===----------------------------------------------------------------------===//
// TritonDialect Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TritonInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
};
} // namespace
void TritonDialect::initialize() { void TritonDialect::initialize() {
registerTypes(); registerTypes();
@@ -22,6 +41,7 @@ void TritonDialect::initialize() {
>(); >();
// We can also add interface here. // We can also add interface here.
addInterfaces<TritonInlinerInterface>();
} }
Operation *TritonDialect::materializeConstant(OpBuilder &builder, Operation *TritonDialect::materializeConstant(OpBuilder &builder,

View File

@@ -38,12 +38,13 @@ class Package(NamedTuple):
test_file: str test_file: str
include_flag: str include_flag: str
lib_flag: str lib_flag: str
syspath_var_name: str
def get_pybind11_package_info(): def get_pybind11_package_info():
name = "pybind11-2.10.0" name = "pybind11-2.10.0"
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz" url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "") return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
def get_llvm_package_info(): def get_llvm_package_info():
@@ -57,7 +58,7 @@ def get_llvm_package_info():
else: else:
name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix) name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix)
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name) url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name)
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR") return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
def get_thirdparty_packages(triton_cache_path): def get_thirdparty_packages(triton_cache_path):
@@ -67,6 +68,8 @@ def get_thirdparty_packages(triton_cache_path):
package_root_dir = os.path.join(triton_cache_path, p.package) package_root_dir = os.path.join(triton_cache_path, p.package)
package_dir = os.path.join(package_root_dir, p.name) package_dir = os.path.join(package_root_dir, p.name)
test_file_path = os.path.join(package_dir, p.test_file) test_file_path = os.path.join(package_dir, p.test_file)
if p.syspath_var_name in os.environ:
package_dir = os.environ[p.syspath_var_name]
if not os.path.exists(test_file_path): if not os.path.exists(test_file_path):
try: try:
shutil.rmtree(package_root_dir) shutil.rmtree(package_root_dir)

View File

@@ -422,7 +422,12 @@ void init_triton_ir(py::module &&m) {
.def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr) .def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr)
// Use arith.ConstantOp to create constants // Use arith.ConstantOp to create constants
// // Constants // // Constants
// .def("get_int1", &ir::builder::get_int1, ret::reference) .def("get_int1",
[](mlir::OpBuilder &self, bool v) -> mlir::Value {
auto loc = self.getUnknownLoc();
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
loc, v, self.getI1Type()));
})
.def("get_int32", .def("get_int32",
[](mlir::OpBuilder &self, int64_t v) -> mlir::Value { [](mlir::OpBuilder &self, int64_t v) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();

View File

@@ -1177,20 +1177,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# # --------------- # # ---------------
# @pytest.mark.parametrize("start", [0, 1, 7, 16]) @pytest.mark.parametrize("start", [0, 1, 7, 16])
# def test_arange(start, device='cuda'): def test_arange(start, device='cuda'):
# BLOCK = 128 BLOCK = 128
# z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
# @triton.jit @triton.jit
# def _kernel(z, BLOCK: tl.constexpr, def _kernel(z, BLOCK: tl.constexpr,
# START: tl.constexpr, END: tl.constexpr): START: tl.constexpr, END: tl.constexpr):
# off = tl.arange(0, BLOCK) off = tl.arange(0, BLOCK)
# val = tl.arange(START, END) val = tl.arange(START, END)
# tl.store(z + off, val) tl.store(z + off, val)
# _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK) _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
# z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
# triton.testing.assert_almost_equal(z_tri, z_ref) triton.testing.assert_almost_equal(z_tri, z_ref)
# # --------------- # # ---------------
# # test load # # test load
@@ -1248,47 +1248,47 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# triton.testing.allclose(out, reference_out) # triton.testing.allclose(out, reference_out)
# @pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) @pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
# def test_load_cache_modifier(cache): def test_load_cache_modifier(cache):
# src = torch.empty(128, device='cuda') src = torch.empty(128, device='cuda')
# dst = torch.empty(128, device='cuda') dst = torch.empty(128, device='cuda')
# @triton.jit @triton.jit
# def _kernel(dst, src, CACHE: tl.constexpr): def _kernel(dst, src, CACHE: tl.constexpr):
# offsets = tl.arange(0, 128) offsets = tl.arange(0, 128)
# x = tl.load(src + offsets, cache_modifier=CACHE) x = tl.load(src + offsets, cache_modifier=CACHE)
# tl.store(dst + offsets, x) tl.store(dst + offsets, x)
# pgm = _kernel[(1,)](dst, src, CACHE=cache) pgm = _kernel[(1,)](dst, src, CACHE=cache)
# ptx = pgm.asm['ptx'] ptx = pgm.asm['ptx']
# if cache == '': if cache == '':
# assert 'ld.global.ca' not in ptx assert 'ld.global.ca' not in ptx
# assert 'ld.global.cg' not in ptx assert 'ld.global.cg' not in ptx
# if cache == '.cg': if cache == '.cg':
# assert 'ld.global.cg' in ptx assert 'ld.global.cg' in ptx
# assert 'ld.global.ca' not in ptx assert 'ld.global.ca' not in ptx
# if cache == '.ca': if cache == '.ca':
# assert 'ld.global.ca' in ptx assert 'ld.global.ca' in ptx
# assert 'ld.global.cg' not in ptx assert 'ld.global.cg' not in ptx
# @pytest.mark.parametrize("N", [16, 10, 11, 1024]) @pytest.mark.parametrize("N", [16, 10, 11, 1024])
# def test_vectorization(N): def test_vectorization(N):
# src = torch.empty(1024, device='cuda') src = torch.empty(1024, device='cuda')
# dst = torch.empty(1024, device='cuda') dst = torch.empty(1024, device='cuda')
# @triton.jit @triton.jit
# def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
# offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# x = tl.load(src + offsets, mask=offsets < N) x = tl.load(src + offsets, mask=offsets < N)
# tl.store(dst + offsets, x, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N)
# pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
# ptx = pgm.asm["ptx"] ptx = pgm.asm["ptx"]
# if N % 16 == 0: if N % 16 == 0:
# assert "ld.global.v4.b32" in ptx assert "ld.global.v4.b32" in ptx
# else: else:
# assert "ld.global.b32" in ptx assert "ld.global.b32" in ptx
# # triton.testing.assert_almost_equal(dst, src[:N]) # triton.testing.assert_almost_equal(dst, src[:N])
# # --------------- # # ---------------
# # test store # # test store
# # --------------- # # ---------------
@@ -1335,145 +1335,149 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# # ---------------- # # ----------------
# def test_noop(device='cuda'): def test_noop(device='cuda'):
# @triton.jit @triton.jit
# def kernel(x): def kernel(x):
# pass pass
# x = to_triton(numpy_random((1,), dtype_str='int32'), device=device) x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
# kernel[(1, )](x) kernel[(1, )](x)
# @pytest.mark.parametrize("value, value_type", [ @pytest.mark.parametrize("value, value_type", [
# (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
# (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
# (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
# ]) ])
# def test_value_specialization(value: int, value_type: str, device='cuda') -> None: def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
# spec_type = None spec_type = None
# def cache_hook(*args, **kwargs): def cache_hook(*args, **kwargs):
# nonlocal spec_type nonlocal spec_type
# spec_type = kwargs["compile"]["signature"][0] spec_type = kwargs["compile"]["signature"][0]
# JITFunction.cache_hook = cache_hook JITFunction.cache_hook = cache_hook
# @triton.jit @triton.jit
# def kernel(VALUE, X): def kernel(VALUE, X):
# pass pass
# x = torch.tensor([3.14159], device='cuda') x = torch.tensor([3.14159], device='cuda')
# pgm = kernel[(1, )](value, x) pgm = kernel[(1, )](value, x)
# JITFunction.cache_hook = None JITFunction.cache_hook = None
# assert spec_type == value_type assert spec_type == value_type
# # --------------------
# # value specialization
# # --------------------
# @pytest.mark.parametrize( @pytest.mark.parametrize(
# "value, overflow", "value, overflow",
# [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
# ) )
# def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None: def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
# @triton.jit @triton.jit
# def kernel(VALUE, X): def kernel(VALUE, X):
# pass pass
# x = torch.tensor([3.14159], device='cuda') x = torch.tensor([3.14159], device='cuda')
# if overflow: if overflow:
# with pytest.raises(OverflowError): with pytest.raises(OverflowError):
# kernel[(1, )](value, x) kernel[(1, )](value, x)
# else: else:
# kernel[(1, )](value, x) kernel[(1, )](value, x)
# # ---------------- # # ----------------
# # test constexpr # # test constexpr
# # ---------------- # # ----------------
# @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>']) @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
# @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) @pytest.mark.parametrize("is_lhs_constexpr", [False, True])
# @pytest.mark.parametrize("is_rhs_constexpr", [True, False]) @pytest.mark.parametrize("is_rhs_constexpr", [True, False])
# def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr): def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
# @triton.jit @triton.jit
# def kernel(Z, X, Y): def kernel(Z, X, Y):
# x = tl.load(X) x = tl.load(X)
# y = tl.load(Y) y = tl.load(Y)
# z = GENERATE_TEST_HERE z = GENERATE_TEST_HERE
# tl.store(Z, z) tl.store(Z, z)
# x_str = "3.14" if is_lhs_constexpr else "x" x_str = "3.14" if is_lhs_constexpr else "x"
# y_str = "4.13" if is_rhs_constexpr else "y" y_str = "4.13" if is_rhs_constexpr else "y"
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
# x = numpy_random((1,), dtype_str="float32") x = numpy_random((1,), dtype_str="float32")
# y = numpy_random((1,), dtype_str="float32") y = numpy_random((1,), dtype_str="float32")
# z = np.array(eval(f"{x_str} {op} {y_str}")) z = np.array(eval(f"{x_str} {op} {y_str}"))
# x_tri = to_triton(x) x_tri = to_triton(x)
# y_tri = to_triton(y) y_tri = to_triton(y)
# z_tri = to_triton(np.empty((1,), dtype=z.dtype)) z_tri = to_triton(np.empty((1,), dtype=z.dtype))
# kernel[(1,)](z_tri, x_tri, y_tri) kernel[(1,)](z_tri, x_tri, y_tri)
# np.testing.assert_allclose(z, to_numpy(z_tri)) np.testing.assert_allclose(z, to_numpy(z_tri))
# def test_constexpr_shape(): def test_constexpr_shape():
# @triton.jit @triton.jit
# def kernel(X): def kernel(X):
# off = tl.arange(0, 128 + 128) off = tl.arange(0, 128 + 128)
# tl.store(X + off, off) tl.store(X + off, off)
# x_tri = to_triton(np.empty((256, ), dtype=np.int32)) x_tri = to_triton(np.empty((256, ), dtype=np.int32))
# kernel[(1,)](x_tri) kernel[(1,)](x_tri)
# np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
# def test_constexpr_scalar_shape(): def test_constexpr_scalar_shape():
# @triton.jit @triton.jit
# def kernel(X, s): def kernel(X, s):
# off = tl.arange(0, 256) off = tl.arange(0, 256)
# val = off % (256 // s) val = off % (256 // s)
# tl.store(X + off, val) tl.store(X + off, val)
# x_tri = to_triton(np.empty((256, ), dtype=np.int32)) x_tri = to_triton(np.empty((256, ), dtype=np.int32))
# kernel[(1,)](x_tri, 32) kernel[(1,)](x_tri, 32)
# np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
# # ------------- # # -------------
# # test call # # test call
# # ------------- # # -------------
# @triton.jit @triton.jit
# def val_multiplier(val, i): def val_multiplier(val, i):
# return val * i return val * i
# @triton.jit @triton.jit
# def vecmul_kernel(ptr, n_elements, rep): def vecmul_kernel(ptr, n_elements, rep):
# pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
# offsets = pid * 128 + tl.arange(0, 128) offsets = pid * 128 + tl.arange(0, 128)
# mask = offsets < n_elements mask = offsets < n_elements
# vec = tl.load(ptr + offsets, mask=mask) vec = tl.load(ptr + offsets, mask=mask)
# for i in range(1, rep): for i in range(1, rep):
# vec = val_multiplier(vec, i) vec = val_multiplier(vec, i)
# tl.store(ptr + offsets, vec, mask=mask) tl.store(ptr + offsets, vec, mask=mask)
# def test_call(): def test_call():
# @triton.jit @triton.jit
# def kernel(ptr, n_elements, num1, num2): def kernel(ptr, n_elements, num1, num2):
# vecmul_kernel(ptr, n_elements, num1) vecmul_kernel(ptr, n_elements, num1)
# vecmul_kernel(ptr, n_elements, num2) vecmul_kernel(ptr, n_elements, num2)
# size = 1024 size = 1024
# rand_val = numpy_random((size,), dtype_str="float32") rand_val = numpy_random((size,), dtype_str="float32")
# rand_val_tri = to_triton(rand_val, device='cuda') rand_val_tri = to_triton(rand_val, device='cuda')
# kernel[(size // 128,)](rand_val_tri, size, 3, 5) kernel[(size // 128,)](rand_val_tri, size, 3, 5)
# ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
# np.testing.assert_equal(to_numpy(rand_val_tri), ans) np.testing.assert_equal(to_numpy(rand_val_tri), ans)
# # ------------- # # -------------
# # test if # # test if

View File

@@ -685,8 +685,7 @@ class CodeGenerator(ast.NodeVisitor):
fn_name = mangle_fn(fn.__name__, arg_types, constants) fn_name = mangle_fn(fn.__name__, arg_types, constants)
# generate function def if necessary # generate function def if necessary
if not self.module.has_function(fn_name): if not self.module.has_function(fn_name):
ret_type = triton.language.void prototype = triton.language.function_type([], arg_types)
prototype = triton.language.function_type([ret_type], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__ gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types) generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types)
generator.visit(fn.parse()) generator.visit(fn.parse())
@@ -696,7 +695,7 @@ class CodeGenerator(ast.NodeVisitor):
callee_ret_type = self.function_ret_types[fn_name] callee_ret_type = self.function_ret_types[fn_name]
symbol = self.module.get_function(fn_name) symbol = self.module.get_function(fn_name)
call_op = self.builder.call(symbol, arg_vals) call_op = self.builder.call(symbol, arg_vals)
if call_op.get_num_results() == 0: if call_op.get_num_results() == 0 or callee_ret_type is None:
return None return None
elif call_op.get_num_results() == 1: elif call_op.get_num_results() == 1:
return triton.language.tensor(call_op.get_result(0), callee_ret_type) return triton.language.tensor(call_op.get_result(0), callee_ret_type)

View File

@@ -473,6 +473,11 @@ class tensor:
other = _to_tensor(other, _builder) other = _to_tensor(other, _builder)
return semantic.mod(self, other, _builder) return semantic.mod(self, other, _builder)
@builtin
def __rmod__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.mod(other, self, _builder)
# unary operators # unary operators
@builtin @builtin
def __neg__(self, _builder=None): def __neg__(self, _builder=None):
@@ -541,6 +546,7 @@ class tensor:
@builtin @builtin
def __rlt__(self, other, _builder=None): def __rlt__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.less_than(other, self, _builder) return semantic.less_than(other, self, _builder)
# <= # <=