[FRONTEND] Made more tests pass (#805)
This commit is contained in:
@@ -56,6 +56,22 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
|||||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// arith.bitcast doesn't support pointers
|
||||||
|
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultEncoding,
|
||||||
|
NoSideEffect,
|
||||||
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||||
|
let summary = "Cast between types of the same bitwidth";
|
||||||
|
|
||||||
|
let arguments = (ins TT_Type:$from);
|
||||||
|
|
||||||
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||||
|
|
||||||
|
// TODO: Add verifier
|
||||||
|
}
|
||||||
|
|
||||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultEncoding,
|
SameOperandsAndResultEncoding,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
|
@@ -1400,6 +1400,8 @@ struct ExtractSliceOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: rewrite Ternary/Binary/Unary as Elementwise
|
||||||
|
|
||||||
// A CRTP style of base class.
|
// A CRTP style of base class.
|
||||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||||
class BinaryOpConversionBase
|
class BinaryOpConversionBase
|
||||||
@@ -1470,6 +1472,77 @@ struct BinaryOpConversion
|
|||||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
|
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// Ternary
|
||||||
|
//
|
||||||
|
|
||||||
|
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||||
|
class TernaryOpConversionBase
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||||
|
public:
|
||||||
|
using OpAdaptor = typename SourceOp::Adaptor;
|
||||||
|
|
||||||
|
explicit TernaryOpConversionBase(LLVMTypeConverter &typeConverter,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
|
||||||
|
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
|
||||||
|
if (!resultTy)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
auto resultLayout =
|
||||||
|
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
||||||
|
auto resultShape = resultTy.getShape();
|
||||||
|
assert(resultLayout && "Unexpected resultLayout in TernaryOpConversion");
|
||||||
|
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
||||||
|
Type elemTy =
|
||||||
|
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||||
|
SmallVector<Type> types(elems, elemTy);
|
||||||
|
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||||
|
|
||||||
|
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||||
|
auto lhss =
|
||||||
|
this->getElementsFromStruct(loc, adaptor.getOperands()[0], rewriter);
|
||||||
|
auto rhss =
|
||||||
|
this->getElementsFromStruct(loc, adaptor.getOperands()[1], rewriter);
|
||||||
|
auto thss =
|
||||||
|
this->getElementsFromStruct(loc, adaptor.getOperands()[2], rewriter);
|
||||||
|
SmallVector<Value> resultVals(elems);
|
||||||
|
for (unsigned i = 0; i < elems; ++i) {
|
||||||
|
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
|
||||||
|
rhss[i], thss[i], loc);
|
||||||
|
}
|
||||||
|
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
|
rewriter.replaceOp(op, view);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SourceOp, typename DestOp>
|
||||||
|
struct TernaryOpConversion
|
||||||
|
: public TernaryOpConversionBase<SourceOp, DestOp,
|
||||||
|
TernaryOpConversion<SourceOp, DestOp>> {
|
||||||
|
|
||||||
|
explicit TernaryOpConversion(LLVMTypeConverter &typeConverter,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: TernaryOpConversionBase<SourceOp, DestOp,
|
||||||
|
TernaryOpConversion<SourceOp, DestOp>>(
|
||||||
|
typeConverter, benefit) {}
|
||||||
|
|
||||||
|
using OpAdaptor = typename SourceOp::Adaptor;
|
||||||
|
// An interface to support variant DestOp builder.
|
||||||
|
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
|
||||||
|
Type elemTy, Value lhs, Value rhs, Value th,
|
||||||
|
Location loc) const {
|
||||||
|
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs, th);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// Unary
|
// Unary
|
||||||
//
|
//
|
||||||
@@ -3590,9 +3663,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
benefit);
|
benefit);
|
||||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
||||||
|
|
||||||
|
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
|
||||||
|
patterns.add<TernaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||||
|
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp);
|
||||||
|
#undef POPULATE_TERNARY_OP
|
||||||
|
|
||||||
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
|
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
|
||||||
patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||||
|
|
||||||
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
|
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
|
||||||
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
|
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
|
||||||
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
|
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
|
||||||
@@ -3605,24 +3683,31 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
|
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
|
||||||
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
|
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
|
||||||
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
|
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
|
||||||
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
||||||
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
||||||
|
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
|
||||||
|
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
|
||||||
|
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
|
||||||
|
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
|
||||||
#undef POPULATE_BINARY_OP
|
#undef POPULATE_BINARY_OP
|
||||||
|
|
||||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||||
#define POPULATE_CAST_OP(SRC_OP, DST_OP) \
|
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||||
patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||||
POPULATE_CAST_OP(arith::TruncIOp, LLVM::TruncOp)
|
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
|
||||||
POPULATE_CAST_OP(arith::TruncFOp, LLVM::FPTruncOp)
|
POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
|
||||||
POPULATE_CAST_OP(arith::ExtSIOp, LLVM::SExtOp)
|
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
|
||||||
POPULATE_CAST_OP(arith::ExtUIOp, LLVM::ZExtOp)
|
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
|
||||||
POPULATE_CAST_OP(arith::FPToUIOp, LLVM::FPToUIOp)
|
POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp)
|
||||||
POPULATE_CAST_OP(arith::FPToSIOp, LLVM::FPToSIOp)
|
POPULATE_UNARY_OP(arith::FPToSIOp, LLVM::FPToSIOp)
|
||||||
POPULATE_CAST_OP(arith::UIToFPOp, LLVM::UIToFPOp)
|
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
|
||||||
POPULATE_CAST_OP(arith::SIToFPOp, LLVM::SIToFPOp)
|
POPULATE_UNARY_OP(arith::SIToFPOp, LLVM::SIToFPOp)
|
||||||
POPULATE_CAST_OP(arith::ExtFOp, LLVM::FPExtOp)
|
POPULATE_UNARY_OP(arith::ExtFOp, LLVM::FPExtOp)
|
||||||
#undef POPULATE_CAST_OP
|
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
|
||||||
|
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
|
||||||
|
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
||||||
|
#undef POPULATE_UNARY_OP
|
||||||
|
|
||||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
|
@@ -351,6 +351,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
patterns.add< // TODO: view should have custom pattern that views the layout
|
patterns.add< // TODO: view should have custom pattern that views the layout
|
||||||
TritonGenericPattern<triton::ViewOp>,
|
TritonGenericPattern<triton::ViewOp>,
|
||||||
|
TritonGenericPattern<triton::BitcastOp>,
|
||||||
|
TritonGenericPattern<triton::IntToPtrOp>,
|
||||||
|
TritonGenericPattern<triton::PtrToIntOp>,
|
||||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||||
|
@@ -150,9 +150,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
|||||||
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
||||||
::mlir::triton::CacheModifier cache,
|
::mlir::triton::CacheModifier cache,
|
||||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
TensorType ptrType = ptr.getType().cast<TensorType>();
|
||||||
Type elementType =
|
Type elementType =
|
||||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
ptrType.getElementType().cast<PointerType>().getPointeeType();
|
||||||
auto shape = ptrType.getShape();
|
auto shape = ptrType.getShape();
|
||||||
Type resultType = RankedTensorType::get(shape, elementType);
|
Type resultType = RankedTensorType::get(shape, elementType);
|
||||||
state.addOperands(ptr);
|
state.addOperands(ptr);
|
||||||
|
@@ -441,11 +441,22 @@ void init_triton_ir(py::module &&m) {
|
|||||||
loc, self.getF32FloatAttr(v));
|
loc, self.getF32FloatAttr(v));
|
||||||
})
|
})
|
||||||
.def("get_null_value",
|
.def("get_null_value",
|
||||||
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
[](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
if (type.isa<mlir::FloatType>())
|
if (auto floatTy = type.dyn_cast<mlir::FloatType>())
|
||||||
return self.create<mlir::arith::ConstantOp>(
|
return self.create<mlir::arith::ConstantFloatOp>(
|
||||||
loc, self.getF32FloatAttr(0.0));
|
loc, mlir::APFloat(floatTy.getFloatSemantics(), 0), floatTy);
|
||||||
|
else if (auto intTy = type.dyn_cast<mlir::IntegerType>())
|
||||||
|
return self.create<mlir::arith::ConstantIntOp>(loc, 0, intTy);
|
||||||
|
else
|
||||||
|
throw std::runtime_error("Not implemented");
|
||||||
|
})
|
||||||
|
.def("get_all_ones_value",
|
||||||
|
[](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
uint64_t val = 0xFFFFFFFFFFFFFFFF;
|
||||||
|
if (auto intTy = type.dyn_cast<mlir::IntegerType>())
|
||||||
|
return self.create<mlir::arith::ConstantIntOp>(loc, val, intTy);
|
||||||
else
|
else
|
||||||
throw std::runtime_error("Not implemented");
|
throw std::runtime_error("Not implemented");
|
||||||
})
|
})
|
||||||
@@ -602,7 +613,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||||
mlir::Type &dstType) -> mlir::Value {
|
mlir::Type &dstType) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::arith::BitcastOp>(loc, dstType, src);
|
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
|
||||||
})
|
})
|
||||||
// .def("create_cast", &ir::builder::create_cast)
|
// .def("create_cast", &ir::builder::create_cast)
|
||||||
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
||||||
@@ -1143,6 +1154,18 @@ void init_triton_ir(py::module &&m) {
|
|||||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
||||||
operand, axis);
|
operand, axis);
|
||||||
})
|
})
|
||||||
|
.def("create_ptr_to_int",
|
||||||
|
[](mlir::OpBuilder &self, mlir::Value &val,
|
||||||
|
mlir::Type &type) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return self.create<mlir::triton::PtrToIntOp>(loc, type, val);
|
||||||
|
})
|
||||||
|
.def("create_int_to_ptr",
|
||||||
|
[](mlir::OpBuilder &self, mlir::Value &val,
|
||||||
|
mlir::Type &type) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return self.create<mlir::triton::IntToPtrOp>(loc, type, val);
|
||||||
|
})
|
||||||
.def("create_select",
|
.def("create_select",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &condition,
|
[](mlir::OpBuilder &self, mlir::Value &condition,
|
||||||
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
|
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
|
||||||
@@ -1231,7 +1254,6 @@ void init_triton_ir(py::module &&m) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void init_triton_translation(py::module &m) {
|
void init_triton_translation(py::module &m) {
|
||||||
|
|
||||||
using ret = py::return_value_policy;
|
using ret = py::return_value_policy;
|
||||||
|
|
||||||
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
|
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
|
||||||
|
@@ -281,141 +281,142 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
|||||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("dtype_x, dtype_y",
|
@pytest.mark.parametrize("dtype_x, dtype_y",
|
||||||
# [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
|
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
|
||||||
# [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
|
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
|
||||||
# )
|
)
|
||||||
# def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||||
# # Triton has IEEE, not numpy/torch, semantics for %, and those carry
|
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
|
||||||
# # through to //, so we have to use a nonstandard expression to get a
|
# through to //, so we have to use a nonstandard expression to get a
|
||||||
# # reference result for //.
|
# reference result for //.
|
||||||
# expr = 'x // y'
|
expr = 'x // y'
|
||||||
# numpy_expr = '((x - np.fmod(x, y)) / y)'
|
numpy_expr = '((x - np.fmod(x, y)) / y)'
|
||||||
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
|
|
||||||
|
|
||||||
# # ---------------
|
# ---------------
|
||||||
# # test bitwise ops
|
# test bitwise ops
|
||||||
# # ---------------
|
# ---------------
|
||||||
# @pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||||
# (dtype_x, dtype_y, op)
|
(dtype_x, dtype_y, op)
|
||||||
# for op in ['&', '|', '^']
|
for op in ['&', '|', '^']
|
||||||
# for dtype_x in dtypes + dtypes_with_bfloat16
|
for dtype_x in dtypes + dtypes_with_bfloat16
|
||||||
# for dtype_y in dtypes + dtypes_with_bfloat16
|
for dtype_y in dtypes + dtypes_with_bfloat16
|
||||||
# ])
|
])
|
||||||
# def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||||
# expr = f'x {op} y'
|
expr = f'x {op} y'
|
||||||
# if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||||
# numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||||
# elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||||
# numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||||
# else:
|
else:
|
||||||
# numpy_expr = None
|
numpy_expr = None
|
||||||
# if 'float' in dtype_x + dtype_y:
|
if 'float' in dtype_x + dtype_y:
|
||||||
# with pytest.raises(triton.CompilationError) as exc_info:
|
with pytest.raises(triton.CompilationError) as exc_info:
|
||||||
# _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
||||||
# # The CompilationError must have been caused by a C++ exception with this text.
|
# The CompilationError must have been caused by a C++ exception with this text.
|
||||||
# assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
||||||
# else:
|
else:
|
||||||
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||||
# (dtype_x, dtype_y, op)
|
(dtype_x, dtype_y, op)
|
||||||
# for op in ['<<', '>>']
|
for op in ['<<', '>>']
|
||||||
# for dtype_x in int_dtypes + uint_dtypes
|
for dtype_x in int_dtypes + uint_dtypes
|
||||||
# for dtype_y in int_dtypes + uint_dtypes
|
for dtype_y in int_dtypes + uint_dtypes
|
||||||
# ])
|
])
|
||||||
# def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
|
def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
|
||||||
# expr = f'x {op} y'
|
expr = f'x {op} y'
|
||||||
# bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
|
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
|
||||||
# dtype_z = f'uint{bw}'
|
dtype_z = f'uint{bw}'
|
||||||
# numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
|
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
|
||||||
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
|
||||||
|
|
||||||
|
|
||||||
# # ---------------
|
# ---------------
|
||||||
# # test compare ops
|
# test compare ops
|
||||||
# # ---------------
|
# ---------------
|
||||||
# ops = ['==', '!=', '>', '<', '>=', '<=']
|
ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
|
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
|
||||||
# # real
|
# real
|
||||||
# [
|
[
|
||||||
# (dtype_x, dtype_y, op, 'real', 'real')
|
(dtype_x, dtype_y, op, 'real', 'real')
|
||||||
# for op in ops
|
for op in ops
|
||||||
# for dtype_x in dtypes
|
for dtype_x in dtypes
|
||||||
# for dtype_y in dtypes
|
for dtype_y in dtypes
|
||||||
# ] +
|
] +
|
||||||
# # NaNs
|
# NaNs
|
||||||
# [('float32', 'float32', op, mode_x, mode_y)
|
[('float32', 'float32', op, mode_x, mode_y)
|
||||||
# for op in ops
|
for op in ops
|
||||||
# for mode_x, mode_y in [('nan', 'real'),
|
for mode_x, mode_y in [('nan', 'real'),
|
||||||
# ('real', 'nan'),
|
('real', 'nan'),
|
||||||
# ('nan', 'nan')]
|
('nan', 'nan')]
|
||||||
|
|
||||||
# ])
|
])
|
||||||
# def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||||
# expr = f'x {op} y'
|
expr = f'x {op} y'
|
||||||
# if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||||
# numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||||
# elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||||
# numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||||
# else:
|
else:
|
||||||
# numpy_expr = None
|
numpy_expr = None
|
||||||
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||||
|
|
||||||
|
|
||||||
# # ---------------
|
# ---------------
|
||||||
# # test where
|
# test where
|
||||||
# # ---------------
|
# ---------------
|
||||||
# @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"])
|
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"])
|
||||||
# def test_where(dtype):
|
def test_where(dtype):
|
||||||
# select_ptrs = False
|
select_ptrs = False
|
||||||
# if dtype == "*int32":
|
if dtype == "*int32":
|
||||||
# dtype = "int64"
|
dtype = "int64"
|
||||||
# select_ptrs = True
|
select_ptrs = True
|
||||||
# check_type_supported(dtype)
|
check_type_supported(dtype)
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
|
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
|
||||||
# BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
# TEST_POINTERS: tl.constexpr):
|
TEST_POINTERS: tl.constexpr):
|
||||||
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
# mask = offsets < n_elements
|
mask = offsets < n_elements
|
||||||
# decide = tl.load(cond_ptr + offsets, mask=mask)
|
decide = tl.load(cond_ptr + offsets, mask=mask)
|
||||||
# if TEST_POINTERS:
|
if TEST_POINTERS:
|
||||||
# a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||||
# b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||||
# else:
|
else:
|
||||||
# a = tl.load(a_ptr + offsets, mask=mask)
|
a = tl.load(a_ptr + offsets, mask=mask)
|
||||||
# b = tl.load(b_ptr + offsets, mask=mask)
|
b = tl.load(b_ptr + offsets, mask=mask)
|
||||||
# output = tl.where(decide, a, b)
|
output = tl.where(decide, a, b)
|
||||||
# tl.store(output_ptr + offsets, output, mask=mask)
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
# SIZE = 1_000
|
SIZE = 1_000
|
||||||
# rs = RandomState(17)
|
rs = RandomState(17)
|
||||||
# cond = numpy_random(SIZE, 'bool', rs)
|
cond = numpy_random(SIZE, 'bool', rs)
|
||||||
# x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
||||||
# y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
||||||
# z = np.where(cond, x, y)
|
z = np.where(cond, x, y)
|
||||||
|
|
||||||
# cond_tri = to_triton(cond, device='cuda')
|
cond_tri = to_triton(cond, device='cuda')
|
||||||
# x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
||||||
# y_tri = to_triton(y, device='cuda', dst_type=dtype)
|
y_tri = to_triton(y, device='cuda', dst_type=dtype)
|
||||||
# z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
|
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||||
|
|
||||||
# grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||||
# where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
|
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
|
||||||
# assert (z == to_numpy(z_tri)).all()
|
assert (z == to_numpy(z_tri)).all()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: wrong result
|
||||||
# def test_where_broadcast():
|
# def test_where_broadcast():
|
||||||
# @triton.jit
|
# @triton.jit
|
||||||
# def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
# def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||||
# xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
|
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||||
# yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
|
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||||
|
|
||||||
# mask = tl.load(cond_ptr + yoffsets)
|
# mask = tl.load(cond_ptr + yoffsets)
|
||||||
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||||
@@ -424,8 +425,8 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
|||||||
|
|
||||||
# @triton.jit
|
# @triton.jit
|
||||||
# def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
# def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||||
# xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
|
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||||
# yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
|
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||||
# mask = 0
|
# mask = 0
|
||||||
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||||
# res = tl.where(mask, vals, 0.)
|
# res = tl.where(mask, vals, 0.)
|
||||||
@@ -451,17 +452,19 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
|||||||
# # ---------------
|
# # ---------------
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("dtype_x, expr", [
|
@pytest.mark.parametrize("dtype_x, expr", [
|
||||||
# (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16
|
(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16
|
||||||
# ] + [
|
] + [
|
||||||
# (dtype_x, ' ~x') for dtype_x in int_dtypes
|
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
||||||
# ])
|
])
|
||||||
# def test_unary_op(dtype_x, expr, device='cuda'):
|
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||||
# _test_unary(dtype_x, expr, device=device)
|
_test_unary(dtype_x, expr, device=device)
|
||||||
|
|
||||||
# # ----------------
|
# # ----------------
|
||||||
# # test math ops
|
# # test math ops
|
||||||
# # ----------------
|
# # ----------------
|
||||||
|
|
||||||
|
# TODO: Math module
|
||||||
# # @pytest.mark.parametrize("expr", [
|
# # @pytest.mark.parametrize("expr", [
|
||||||
# # 'exp', 'log', 'cos', 'sin'
|
# # 'exp', 'log', 'cos', 'sin'
|
||||||
# # ])
|
# # ])
|
||||||
@@ -479,17 +482,18 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
|||||||
# # ----------------
|
# # ----------------
|
||||||
|
|
||||||
|
|
||||||
# def make_ptr_str(name, shape):
|
def make_ptr_str(name, shape):
|
||||||
# rank = len(shape)
|
rank = len(shape)
|
||||||
# offsets = []
|
offsets = []
|
||||||
# stride = 1
|
stride = 1
|
||||||
# for i in reversed(range(rank)):
|
for i in reversed(range(rank)):
|
||||||
# idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
|
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
|
||||||
# offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
|
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
|
||||||
# stride *= shape[i]
|
stride *= shape[i]
|
||||||
# return f"{name} + {' + '.join(offsets)}"
|
return f"{name} + {' + '.join(offsets)}"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
||||||
# @pytest.mark.parametrize("expr, dtype_str", [
|
# @pytest.mark.parametrize("expr, dtype_str", [
|
||||||
# (f'x[{s}]', d)
|
# (f'x[{s}]', d)
|
||||||
# for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
# for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||||
|
@@ -45,6 +45,7 @@ def str_to_ty(name):
|
|||||||
"u32": triton.language.uint32,
|
"u32": triton.language.uint32,
|
||||||
"u64": triton.language.uint64,
|
"u64": triton.language.uint64,
|
||||||
"B": triton.language.int1,
|
"B": triton.language.int1,
|
||||||
|
"i1": triton.language.int1,
|
||||||
}
|
}
|
||||||
return tys[name]
|
return tys[name]
|
||||||
|
|
||||||
|
@@ -729,9 +729,10 @@ def cat(input, other, _builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def reshape(input, shape, _builder=None):
|
def view(input, shape, _builder=None):
|
||||||
"""
|
"""
|
||||||
Tries to reshape the given tensor to a new shape.
|
Returns a tensor with the same elements as `input` but a different shape.
|
||||||
|
The order of the elements may not be preserved.
|
||||||
|
|
||||||
:param input: The input tensor.
|
:param input: The input tensor.
|
||||||
:type input:
|
:type input:
|
||||||
@@ -740,7 +741,7 @@ def reshape(input, shape, _builder=None):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
shape = [x.value for x in shape]
|
shape = [x.value for x in shape]
|
||||||
return semantic.reshape(input, shape, _builder)
|
return semantic.view(input, shape, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -1151,7 +1152,7 @@ def ravel(x):
|
|||||||
:param x: the input tensor
|
:param x: the input tensor
|
||||||
:type x: Block
|
:type x: Block
|
||||||
"""
|
"""
|
||||||
return triton.language.reshape(x, [x.numel])
|
return triton.language.view(x, [x.numel])
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
@@ -345,7 +345,7 @@ def invert(input: tl.tensor,
|
|||||||
input_sca_ty = input.type.scalar
|
input_sca_ty = input.type.scalar
|
||||||
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
|
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
|
||||||
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
|
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
|
||||||
_1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
_1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
||||||
return xor_(input, _1, builder)
|
return xor_(input, _1, builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -481,11 +481,13 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
|||||||
def view(input: tl.tensor,
|
def view(input: tl.tensor,
|
||||||
dst_shape: List[int],
|
dst_shape: List[int],
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
|
# TODO: disable when TritonToTritonGPU handles views properly
|
||||||
|
assert len(input.shape) == len(dst_shape)
|
||||||
numel = 1
|
numel = 1
|
||||||
for s in dst_shape:
|
for s in dst_shape:
|
||||||
numel *= s
|
numel *= s
|
||||||
if input.type.numel != numel:
|
if input.type.numel != numel:
|
||||||
raise ValueError("cannot reshape block of different shape")
|
raise ValueError("cannot view block of different shape")
|
||||||
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
||||||
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
|
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
|
||||||
|
|
||||||
@@ -516,7 +518,7 @@ def broadcast_impl_shape(input: tl.tensor,
|
|||||||
for i in range(len(src_shape)):
|
for i in range(len(src_shape)):
|
||||||
if shape[i] != src_shape[i] and src_shape[i] != 1:
|
if shape[i] != src_shape[i] and src_shape[i] != 1:
|
||||||
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
|
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
|
||||||
f" must match the existing size ({src_shape[1]}) at non-singleton dimension"
|
f" must match the existing size ({src_shape[i]}) at non-singleton dimension"
|
||||||
f" {i}: {src_shape}, {shape}")
|
f" {i}: {src_shape}, {shape}")
|
||||||
ret_ty = tl.block_type(input.type.scalar, shape)
|
ret_ty = tl.block_type(input.type.scalar, shape)
|
||||||
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
|
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
|
||||||
@@ -679,7 +681,7 @@ def cast(input: tl.tensor,
|
|||||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
||||||
bitwidth = dst_sca_ty.int_bitwidth
|
bitwidth = dst_sca_ty.int_bitwidth
|
||||||
if bitwidth == 64:
|
if bitwidth == 64:
|
||||||
return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)),
|
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
|
||||||
dst_ty)
|
dst_ty)
|
||||||
if bitwidth == 1:
|
if bitwidth == 1:
|
||||||
return not_equal(cast(input, tl.int64, builder),
|
return not_equal(cast(input, tl.int64, builder),
|
||||||
@@ -989,18 +991,21 @@ def where(condition: tl.tensor,
|
|||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
condition = cast(condition, tl.int1, builder)
|
condition = cast(condition, tl.int1, builder)
|
||||||
if condition.type.is_block():
|
if condition.type.is_block():
|
||||||
x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder)
|
condition, x = broadcast_impl_value(condition, x, builder)
|
||||||
y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder)
|
x, y = broadcast_impl_value(x, y, builder)
|
||||||
|
condition, x = broadcast_impl_value(condition, x, builder)
|
||||||
|
|
||||||
x, y = binary_op_type_checking_impl(x, y, builder, True, True)
|
x, y = binary_op_type_checking_impl(x, y, builder, True, True)
|
||||||
|
if not condition.type.is_block():
|
||||||
|
condition, _ = broadcast_impl_value(condition, x, builder)
|
||||||
ret_ty = x.type
|
ret_ty = x.type
|
||||||
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
||||||
|
|
||||||
|
|
||||||
# ===----------------------------------------------------------------------===//
|
# ===----------------------------------------------------------------------===//
|
||||||
# Reductions
|
# Reductions
|
||||||
# ===----------------------------------------------------------------------===
|
# ===----------------------------------------------------------------------===
|
||||||
|
|
||||||
|
|
||||||
def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||||
FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor:
|
FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor:
|
||||||
scalar_ty = input.type.scalar
|
scalar_ty = input.type.scalar
|
||||||
|
Reference in New Issue
Block a user