[BACKEND] Support optional mask in TritonGPUToLLVM (#80)

Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
goostavz
2022-08-25 08:51:37 +08:00
committed by GitHub
parent 84aa7d025a
commit bedbf221c0
5 changed files with 69 additions and 17 deletions

View File

@@ -248,6 +248,13 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
} }
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
LLVMTypeConverter *converter, Type ty,
int64_t value) {
return builder.create<LLVM::ConstantOp>(loc, converter->convertType(ty),
builder.getIntegerAttr(ty, value));
}
Value getStructFromElements(Location loc, ValueRange resultVals, Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Type structType) { Type structType) {
@@ -601,7 +608,7 @@ struct StoreOpConversion
auto getLLVMElems = auto getLLVMElems =
[&](Value value, Value llValue, [&](Value value, Value llValue,
const BlockedEncodingAttr &layout) -> SmallVector<Value, 4> { const BlockedEncodingAttr &layout) -> SmallVector<Value> {
auto ty = value.getType().cast<RankedTensorType>(); auto ty = value.getType().cast<RankedTensorType>();
auto shape = ty.getShape(); auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout // Here, we assume that all inputs should have a blockedLayout
@@ -630,13 +637,16 @@ struct StoreOpConversion
}; };
auto [ptrLayout, ptrNumElems] = getLayout(ptr); auto [ptrLayout, ptrNumElems] = getLayout(ptr);
auto [maskLayout, maskNumElems] = getLayout(mask);
auto [valueLayout, valueNumElems] = getLayout(value); auto [valueLayout, valueNumElems] = getLayout(value);
auto ptrElems = getLLVMElems(mask, llPtr, maskLayout); auto ptrElems = getLLVMElems(ptr, llPtr, ptrLayout);
auto valueElems = getLLVMElems(value, llValue, valueLayout); auto valueElems = getLLVMElems(value, llValue, valueLayout);
auto maskElems = getLLVMElems(mask, llMask, maskLayout); SmallVector<Value> maskElems;
if (llMask) {
auto [maskLayout, maskNumElems] = getLayout(mask);
maskElems = getLLVMElems(mask, llMask, maskLayout);
assert(valueElems.size() == maskElems.size()); assert(valueElems.size() == maskElems.size());
}
auto getAlign = [this](Value val, auto getAlign = [this](Value val,
const BlockedEncodingAttr &layout) -> unsigned { const BlockedEncodingAttr &layout) -> unsigned {
@@ -710,10 +720,12 @@ struct StoreOpConversion
PTXBuilder ptxBuilder; PTXBuilder ptxBuilder;
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st"); auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
ptxStoreInstr.predicate(maskElems[vecIdx], "b")
.global() Value maskVal =
.b(width) llMask ? maskElems[vecIdx]
.v(nWords); : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
llvm::SmallVector<std::string> asmArgs; llvm::SmallVector<std::string> asmArgs;
@@ -746,8 +758,8 @@ struct StoreOpConversion
} }
ptxStoreInstr(asmAddr, asmArgList); ptxStoreInstr(asmAddr, asmArgList);
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()}); llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
for (int i = 0; i < nWords; i++) for (int i = 0; i < nWords; i++)
argTys.push_back(valArgTy); argTys.push_back(valArgTy);
@@ -970,9 +982,12 @@ struct LoadOpConversion
auto elemTy = resultTy.getElementType(); auto elemTy = resultTy.getElementType();
unsigned numElems = getElemsPerThread(blockedLayout, shape); unsigned numElems = getElemsPerThread(blockedLayout, shape);
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter); auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter); SmallVector<Value> maskVals;
if (mask) {
maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
}
SmallVector<Value> otherVals; SmallVector<Value> otherVals;
if (other != nullptr) { if (other) {
otherVals = getElementsFromStruct(loc, other, numElems, rewriter); otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
} }
unsigned nbits = elemTy.isa<FloatType>() unsigned nbits = elemTy.isa<FloatType>()
@@ -1004,7 +1019,10 @@ struct LoadOpConversion
// TODO: Handle the optimization if ptr is from GEP and the idx is // TODO: Handle the optimization if ptr is from GEP and the idx is
// constant. This should be a canonicalization pattern in LLVM Dialect // constant. This should be a canonicalization pattern in LLVM Dialect
unsigned in_off = 0; unsigned in_off = 0;
Value pred = maskVals[i]; Value pred =
mask ? maskVals[i]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
// --- // ---
// create inline asm string // create inline asm string

