[FRONTEND] Made more tests pass (#805)

This commit is contained in:
Philippe Tillet
2022-10-26 17:47:33 -07:00
committed by GitHub
parent bb7008651a
commit 3e6cc6d66c
9 changed files with 303 additions and 166 deletions

View File

@@ -56,6 +56,22 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)"; let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
} }
// arith.bitcast doesn't support pointers
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast between types of the same bitwidth";
let arguments = (ins TT_Type:$from);
let results = (outs TT_Type:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
// TODO: Add verifier
}
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding, SameOperandsAndResultEncoding,
NoSideEffect, NoSideEffect,

View File

@@ -1400,6 +1400,8 @@ struct ExtractSliceOpConversion
} }
}; };
// TODO: rewrite Ternary/Binary/Unary as Elementwise
// A CRTP style of base class. // A CRTP style of base class.
template <typename SourceOp, typename DestOp, typename ConcreteT> template <typename SourceOp, typename DestOp, typename ConcreteT>
class BinaryOpConversionBase class BinaryOpConversionBase
@@ -1470,6 +1472,77 @@ struct BinaryOpConversion
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); } Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
}; };
//
// Ternary
//
template <typename SourceOp, typename DestOp, typename ConcreteT>
class TernaryOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit TernaryOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
if (!resultTy)
return failure();
Location loc = op->getLoc();
auto resultLayout =
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
assert(resultLayout && "Unexpected resultLayout in TernaryOpConversion");
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto lhss =
this->getElementsFromStruct(loc, adaptor.getOperands()[0], rewriter);
auto rhss =
this->getElementsFromStruct(loc, adaptor.getOperands()[1], rewriter);
auto thss =
this->getElementsFromStruct(loc, adaptor.getOperands()[2], rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
rhss[i], thss[i], loc);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
template <typename SourceOp, typename DestOp>
struct TernaryOpConversion
: public TernaryOpConversionBase<SourceOp, DestOp,
TernaryOpConversion<SourceOp, DestOp>> {
explicit TernaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: TernaryOpConversionBase<SourceOp, DestOp,
TernaryOpConversion<SourceOp, DestOp>>(
typeConverter, benefit) {}
using OpAdaptor = typename SourceOp::Adaptor;
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
Type elemTy, Value lhs, Value rhs, Value th,
Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs, th);
}
};
// //
// Unary // Unary
// //
@@ -3590,9 +3663,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit); benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit); patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit); patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<TernaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp);
#undef POPULATE_TERNARY_OP
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ #define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit); patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp) POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
@@ -3605,24 +3683,31 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
#undef POPULATE_BINARY_OP #undef POPULATE_BINARY_OP
patterns.add<CmpIOpConversion>(typeConverter, benefit); patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit); patterns.add<CmpFOpConversion>(typeConverter, benefit);
#define POPULATE_CAST_OP(SRC_OP, DST_OP) \ #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit); patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_CAST_OP(arith::TruncIOp, LLVM::TruncOp) POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_CAST_OP(arith::TruncFOp, LLVM::FPTruncOp) POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
POPULATE_CAST_OP(arith::ExtSIOp, LLVM::SExtOp) POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_CAST_OP(arith::ExtUIOp, LLVM::ZExtOp) POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
POPULATE_CAST_OP(arith::FPToUIOp, LLVM::FPToUIOp) POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp)
POPULATE_CAST_OP(arith::FPToSIOp, LLVM::FPToSIOp) POPULATE_UNARY_OP(arith::FPToSIOp, LLVM::FPToSIOp)
POPULATE_CAST_OP(arith::UIToFPOp, LLVM::UIToFPOp) POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
POPULATE_CAST_OP(arith::SIToFPOp, LLVM::SIToFPOp) POPULATE_UNARY_OP(arith::SIToFPOp, LLVM::SIToFPOp)
POPULATE_CAST_OP(arith::ExtFOp, LLVM::FPExtOp) POPULATE_UNARY_OP(arith::ExtFOp, LLVM::FPExtOp)
#undef POPULATE_CAST_OP POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP
patterns.add<BroadcastOpConversion>(typeConverter, benefit); patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem, patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,

