From 3e6cc6d66c2ca92f21fe29529de63a31e59db426 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 26 Oct 2022 17:47:33 -0700 Subject: [PATCH] [FRONTEND] Made more tests pass (#805) --- include/triton/Dialect/Triton/IR/TritonOps.td | 16 ++ .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 113 +++++++- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 3 + lib/Dialect/Triton/IR/Ops.cpp | 4 +- python/src/triton.cc | 34 ++- python/tests/test_core.py | 270 +++++++++--------- python/triton/compiler.py | 1 + python/triton/language/core.py | 9 +- python/triton/language/semantic.py | 19 +- 9 files changed, 303 insertions(+), 166 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 0d1675b39..325ea3500 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -56,6 +56,22 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)"; } +// arith.bitcast doesn't support pointers +def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + NoSideEffect, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast between types of the same bitwidth"; + + let arguments = (ins TT_Type:$from); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)"; + + // TODO: Add verifier +} + def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, NoSideEffect, diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 2fc0505c9..ba98f0bba 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1400,6 +1400,8 @@ struct ExtractSliceOpConversion } }; +// TODO: rewrite Ternary/Binary/Unary as Elementwise + // A CRTP style of base class. template class BinaryOpConversionBase @@ -1470,6 +1472,77 @@ struct BinaryOpConversion Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); } }; +// +// Ternary +// + +template +class TernaryOpConversionBase + : public ConvertTritonGPUOpToLLVMPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit TernaryOpConversionBase(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTy = op.getType().template dyn_cast(); + // ArithmeticToLLVM will handle the lowering of scalar ArithOps + if (!resultTy) + return failure(); + + Location loc = op->getLoc(); + auto resultLayout = + resultTy.getEncoding().template dyn_cast(); + auto resultShape = resultTy.getShape(); + assert(resultLayout && "Unexpected resultLayout in TernaryOpConversion"); + unsigned elems = resultLayout.getElemsPerThread(resultShape); + Type elemTy = + this->getTypeConverter()->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); + + auto *concreteThis = static_cast(this); + auto lhss = + this->getElementsFromStruct(loc, adaptor.getOperands()[0], rewriter); + auto rhss = + this->getElementsFromStruct(loc, adaptor.getOperands()[1], rewriter); + auto thss = + this->getElementsFromStruct(loc, adaptor.getOperands()[2], rewriter); + SmallVector resultVals(elems); + for (unsigned i = 0; i < elems; ++i) { + resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i], + rhss[i], thss[i], loc); + } + Value view = getStructFromElements(loc, resultVals, rewriter, structTy); + rewriter.replaceOp(op, view); + return success(); + } +}; + +template +struct TernaryOpConversion + : public TernaryOpConversionBase> { + + explicit TernaryOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : TernaryOpConversionBase>( + typeConverter, benefit) {} + + using OpAdaptor = typename SourceOp::Adaptor; + // An interface to support variant DestOp builder. + DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter, + Type elemTy, Value lhs, Value rhs, Value th, + Location loc) const { + return rewriter.create(loc, elemTy, lhs, rhs, th); + } +}; + // // Unary // @@ -3590,9 +3663,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + +#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \ + patterns.add>(typeConverter, benefit); + POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp); +#undef POPULATE_TERNARY_OP + #define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ patterns.add>(typeConverter, benefit); - POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp) POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + @@ -3605,24 +3683,31 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) - POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & - POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> #undef POPULATE_BINARY_OP patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); -#define POPULATE_CAST_OP(SRC_OP, DST_OP) \ +#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ patterns.add>(typeConverter, benefit); - POPULATE_CAST_OP(arith::TruncIOp, LLVM::TruncOp) - POPULATE_CAST_OP(arith::TruncFOp, LLVM::FPTruncOp) - POPULATE_CAST_OP(arith::ExtSIOp, LLVM::SExtOp) - POPULATE_CAST_OP(arith::ExtUIOp, LLVM::ZExtOp) - POPULATE_CAST_OP(arith::FPToUIOp, LLVM::FPToUIOp) - POPULATE_CAST_OP(arith::FPToSIOp, LLVM::FPToSIOp) - POPULATE_CAST_OP(arith::UIToFPOp, LLVM::UIToFPOp) - POPULATE_CAST_OP(arith::SIToFPOp, LLVM::SIToFPOp) - POPULATE_CAST_OP(arith::ExtFOp, LLVM::FPExtOp) -#undef POPULATE_CAST_OP + POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) + POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp) + POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) + POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) + POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp) + POPULATE_UNARY_OP(arith::FPToSIOp, LLVM::FPToSIOp) + POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) + POPULATE_UNARY_OP(arith::SIToFPOp, LLVM::SIToFPOp) + POPULATE_UNARY_OP(arith::ExtFOp, LLVM::FPExtOp) + POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) + POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) + POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) +#undef POPULATE_UNARY_OP patterns.add(typeConverter, benefit); patterns.add(typeConverter, allocation, smem, diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 34b7e07f6..576ef735a 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -351,6 +351,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, MLIRContext *context = patterns.getContext(); patterns.add< // TODO: view should have custom pattern that views the layout TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonReducePattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 32b56fb0d..d8db733f4 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -150,9 +150,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other, ::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) { - TensorType ptrType = ptr.getType().dyn_cast(); + TensorType ptrType = ptr.getType().cast(); Type elementType = - ptrType.getElementType().dyn_cast().getPointeeType(); + ptrType.getElementType().cast().getPointeeType(); auto shape = ptrType.getShape(); Type resultType = RankedTensorType::get(shape, elementType); state.addOperands(ptr); diff --git a/python/src/triton.cc b/python/src/triton.cc index 33579b09e..fb893219c 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -441,11 +441,22 @@ void init_triton_ir(py::module &&m) { loc, self.getF32FloatAttr(v)); }) .def("get_null_value", - [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value { + [](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value { auto loc = self.getUnknownLoc(); - if (type.isa()) - return self.create( - loc, self.getF32FloatAttr(0.0)); + if (auto floatTy = type.dyn_cast()) + return self.create( + loc, mlir::APFloat(floatTy.getFloatSemantics(), 0), floatTy); + else if (auto intTy = type.dyn_cast()) + return self.create(loc, 0, intTy); + else + throw std::runtime_error("Not implemented"); + }) + .def("get_all_ones_value", + [](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value { + auto loc = self.getUnknownLoc(); + uint64_t val = 0xFFFFFFFFFFFFFFFF; + if (auto intTy = type.dyn_cast()) + return self.create(loc, val, intTy); else throw std::runtime_error("Not implemented"); }) @@ -602,7 +613,7 @@ void init_triton_ir(py::module &&m) { [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, dstType, src); + return self.create(loc, dstType, src); }) // .def("create_cast", &ir::builder::create_cast) // .def("create_ptr_to_int", &ir::builder::create_ptr_to_int) @@ -1143,6 +1154,18 @@ void init_triton_ir(py::module &&m) { return self.create(loc, resType, redOp, operand, axis); }) + .def("create_ptr_to_int", + [](mlir::OpBuilder &self, mlir::Value &val, + mlir::Type &type) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, type, val); + }) + .def("create_int_to_ptr", + [](mlir::OpBuilder &self, mlir::Value &val, + mlir::Type &type) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, type, val); + }) .def("create_select", [](mlir::OpBuilder &self, mlir::Value &condition, mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value { @@ -1231,7 +1254,6 @@ void init_triton_ir(py::module &&m) { } void init_triton_translation(py::module &m) { - using ret = py::return_value_policy; m.def("get_shared_memory_size", [](mlir::ModuleOp module) { diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 74692edca..e36f436bd 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -281,141 +281,142 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) -# @pytest.mark.parametrize("dtype_x, dtype_y", -# [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + -# [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] -# ) -# def test_floordiv(dtype_x, dtype_y, device='cuda'): -# # Triton has IEEE, not numpy/torch, semantics for %, and those carry -# # through to //, so we have to use a nonstandard expression to get a -# # reference result for //. -# expr = 'x // y' -# numpy_expr = '((x - np.fmod(x, y)) / y)' -# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) +@pytest.mark.parametrize("dtype_x, dtype_y", + [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] + ) +def test_floordiv(dtype_x, dtype_y, device='cuda'): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + expr = 'x // y' + numpy_expr = '((x - np.fmod(x, y)) / y)' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) -# # --------------- -# # test bitwise ops -# # --------------- -# @pytest.mark.parametrize("dtype_x, dtype_y, op", [ -# (dtype_x, dtype_y, op) -# for op in ['&', '|', '^'] -# for dtype_x in dtypes + dtypes_with_bfloat16 -# for dtype_y in dtypes + dtypes_with_bfloat16 -# ]) -# def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): -# expr = f'x {op} y' -# if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): -# numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' -# elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): -# numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' -# else: -# numpy_expr = None -# if 'float' in dtype_x + dtype_y: -# with pytest.raises(triton.CompilationError) as exc_info: -# _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) -# # The CompilationError must have been caused by a C++ exception with this text. -# assert re.match('invalid operands of type', str(exc_info.value.__cause__)) -# else: -# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) +# --------------- +# test bitwise ops +# --------------- +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ + (dtype_x, dtype_y, op) + for op in ['&', '|', '^'] + for dtype_x in dtypes + dtypes_with_bfloat16 + for dtype_y in dtypes + dtypes_with_bfloat16 +]) +def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if 'float' in dtype_x + dtype_y: + with pytest.raises(triton.CompilationError) as exc_info: + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) + # The CompilationError must have been caused by a C++ exception with this text. + assert re.match('invalid operands of type', str(exc_info.value.__cause__)) + else: + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) -# @pytest.mark.parametrize("dtype_x, dtype_y, op", [ -# (dtype_x, dtype_y, op) -# for op in ['<<', '>>'] -# for dtype_x in int_dtypes + uint_dtypes -# for dtype_y in int_dtypes + uint_dtypes -# ]) -# def test_shift_op(dtype_x, dtype_y, op, device='cuda'): -# expr = f'x {op} y' -# bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) -# dtype_z = f'uint{bw}' -# numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' -# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65) +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ + (dtype_x, dtype_y, op) + for op in ['<<', '>>'] + for dtype_x in int_dtypes + uint_dtypes + for dtype_y in int_dtypes + uint_dtypes +]) +def test_shift_op(dtype_x, dtype_y, op, device='cuda'): + expr = f'x {op} y' + bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) + dtype_z = f'uint{bw}' + numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65) -# # --------------- -# # test compare ops -# # --------------- -# ops = ['==', '!=', '>', '<', '>=', '<='] +# --------------- +# test compare ops +# --------------- +ops = ['==', '!=', '>', '<', '>=', '<='] -# @pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", -# # real -# [ -# (dtype_x, dtype_y, op, 'real', 'real') -# for op in ops -# for dtype_x in dtypes -# for dtype_y in dtypes -# ] + -# # NaNs -# [('float32', 'float32', op, mode_x, mode_y) -# for op in ops -# for mode_x, mode_y in [('nan', 'real'), -# ('real', 'nan'), -# ('nan', 'nan')] +@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", + # real + [ + (dtype_x, dtype_y, op, 'real', 'real') + for op in ops + for dtype_x in dtypes + for dtype_y in dtypes + ] + + # NaNs + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), + ('real', 'nan'), + ('nan', 'nan')] -# ]) -# def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): -# expr = f'x {op} y' -# if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): -# numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' -# elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): -# numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' -# else: -# numpy_expr = None -# _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) + ]) +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) -# # --------------- -# # test where -# # --------------- -# @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) -# def test_where(dtype): -# select_ptrs = False -# if dtype == "*int32": -# dtype = "int64" -# select_ptrs = True -# check_type_supported(dtype) +# --------------- +# test where +# --------------- +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) +def test_where(dtype): + select_ptrs = False + if dtype == "*int32": + dtype = "int64" + select_ptrs = True + check_type_supported(dtype) -# @triton.jit -# def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, -# BLOCK_SIZE: tl.constexpr, -# TEST_POINTERS: tl.constexpr): -# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask = offsets < n_elements -# decide = tl.load(cond_ptr + offsets, mask=mask) -# if TEST_POINTERS: -# a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) -# b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) -# else: -# a = tl.load(a_ptr + offsets, mask=mask) -# b = tl.load(b_ptr + offsets, mask=mask) -# output = tl.where(decide, a, b) -# tl.store(output_ptr + offsets, output, mask=mask) + @triton.jit + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, + BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + decide = tl.load(cond_ptr + offsets, mask=mask) + if TEST_POINTERS: + a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) + b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) + else: + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + output = tl.where(decide, a, b) + tl.store(output_ptr + offsets, output, mask=mask) -# SIZE = 1_000 -# rs = RandomState(17) -# cond = numpy_random(SIZE, 'bool', rs) -# x = numpy_random(SIZE, dtype_str=dtype, rs=rs) -# y = numpy_random(SIZE, dtype_str=dtype, rs=rs) -# z = np.where(cond, x, y) + SIZE = 1_000 + rs = RandomState(17) + cond = numpy_random(SIZE, 'bool', rs) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + z = np.where(cond, x, y) -# cond_tri = to_triton(cond, device='cuda') -# x_tri = to_triton(x, device='cuda', dst_type=dtype) -# y_tri = to_triton(y, device='cuda', dst_type=dtype) -# z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype) + cond_tri = to_triton(cond, device='cuda') + x_tri = to_triton(x, device='cuda', dst_type=dtype) + y_tri = to_triton(y, device='cuda', dst_type=dtype) + z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype) -# grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) -# where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs) -# assert (z == to_numpy(z_tri)).all() + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs) + assert (z == to_numpy(z_tri)).all() +# TODO: wrong result # def test_where_broadcast(): # @triton.jit # def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): -# xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1]) -# yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE]) +# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] +# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] # mask = tl.load(cond_ptr + yoffsets) # vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) @@ -424,8 +425,8 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): # @triton.jit # def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): -# xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1]) -# yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE]) +# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] +# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] # mask = 0 # vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) # res = tl.where(mask, vals, 0.) @@ -451,17 +452,19 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): # # --------------- -# @pytest.mark.parametrize("dtype_x, expr", [ -# (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16 -# ] + [ -# (dtype_x, ' ~x') for dtype_x in int_dtypes -# ]) -# def test_unary_op(dtype_x, expr, device='cuda'): -# _test_unary(dtype_x, expr, device=device) +@pytest.mark.parametrize("dtype_x, expr", [ + (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16 +] + [ + (dtype_x, ' ~x') for dtype_x in int_dtypes +]) +def test_unary_op(dtype_x, expr, device='cuda'): + _test_unary(dtype_x, expr, device=device) # # ---------------- # # test math ops # # ---------------- + +# TODO: Math module # # @pytest.mark.parametrize("expr", [ # # 'exp', 'log', 'cos', 'sin' # # ]) @@ -479,17 +482,18 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): # # ---------------- -# def make_ptr_str(name, shape): -# rank = len(shape) -# offsets = [] -# stride = 1 -# for i in reversed(range(rank)): -# idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) -# offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] -# stride *= shape[i] -# return f"{name} + {' + '.join(offsets)}" +def make_ptr_str(name, shape): + rank = len(shape) + offsets = [] + stride = 1 + for i in reversed(range(rank)): + idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) + offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] + stride *= shape[i] + return f"{name} + {' + '.join(offsets)}" +# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` # @pytest.mark.parametrize("expr, dtype_str", [ # (f'x[{s}]', d) # for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] diff --git a/python/triton/compiler.py b/python/triton/compiler.py index a48d1ec25..2ef5293e4 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -45,6 +45,7 @@ def str_to_ty(name): "u32": triton.language.uint32, "u64": triton.language.uint64, "B": triton.language.int1, + "i1": triton.language.int1, } return tys[name] diff --git a/python/triton/language/core.py b/python/triton/language/core.py index c958fa7f1..e7f6a744d 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -729,9 +729,10 @@ def cat(input, other, _builder=None): @builtin -def reshape(input, shape, _builder=None): +def view(input, shape, _builder=None): """ - Tries to reshape the given tensor to a new shape. + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. :param input: The input tensor. :type input: @@ -740,7 +741,7 @@ def reshape(input, shape, _builder=None): """ shape = [x.value for x in shape] - return semantic.reshape(input, shape, _builder) + return semantic.view(input, shape, _builder) # ----------------------- @@ -1151,7 +1152,7 @@ def ravel(x): :param x: the input tensor :type x: Block """ - return triton.language.reshape(x, [x.numel]) + return triton.language.view(x, [x.numel]) @triton.jit diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 619fad4ee..b7fda1736 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -345,7 +345,7 @@ def invert(input: tl.tensor, input_sca_ty = input.type.scalar if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") - _1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) return xor_(input, _1, builder) @@ -481,11 +481,13 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: def view(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor: + # TODO: disable when TritonToTritonGPU handles views properly + assert len(input.shape) == len(dst_shape) numel = 1 for s in dst_shape: numel *= s if input.type.numel != numel: - raise ValueError("cannot reshape block of different shape") + raise ValueError("cannot view block of different shape") ret_ty = tl.block_type(input.type.scalar, dst_shape) return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty) @@ -516,7 +518,7 @@ def broadcast_impl_shape(input: tl.tensor, for i in range(len(src_shape)): if shape[i] != src_shape[i] and src_shape[i] != 1: raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" - f" must match the existing size ({src_shape[1]}) at non-singleton dimension" + f" must match the existing size ({src_shape[i]}) at non-singleton dimension" f" {i}: {src_shape}, {shape}") ret_ty = tl.block_type(input.type.scalar, shape) return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) @@ -679,7 +681,7 @@ def cast(input: tl.tensor, if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): bitwidth = dst_sca_ty.int_bitwidth if bitwidth == 64: - return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)), + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) if bitwidth == 1: return not_equal(cast(input, tl.int64, builder), @@ -989,18 +991,21 @@ def where(condition: tl.tensor, builder: ir.builder) -> tl.tensor: condition = cast(condition, tl.int1, builder) if condition.type.is_block(): - x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) - y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) + condition, x = broadcast_impl_value(condition, x, builder) + x, y = broadcast_impl_value(x, y, builder) + condition, x = broadcast_impl_value(condition, x, builder) x, y = binary_op_type_checking_impl(x, y, builder, True, True) + if not condition.type.is_block(): + condition, _ = broadcast_impl_value(condition, x, builder) ret_ty = x.type return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) - # ===----------------------------------------------------------------------===// # Reductions # ===----------------------------------------------------------------------=== + def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor: scalar_ty = input.type.scalar