From 5123db0b7d71f53c0245b94887eaefb921a14bca Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 4 Oct 2021 18:39:40 -0700 Subject: [PATCH] [LANG] Various (relatively minor) improvements (#320) --- .github/workflows/integration-tests.yml | 6 ++++- lib/codegen/analysis/layout.cc | 1 + lib/codegen/selection/generator.cc | 5 ++-- lib/ir/instructions.cc | 2 +- python/src/triton.cc | 1 + python/test/unit/language/test_core.py | 16 +++++++++++++ python/test/unit/language/test_random.py | 2 +- python/triton/code_gen.py | 29 ++++++++++++++++-------- python/triton/language/core.py | 8 ++++++- python/triton/language/random.py | 5 +++- 10 files changed, 59 insertions(+), 16 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index fb69bae2f..5fdf4e140 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -18,12 +18,16 @@ jobs: - name: Checkout uses: actions/checkout@v2 + - name: Clear cache + run: | + rm -r /tmp/triton/ + continue-on-error: true + - name: Install Triton run: | alias python='python3' cd python pip3 install -e . - rm -r /tmp/triton/ - name: Unit tests run: | diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 1693eff42..5ad6eb304 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -537,6 +537,7 @@ void layouts::run(ir::module &mod) { tmp_[atom] = id; } }); + } } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index f0068be11..20692036f 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -162,7 +162,7 @@ Type *generator::cvt(ir::type *ty) { case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); - case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); + case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); @@ -2197,7 +2197,8 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) { void generator::visit_make_range(ir::make_range* x) { for(indices_t idx: idxs_.at(x)){ - vals_[x][idx] = idx[0]; + Value* start = ConstantInt::get(idx[0]->getType(), x->get_first()->get_value()); + vals_[x][idx] = add(start, idx[0]); } } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index e50fd790f..298492a30 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -875,7 +875,7 @@ make_range *make_range::create(constant_int *first, constant_int *last) { assert(first->get_type()->is_integer_ty()); assert(first->get_type() == last->get_type()); assert(((constant_int*)first)->get_value() == 0); - type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value()}); + type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()}); return new make_range(ty, first, last); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 0140a1362..1378710e1 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -476,6 +476,7 @@ void init_triton_ir(py::module &&m) { // constants .def("get_int1", &ir::builder::get_int1, ret::reference) .def("get_int32", &ir::builder::get_int32, ret::reference) + .def("get_int64", &ir::builder::get_int64, ret::reference) .def("get_float16", &ir::builder::get_float16, ret::reference) .def("get_float32", &ir::builder::get_float32, ret::reference) .def("get_range", &ir::builder::get_range, ret::reference); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 07b0df31a..3744e4d03 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -515,6 +515,22 @@ def test_dot(epilogue, device='cuda'): assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx +# --------------- +# test arange +# --------------- + +@pytest.mark.parametrize("start", [0, 1, 7, 16]) +def test_arange(start, device='cuda'): + BLOCK = 128 + z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) + @triton.jit + def _kernel(z, **meta): + off = tl.arange(0, meta['BLOCK']) + val = tl.arange(meta['START'], meta['END']) + tl.store(z + off, val) + _kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK) + z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device) + triton.testing.assert_almost_equal(z_tri, z_ref) # --------------- # test load diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index 6c15a7588..a7f178f02 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -112,7 +112,7 @@ BLOCK = 1024 # test generation of random uint32 @pytest.mark.parametrize('size, seed', [(size, seed) for size in ['10', '4,53', '10000']\ - for seed in [0, 42, 124, 54]] + for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] ) def test_randint(size, seed, device='cuda'): size = list(map(int, size.split(','))) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index cd919cc6c..7f6982329 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -103,7 +103,8 @@ class CodeGenerator(ast.NodeVisitor): arg_values = [] for i, arg_name in enumerate(arg_names): if i in self.constants: - arg_values.append(self.constants[i]) + cst = triton.language.core._to_ir(self.constants[i], self.builder) + arg_values.append(cst) else: if i in self.attributes: is_ptr = fn.args[i].type.is_ptr() @@ -463,9 +464,6 @@ class Kernel: @staticmethod def _type_name(obj): type_names = { - int: 'I', - float: 'f', - bool: 'B', triton.language.float8: 'f8', torch.bfloat16: 'bf16', torch.float16: 'f16', @@ -477,12 +475,25 @@ class Kernel: torch.int32: 'i32', torch.int64: 'i64', } - return type_names[obj] + if hasattr(obj, 'data_ptr'): + return type_names[obj.dtype] + if isinstance(obj, int): + if abs(obj) <= 0xffffffff: + return 'I' + return 'L' + if isinstance(obj, float): + return 'f' + if isinstance(obj, bool): + return 'B' + assert False + + @staticmethod def _to_triton_ir(context, obj): type_map = { 'I': _triton.ir.type.get_int32, + 'L': _triton.ir.type.get_int64, 'f': _triton.ir.type.get_fp32, 'B': _triton.ir.type.get_int1, 'f8': _triton.ir.type.get_fp8, @@ -498,11 +509,11 @@ class Kernel: } # convert torch.Tensor to Triton IR pointers if hasattr(obj, 'data_ptr'): - name = Kernel._type_name(obj.dtype) + name = Kernel._type_name(obj) elt_ty = type_map[name](context) return _triton.ir.type.make_ptr(elt_ty, 1) # default path returns triton.ir.type directly - name = Kernel._type_name(obj.__class__) + name = Kernel._type_name(obj) return type_map[name](context) @staticmethod @@ -511,7 +522,7 @@ class Kernel: types_key = [None] * len(wargs) for i, arg in enumerate(wargs): prefix = 'P' if i in tensor_idxs else '' - suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__) + suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg) types_key[i] = prefix + suffix return tuple(types_key) @@ -646,7 +657,7 @@ class Kernel: drv_cache[key] = LoadedBinary(device_idx, binary) # pack arguments - fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) + fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)]) params = struct.pack(fmt, *args) # enqueue cached function into stream callable = drv_cache[key] diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 29e867502..f492b33d9 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -9,7 +9,9 @@ def _to_ir(x, builder): if isinstance(x, bool): return builder.get_int1(x) elif isinstance(x, int): - return builder.get_int32(x) + if x.__abs__() <= 2**31: + return builder.get_int32(x) + return builder.get_int64(x) elif isinstance(x, float): return builder.get_float32(x) if isinstance(x, block): @@ -636,6 +638,10 @@ def max_contiguous(input, value, _builder=None): # Standard library # ----------------------- +@triton.jit +def abs(x): + return where(x >= 0, x, -x) + @triton.jit def cdiv(x, div): """ diff --git a/python/triton/language/random.py b/python/triton/language/random.py index de9c1b3b2..414f61cc0 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -128,7 +128,10 @@ def randint4x(seed, offset): :param offsets: The offsets to generate random numbers for. """ z = 0 - return philox_f(offset, z, z, z, seed, z) + seed = hacky_to_uint64(seed) # uint will solve this + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) + seed_lo = (seed & 0xffffffff).to(tl.int32) + return philox_f(offset, z, z, z, seed_lo, seed_hi) @triton.jit