View File

@@ -351,6 +351,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
patterns.add< // TODO: view should have custom pattern that views the layout patterns.add< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::ViewOp>, TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::BitcastOp>,
TritonGenericPattern<triton::IntToPtrOp>,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern, TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern, TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,

View File

@@ -150,9 +150,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other, ::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
::mlir::triton::CacheModifier cache, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) { ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>(); TensorType ptrType = ptr.getType().cast<TensorType>();
Type elementType = Type elementType =
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType(); ptrType.getElementType().cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape(); auto shape = ptrType.getShape();
Type resultType = RankedTensorType::get(shape, elementType); Type resultType = RankedTensorType::get(shape, elementType);
state.addOperands(ptr); state.addOperands(ptr);

View File

@@ -441,11 +441,22 @@ void init_triton_ir(py::module &&m) {
loc, self.getF32FloatAttr(v)); loc, self.getF32FloatAttr(v));
}) })
.def("get_null_value", .def("get_null_value",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value { [](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
if (type.isa<mlir::FloatType>()) if (auto floatTy = type.dyn_cast<mlir::FloatType>())
return self.create<mlir::arith::ConstantOp>( return self.create<mlir::arith::ConstantFloatOp>(
loc, self.getF32FloatAttr(0.0)); loc, mlir::APFloat(floatTy.getFloatSemantics(), 0), floatTy);
else if (auto intTy = type.dyn_cast<mlir::IntegerType>())
return self.create<mlir::arith::ConstantIntOp>(loc, 0, intTy);
else
throw std::runtime_error("Not implemented");
})
.def("get_all_ones_value",
[](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
auto loc = self.getUnknownLoc();
uint64_t val = 0xFFFFFFFFFFFFFFFF;
if (auto intTy = type.dyn_cast<mlir::IntegerType>())
return self.create<mlir::arith::ConstantIntOp>(loc, val, intTy);
else else
throw std::runtime_error("Not implemented"); throw std::runtime_error("Not implemented");
}) })
@@ -602,7 +613,7 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Value &src, [](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value { mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return self.create<mlir::arith::BitcastOp>(loc, dstType, src); return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
}) })
// .def("create_cast", &ir::builder::create_cast) // .def("create_cast", &ir::builder::create_cast)
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int) // .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
@@ -1143,6 +1154,18 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp, return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
operand, axis); operand, axis);
}) })
.def("create_ptr_to_int",
[](mlir::OpBuilder &self, mlir::Value &val,
mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::PtrToIntOp>(loc, type, val);
})
.def("create_int_to_ptr",
[](mlir::OpBuilder &self, mlir::Value &val,
mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::IntToPtrOp>(loc, type, val);
})
.def("create_select", .def("create_select",
[](mlir::OpBuilder &self, mlir::Value &condition, [](mlir::OpBuilder &self, mlir::Value &condition,
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value { mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
@@ -1231,7 +1254,6 @@ void init_triton_ir(py::module &&m) {
} }
void init_triton_translation(py::module &m) { void init_triton_translation(py::module &m) {
using ret = py::return_value_policy; using ret = py::return_value_policy;
m.def("get_shared_memory_size", [](mlir::ModuleOp module) { m.def("get_shared_memory_size", [](mlir::ModuleOp module) {

View File

@@ -281,141 +281,142 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# @pytest.mark.parametrize("dtype_x, dtype_y", @pytest.mark.parametrize("dtype_x, dtype_y",
# [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
# [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
# ) )
# def test_floordiv(dtype_x, dtype_y, device='cuda'): def test_floordiv(dtype_x, dtype_y, device='cuda'):
# # Triton has IEEE, not numpy/torch, semantics for %, and those carry # Triton has IEEE, not numpy/torch, semantics for %, and those carry
# # through to //, so we have to use a nonstandard expression to get a # through to //, so we have to use a nonstandard expression to get a
# # reference result for //. # reference result for //.
# expr = 'x // y' expr = 'x // y'
# numpy_expr = '((x - np.fmod(x, y)) / y)' numpy_expr = '((x - np.fmod(x, y)) / y)'
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# # --------------- # ---------------
# # test bitwise ops # test bitwise ops
# # --------------- # ---------------
# @pytest.mark.parametrize("dtype_x, dtype_y, op", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
# (dtype_x, dtype_y, op) (dtype_x, dtype_y, op)
# for op in ['&', '|', '^'] for op in ['&', '|', '^']
# for dtype_x in dtypes + dtypes_with_bfloat16 for dtype_x in dtypes + dtypes_with_bfloat16
# for dtype_y in dtypes + dtypes_with_bfloat16 for dtype_y in dtypes + dtypes_with_bfloat16
# ]) ])
# def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
# expr = f'x {op} y' expr = f'x {op} y'
# if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
# numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
# elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
# numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
# else: else:
# numpy_expr = None numpy_expr = None
# if 'float' in dtype_x + dtype_y: if 'float' in dtype_x + dtype_y:
# with pytest.raises(triton.CompilationError) as exc_info: with pytest.raises(triton.CompilationError) as exc_info:
# _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# # The CompilationError must have been caused by a C++ exception with this text. # The CompilationError must have been caused by a C++ exception with this text.
# assert re.match('invalid operands of type', str(exc_info.value.__cause__)) assert re.match('invalid operands of type', str(exc_info.value.__cause__))
# else: else:
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# @pytest.mark.parametrize("dtype_x, dtype_y, op", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
# (dtype_x, dtype_y, op) (dtype_x, dtype_y, op)
# for op in ['<<', '>>'] for op in ['<<', '>>']
# for dtype_x in int_dtypes + uint_dtypes for dtype_x in int_dtypes + uint_dtypes
# for dtype_y in int_dtypes + uint_dtypes for dtype_y in int_dtypes + uint_dtypes
# ]) ])
# def test_shift_op(dtype_x, dtype_y, op, device='cuda'): def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
# expr = f'x {op} y' expr = f'x {op} y'
# bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
# dtype_z = f'uint{bw}' dtype_z = f'uint{bw}'
# numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65) _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
# # --------------- # ---------------
# # test compare ops # test compare ops
# # --------------- # ---------------
# ops = ['==', '!=', '>', '<', '>=', '<='] ops = ['==', '!=', '>', '<', '>=', '<=']
# @pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", @pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
# # real # real
# [ [
# (dtype_x, dtype_y, op, 'real', 'real') (dtype_x, dtype_y, op, 'real', 'real')
# for op in ops for op in ops
# for dtype_x in dtypes for dtype_x in dtypes
# for dtype_y in dtypes for dtype_y in dtypes
# ] + ] +
# # NaNs # NaNs
# [('float32', 'float32', op, mode_x, mode_y) [('float32', 'float32', op, mode_x, mode_y)
# for op in ops for op in ops
# for mode_x, mode_y in [('nan', 'real'), for mode_x, mode_y in [('nan', 'real'),
# ('real', 'nan'), ('real', 'nan'),
# ('nan', 'nan')] ('nan', 'nan')]
# ]) ])
# def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
# expr = f'x {op} y' expr = f'x {op} y'
# if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
# numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
# elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
# numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
# else: else:
# numpy_expr = None numpy_expr = None
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
# # --------------- # ---------------
# # test where # test where
# # --------------- # ---------------
# @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"])
# def test_where(dtype): def test_where(dtype):
# select_ptrs = False select_ptrs = False
# if dtype == "*int32": if dtype == "*int32":
# dtype = "int64" dtype = "int64"
# select_ptrs = True select_ptrs = True
# check_type_supported(dtype) check_type_supported(dtype)
# @triton.jit @triton.jit
# def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
# BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
# TEST_POINTERS: tl.constexpr): TEST_POINTERS: tl.constexpr):
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# mask = offsets < n_elements mask = offsets < n_elements
# decide = tl.load(cond_ptr + offsets, mask=mask) decide = tl.load(cond_ptr + offsets, mask=mask)
# if TEST_POINTERS: if TEST_POINTERS:
# a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
# b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
# else: else:
# a = tl.load(a_ptr + offsets, mask=mask) a = tl.load(a_ptr + offsets, mask=mask)
# b = tl.load(b_ptr + offsets, mask=mask) b = tl.load(b_ptr + offsets, mask=mask)
# output = tl.where(decide, a, b) output = tl.where(decide, a, b)
# tl.store(output_ptr + offsets, output, mask=mask) tl.store(output_ptr + offsets, output, mask=mask)
# SIZE = 1_000 SIZE = 1_000
# rs = RandomState(17) rs = RandomState(17)
# cond = numpy_random(SIZE, 'bool', rs) cond = numpy_random(SIZE, 'bool', rs)
# x = numpy_random(SIZE, dtype_str=dtype, rs=rs) x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
# y = numpy_random(SIZE, dtype_str=dtype, rs=rs) y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
# z = np.where(cond, x, y) z = np.where(cond, x, y)
# cond_tri = to_triton(cond, device='cuda') cond_tri = to_triton(cond, device='cuda')
# x_tri = to_triton(x, device='cuda', dst_type=dtype) x_tri = to_triton(x, device='cuda', dst_type=dtype)
# y_tri = to_triton(y, device='cuda', dst_type=dtype) y_tri = to_triton(y, device='cuda', dst_type=dtype)
# z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype) z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
# grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
# where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs) where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
# assert (z == to_numpy(z_tri)).all() assert (z == to_numpy(z_tri)).all()
# TODO: wrong result
# def test_where_broadcast(): # def test_where_broadcast():
# @triton.jit # @triton.jit
# def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): # def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1]) # xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
# yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE]) # yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# mask = tl.load(cond_ptr + yoffsets) # mask = tl.load(cond_ptr + yoffsets)
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) # vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
@@ -424,8 +425,8 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
# @triton.jit # @triton.jit
# def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): # def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1]) # xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
# yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE]) # yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# mask = 0 # mask = 0
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) # vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
# res = tl.where(mask, vals, 0.) # res = tl.where(mask, vals, 0.)
@@ -451,17 +452,19 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
# # --------------- # # ---------------
# @pytest.mark.parametrize("dtype_x, expr", [ @pytest.mark.parametrize("dtype_x, expr", [
# (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16 (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16
# ] + [ ] + [
# (dtype_x, ' ~x') for dtype_x in int_dtypes (dtype_x, ' ~x') for dtype_x in int_dtypes
# ]) ])
# def test_unary_op(dtype_x, expr, device='cuda'): def test_unary_op(dtype_x, expr, device='cuda'):
# _test_unary(dtype_x, expr, device=device) _test_unary(dtype_x, expr, device=device)
# # ---------------- # # ----------------
# # test math ops # # test math ops
# # ---------------- # # ----------------
# TODO: Math module
# # @pytest.mark.parametrize("expr", [ # # @pytest.mark.parametrize("expr", [
# # 'exp', 'log', 'cos', 'sin' # # 'exp', 'log', 'cos', 'sin'
# # ]) # # ])
@@ -479,17 +482,18 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
# # ---------------- # # ----------------
# def make_ptr_str(name, shape): def make_ptr_str(name, shape):
# rank = len(shape) rank = len(shape)
# offsets = [] offsets = []
# stride = 1 stride = 1
# for i in reversed(range(rank)): for i in reversed(range(rank)):
# idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
# offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
# stride *= shape[i] stride *= shape[i]
# return f"{name} + {' + '.join(offsets)}" return f"{name} + {' + '.join(offsets)}"
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
# @pytest.mark.parametrize("expr, dtype_str", [ # @pytest.mark.parametrize("expr, dtype_str", [
# (f'x[{s}]', d) # (f'x[{s}]', d)
# for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] # for s in ['None, :', ':, None', 'None, :, :', ':, :, None']

View File

@@ -45,6 +45,7 @@ def str_to_ty(name):
"u32": triton.language.uint32, "u32": triton.language.uint32,
"u64": triton.language.uint64, "u64": triton.language.uint64,
"B": triton.language.int1, "B": triton.language.int1,
"i1": triton.language.int1,
} }
return tys[name] return tys[name]

View File

@@ -729,9 +729,10 @@ def cat(input, other, _builder=None):
@builtin @builtin
def reshape(input, shape, _builder=None): def view(input, shape, _builder=None):
""" """
Tries to reshape the given tensor to a new shape. Returns a tensor with the same elements as `input` but a different shape.
The order of the elements may not be preserved.
:param input: The input tensor. :param input: The input tensor.
:type input: :type input:
@@ -740,7 +741,7 @@ def reshape(input, shape, _builder=None):
""" """
shape = [x.value for x in shape] shape = [x.value for x in shape]
return semantic.reshape(input, shape, _builder) return semantic.view(input, shape, _builder)
# ----------------------- # -----------------------
@@ -1151,7 +1152,7 @@ def ravel(x):
:param x: the input tensor :param x: the input tensor
:type x: Block :type x: Block
""" """
return triton.language.reshape(x, [x.numel]) return triton.language.view(x, [x.numel])
@triton.jit @triton.jit

View File

@@ -345,7 +345,7 @@ def invert(input: tl.tensor,
input_sca_ty = input.type.scalar input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
_1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
return xor_(input, _1, builder) return xor_(input, _1, builder)
@@ -481,11 +481,13 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
def view(input: tl.tensor, def view(input: tl.tensor,
dst_shape: List[int], dst_shape: List[int],
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
# TODO: disable when TritonToTritonGPU handles views properly
assert len(input.shape) == len(dst_shape)
numel = 1 numel = 1
for s in dst_shape: for s in dst_shape:
numel *= s numel *= s
if input.type.numel != numel: if input.type.numel != numel:
raise ValueError("cannot reshape block of different shape") raise ValueError("cannot view block of different shape")
ret_ty = tl.block_type(input.type.scalar, dst_shape) ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty) return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
@@ -516,7 +518,7 @@ def broadcast_impl_shape(input: tl.tensor,
for i in range(len(src_shape)): for i in range(len(src_shape)):
if shape[i] != src_shape[i] and src_shape[i] != 1: if shape[i] != src_shape[i] and src_shape[i] != 1:
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
f" must match the existing size ({src_shape[1]}) at non-singleton dimension" f" must match the existing size ({src_shape[i]}) at non-singleton dimension"
f" {i}: {src_shape}, {shape}") f" {i}: {src_shape}, {shape}")
ret_ty = tl.block_type(input.type.scalar, shape) ret_ty = tl.block_type(input.type.scalar, shape)
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
@@ -679,7 +681,7 @@ def cast(input: tl.tensor,
if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
bitwidth = dst_sca_ty.int_bitwidth bitwidth = dst_sca_ty.int_bitwidth
if bitwidth == 64: if bitwidth == 64:
return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)), return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
dst_ty) dst_ty)
if bitwidth == 1: if bitwidth == 1:
return not_equal(cast(input, tl.int64, builder), return not_equal(cast(input, tl.int64, builder),
@@ -989,18 +991,21 @@ def where(condition: tl.tensor,
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
condition = cast(condition, tl.int1, builder) condition = cast(condition, tl.int1, builder)
if condition.type.is_block(): if condition.type.is_block():
x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) condition, x = broadcast_impl_value(condition, x, builder)
y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) x, y = broadcast_impl_value(x, y, builder)
condition, x = broadcast_impl_value(condition, x, builder)
x, y = binary_op_type_checking_impl(x, y, builder, True, True) x, y = binary_op_type_checking_impl(x, y, builder, True, True)
if not condition.type.is_block():
condition, _ = broadcast_impl_value(condition, x, builder)
ret_ty = x.type ret_ty = x.type
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
# ===----------------------------------------------------------------------===// # ===----------------------------------------------------------------------===//
# Reductions # Reductions
# ===----------------------------------------------------------------------=== # ===----------------------------------------------------------------------===
def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor: FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor:
scalar_ty = input.type.scalar scalar_ty = input.type.scalar