View File

@@ -227,8 +227,9 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::LoadOp>( rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(), op, typeConverter->convertType(op.getType()), adaptor.ptr(),
adaptor.evict(), adaptor.isVolatile()); adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(),
adaptor.isVolatile());
return success(); return success();
} }
}; };

View File

@@ -26,6 +26,7 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <Python.h> #include <Python.h>
#include <cctype>
#include <optional> #include <optional>
#include <pybind11/buffer_info.h> #include <pybind11/buffer_info.h>
#include <pybind11/functional.h> #include <pybind11/functional.h>
@@ -99,6 +100,14 @@ long pow2_divisor(long N) {
return 1; return 1;
} }
bool getBoolEnv(const std::string &env) {
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}
// Returns something like "int16", whether dtype is a torch.dtype or // Returns something like "int16", whether dtype is a torch.dtype or
// triton.language.dtype. // triton.language.dtype.
std::string dtype_cache_key_part(const py::object &dtype) { std::string dtype_cache_key_part(const py::object &dtype) {
@@ -1589,6 +1598,21 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::PassManager>(m, "pass_manager") py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>()) .def(py::init<mlir::MLIRContext *>())
.def("enable_debug",
[](mlir::PassManager &self) {
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
self.enableIRPrinting(
/*shouldPrintBeforePass=*/nullptr,
/*shouldPrintAfterPass=*/
[](mlir::Pass *pass, mlir::Operation *) {
return getBoolEnv("MLIR_ENABLE_DUMP");
},
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(),
printingFlags);
})
.def("run", .def("run",
[](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool { [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
return mlir::succeeded(self.run(mod.getOperation())); return mlir::succeeded(self.run(mod.getOperation()));

View File

@@ -27,7 +27,13 @@ def test_vecadd_no_scf():
z_ptrs = z_ptr + offset z_ptrs = z_ptr + offset
tl.store(z_ptrs, z) tl.store(z_ptrs, z)
ptx, shem_size, kernel_name = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx") # TODO: add this to CI, to make sure the the compilation flow is at lease OK
# before we have GPU machines for CI.
# ptx, shem_size, kernel_name = triton.compile(kernel,
# "*fp32,i32,*fp32,i32,*fp32,i32",
# constants={"BLOCK_SIZE_N": 256},
# num_warps=NUM_WARPS,
# device=0, output="ptx")
torch.zeros([10], device=torch.device('cuda')) torch.zeros([10], device=torch.device('cuda'))
device = torch.cuda.current_device() device = torch.cuda.current_device()

View File

@@ -772,6 +772,7 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
def optimize_triton_ir(mod): def optimize_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass() pm.add_inliner_pass()
pm.add_canonicalizer_pass() pm.add_canonicalizer_pass()
pm.run(mod) pm.run(mod)
@@ -780,6 +781,7 @@ def optimize_triton_ir(mod):
def make_tritongpu_ir(mod, num_warps): def make_tritongpu_ir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass() pm.add_inliner_pass()
pm.add_triton_combine_pass() pm.add_triton_combine_pass()
pm.add_canonicalizer_pass() pm.add_canonicalizer_pass()
@@ -791,6 +793,7 @@ def make_tritongpu_ir(mod, num_warps):
def optimize_tritongpu_ir(mod, num_stages): def optimize_tritongpu_ir(mod, num_stages):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_tritongpu_pipeline_pass(num_stages) pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass() pm.add_canonicalizer_pass()
pm.add_cse_pass() pm.add_cse_pass()