From 59b0ac672a5df3a18646ced6ba3993b750bf0039 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 21 May 2021 02:47:53 -0400 Subject: [PATCH] [LANGUAGE] Added support for bitcast (#119) --- include/triton/ir/dispatch.h | 1 + lib/ir/dispatch.cc | 17 +++++++++++++++++ python/src/triton.cc | 1 + python/test/test_language.py | 32 ++++++++++++++++++++++++++++++++ python/triton/language.py | 9 ++++++--- 5 files changed, 57 insertions(+), 3 deletions(-) diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index 6d06baff1..f9ada21ef 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -62,6 +62,7 @@ struct dispatch{ static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder); static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder); static std::tuple broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder); + static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder); static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); // memory operators diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 5f6d66916..c3cd71c5d 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -414,6 +414,23 @@ std::tuple dispatch::broadcast(ir::value *lhs, ir::value return std::make_tuple(lhs, rhs); } +ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){ + ir::type *src_ty = input->get_type(); + if (src_ty->is_block_ty()) + dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); + if(src_ty == dst_ty) + return input; + ir::type *src_sca_ty = src_ty->get_scalar_ty(); + ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); + // Bitcast + int src_bits = src_sca_ty->get_primitive_size_in_bits(); + int dst_bits = dst_sca_ty->get_primitive_size_in_bits(); + if( src_bits!= dst_bits) + throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) + + "to data-type of size " + std::to_string(dst_bits)); + return builder->create_cast(ir::BitCast, input, dst_ty); +} + ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) { ir::type *src_ty = input->get_type(); if (src_ty->is_block_ty()) diff --git a/python/src/triton.cc b/python/src/triton.cc index c07fade36..43a8e5642 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -129,6 +129,7 @@ void init_triton_frontend(py::module &&m) { typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *); m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference); m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference); + m.def("bitcast", &ir::dispatch::bitcast, ret::reference); m.def("cast", &ir::dispatch::cast, ret::reference); // memory m.def("load", &ir::dispatch::load, ret::reference); diff --git a/python/test/test_language.py b/python/test/test_language.py index 4ef7302b4..35cbe7b0e 100644 --- a/python/test/test_language.py +++ b/python/test/test_language.py @@ -263,6 +263,38 @@ def test_atomic_add(dtype_x, device='cuda'): triton.testing.assert_allclose(x_ref, last_sum) +# --------------- +# test cast +# --------------- +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ + (dtype_x, dtype_z, False) for dtype_x in dtypes \ + for dtype_z in dtypes +] + [ + ('float32', 'int32', True) +]) +def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): + x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device) + + # triton kernel + @triton.jit + def kernel(X, Z, **meta): + x = tl.load(X) + z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST']) + tl.store(Z, z) + + # triton result + z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device) + kernel[(1, )](x, z_tri, BITCAST=bitcast) + # torch result + if bitcast: + import numpy as np + z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z)) + z_ref = torch.from_numpy(z_ref).to(device) + else: + z_ref = x.to(z_tri.dtype) + assert z_tri == z_ref + + # --------------- # test load # --------------- diff --git a/python/triton/language.py b/python/triton/language.py index a3ab0df54..e5791cebf 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -76,7 +76,7 @@ class pointer_dtype: self.element_ty = element_ty def handle(self, builder): - return ir.type.make_ptr(self.element_ty, 1) + return ir.type.make_ptr(self.element_ty.handle(builder), 1) int1 = dtype(ir.type.get_int1) @@ -249,8 +249,11 @@ class block: return ret @builtin - def to(self, dtype, builder=None): - return frontend.cast(self, dtype.handle(builder), builder) + def to(self, dtype, bitcast=False, builder=None): + dtype = dtype.handle(builder) + if bitcast: + return frontend.bitcast(self, dtype, builder) + return frontend.cast(self, dtype, builder) # -----------------------