[BACKEND] Support optional mask in TritonGPUToLLVM (#80)
Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
@@ -248,6 +248,13 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
|||||||
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
|
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
||||||
|
LLVMTypeConverter *converter, Type ty,
|
||||||
|
int64_t value) {
|
||||||
|
return builder.create<LLVM::ConstantOp>(loc, converter->convertType(ty),
|
||||||
|
builder.getIntegerAttr(ty, value));
|
||||||
|
}
|
||||||
|
|
||||||
Value getStructFromElements(Location loc, ValueRange resultVals,
|
Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Type structType) {
|
Type structType) {
|
||||||
@@ -601,7 +608,7 @@ struct StoreOpConversion
|
|||||||
|
|
||||||
auto getLLVMElems =
|
auto getLLVMElems =
|
||||||
[&](Value value, Value llValue,
|
[&](Value value, Value llValue,
|
||||||
const BlockedEncodingAttr &layout) -> SmallVector<Value, 4> {
|
const BlockedEncodingAttr &layout) -> SmallVector<Value> {
|
||||||
auto ty = value.getType().cast<RankedTensorType>();
|
auto ty = value.getType().cast<RankedTensorType>();
|
||||||
auto shape = ty.getShape();
|
auto shape = ty.getShape();
|
||||||
// Here, we assume that all inputs should have a blockedLayout
|
// Here, we assume that all inputs should have a blockedLayout
|
||||||
@@ -630,13 +637,16 @@ struct StoreOpConversion
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
|
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
|
||||||
auto [maskLayout, maskNumElems] = getLayout(mask);
|
|
||||||
auto [valueLayout, valueNumElems] = getLayout(value);
|
auto [valueLayout, valueNumElems] = getLayout(value);
|
||||||
|
|
||||||
auto ptrElems = getLLVMElems(mask, llPtr, maskLayout);
|
auto ptrElems = getLLVMElems(ptr, llPtr, ptrLayout);
|
||||||
auto valueElems = getLLVMElems(value, llValue, valueLayout);
|
auto valueElems = getLLVMElems(value, llValue, valueLayout);
|
||||||
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
|
SmallVector<Value> maskElems;
|
||||||
|
if (llMask) {
|
||||||
|
auto [maskLayout, maskNumElems] = getLayout(mask);
|
||||||
|
maskElems = getLLVMElems(mask, llMask, maskLayout);
|
||||||
assert(valueElems.size() == maskElems.size());
|
assert(valueElems.size() == maskElems.size());
|
||||||
|
}
|
||||||
|
|
||||||
auto getAlign = [this](Value val,
|
auto getAlign = [this](Value val,
|
||||||
const BlockedEncodingAttr &layout) -> unsigned {
|
const BlockedEncodingAttr &layout) -> unsigned {
|
||||||
@@ -710,10 +720,12 @@ struct StoreOpConversion
|
|||||||
|
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
|
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
|
||||||
ptxStoreInstr.predicate(maskElems[vecIdx], "b")
|
|
||||||
.global()
|
Value maskVal =
|
||||||
.b(width)
|
llMask ? maskElems[vecIdx]
|
||||||
.v(nWords);
|
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
||||||
|
rewriter.getIntegerType(1), 1);
|
||||||
|
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
|
||||||
|
|
||||||
llvm::SmallVector<std::string> asmArgs;
|
llvm::SmallVector<std::string> asmArgs;
|
||||||
|
|
||||||
@@ -746,8 +758,8 @@ struct StoreOpConversion
|
|||||||
}
|
}
|
||||||
|
|
||||||
ptxStoreInstr(asmAddr, asmArgList);
|
ptxStoreInstr(asmAddr, asmArgList);
|
||||||
|
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||||
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()});
|
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
||||||
for (int i = 0; i < nWords; i++)
|
for (int i = 0; i < nWords; i++)
|
||||||
argTys.push_back(valArgTy);
|
argTys.push_back(valArgTy);
|
||||||
|
|
||||||
@@ -970,9 +982,12 @@ struct LoadOpConversion
|
|||||||
auto elemTy = resultTy.getElementType();
|
auto elemTy = resultTy.getElementType();
|
||||||
unsigned numElems = getElemsPerThread(blockedLayout, shape);
|
unsigned numElems = getElemsPerThread(blockedLayout, shape);
|
||||||
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
|
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
|
||||||
auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
|
SmallVector<Value> maskVals;
|
||||||
|
if (mask) {
|
||||||
|
maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
|
||||||
|
}
|
||||||
SmallVector<Value> otherVals;
|
SmallVector<Value> otherVals;
|
||||||
if (other != nullptr) {
|
if (other) {
|
||||||
otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
|
otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
|
||||||
}
|
}
|
||||||
unsigned nbits = elemTy.isa<FloatType>()
|
unsigned nbits = elemTy.isa<FloatType>()
|
||||||
@@ -1004,7 +1019,10 @@ struct LoadOpConversion
|
|||||||
// TODO: Handle the optimization if ptr is from GEP and the idx is
|
// TODO: Handle the optimization if ptr is from GEP and the idx is
|
||||||
// constant. This should be a canonicalization pattern in LLVM Dialect
|
// constant. This should be a canonicalization pattern in LLVM Dialect
|
||||||
unsigned in_off = 0;
|
unsigned in_off = 0;
|
||||||
Value pred = maskVals[i];
|
Value pred =
|
||||||
|
mask ? maskVals[i]
|
||||||
|
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
||||||
|
rewriter.getIntegerType(1), 1);
|
||||||
|
|
||||||
// ---
|
// ---
|
||||||
// create inline asm string
|
// create inline asm string
|
||||||
|
@@ -227,8 +227,9 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
|||||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||||
op, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(),
|
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
||||||
adaptor.evict(), adaptor.isVolatile());
|
adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(),
|
||||||
|
adaptor.isVolatile());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -26,6 +26,7 @@
|
|||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
#include <cctype>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <pybind11/buffer_info.h>
|
#include <pybind11/buffer_info.h>
|
||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
@@ -99,6 +100,14 @@ long pow2_divisor(long N) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool getBoolEnv(const std::string &env) {
|
||||||
|
const char *s = std::getenv(env.c_str());
|
||||||
|
std::string str(s ? s : "");
|
||||||
|
std::transform(str.begin(), str.end(), str.begin(),
|
||||||
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
|
return (str == "on" || str == "true" || str == "1");
|
||||||
|
}
|
||||||
|
|
||||||
// Returns something like "int16", whether dtype is a torch.dtype or
|
// Returns something like "int16", whether dtype is a torch.dtype or
|
||||||
// triton.language.dtype.
|
// triton.language.dtype.
|
||||||
std::string dtype_cache_key_part(const py::object &dtype) {
|
std::string dtype_cache_key_part(const py::object &dtype) {
|
||||||
@@ -1589,6 +1598,21 @@ void init_triton_ir(py::module &&m) {
|
|||||||
|
|
||||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||||
.def(py::init<mlir::MLIRContext *>())
|
.def(py::init<mlir::MLIRContext *>())
|
||||||
|
.def("enable_debug",
|
||||||
|
[](mlir::PassManager &self) {
|
||||||
|
auto printingFlags = mlir::OpPrintingFlags();
|
||||||
|
printingFlags.elideLargeElementsAttrs(16);
|
||||||
|
self.enableIRPrinting(
|
||||||
|
/*shouldPrintBeforePass=*/nullptr,
|
||||||
|
/*shouldPrintAfterPass=*/
|
||||||
|
[](mlir::Pass *pass, mlir::Operation *) {
|
||||||
|
return getBoolEnv("MLIR_ENABLE_DUMP");
|
||||||
|
},
|
||||||
|
/*printModuleScope=*/false,
|
||||||
|
/*printAfterOnlyOnChange=*/true,
|
||||||
|
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(),
|
||||||
|
printingFlags);
|
||||||
|
})
|
||||||
.def("run",
|
.def("run",
|
||||||
[](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
[](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
||||||
return mlir::succeeded(self.run(mod.getOperation()));
|
return mlir::succeeded(self.run(mod.getOperation()));
|
||||||
|
@@ -27,7 +27,13 @@ def test_vecadd_no_scf():
|
|||||||
z_ptrs = z_ptr + offset
|
z_ptrs = z_ptr + offset
|
||||||
tl.store(z_ptrs, z)
|
tl.store(z_ptrs, z)
|
||||||
|
|
||||||
ptx, shem_size, kernel_name = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
|
# TODO: add this to CI, to make sure the the compilation flow is at lease OK
|
||||||
|
# before we have GPU machines for CI.
|
||||||
|
# ptx, shem_size, kernel_name = triton.compile(kernel,
|
||||||
|
# "*fp32,i32,*fp32,i32,*fp32,i32",
|
||||||
|
# constants={"BLOCK_SIZE_N": 256},
|
||||||
|
# num_warps=NUM_WARPS,
|
||||||
|
# device=0, output="ptx")
|
||||||
|
|
||||||
torch.zeros([10], device=torch.device('cuda'))
|
torch.zeros([10], device=torch.device('cuda'))
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
|
@@ -772,6 +772,7 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
|||||||
|
|
||||||
def optimize_triton_ir(mod):
|
def optimize_triton_ir(mod):
|
||||||
pm = _triton.ir.pass_manager(mod.context)
|
pm = _triton.ir.pass_manager(mod.context)
|
||||||
|
pm.enable_debug()
|
||||||
pm.add_inliner_pass()
|
pm.add_inliner_pass()
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
@@ -780,6 +781,7 @@ def optimize_triton_ir(mod):
|
|||||||
|
|
||||||
def make_tritongpu_ir(mod, num_warps):
|
def make_tritongpu_ir(mod, num_warps):
|
||||||
pm = _triton.ir.pass_manager(mod.context)
|
pm = _triton.ir.pass_manager(mod.context)
|
||||||
|
pm.enable_debug()
|
||||||
pm.add_inliner_pass()
|
pm.add_inliner_pass()
|
||||||
pm.add_triton_combine_pass()
|
pm.add_triton_combine_pass()
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
@@ -791,6 +793,7 @@ def make_tritongpu_ir(mod, num_warps):
|
|||||||
|
|
||||||
def optimize_tritongpu_ir(mod, num_stages):
|
def optimize_tritongpu_ir(mod, num_stages):
|
||||||
pm = _triton.ir.pass_manager(mod.context)
|
pm = _triton.ir.pass_manager(mod.context)
|
||||||
|
pm.enable_debug()
|
||||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
|
Reference in New Issue
Block a user