diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 48e1a7818..a8960cf95 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -248,6 +248,13 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc, loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); } +static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, + LLVMTypeConverter *converter, Type ty, + int64_t value) { + return builder.create(loc, converter->convertType(ty), + builder.getIntegerAttr(ty, value)); +} + Value getStructFromElements(Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter, Type structType) { @@ -601,7 +608,7 @@ struct StoreOpConversion auto getLLVMElems = [&](Value value, Value llValue, - const BlockedEncodingAttr &layout) -> SmallVector { + const BlockedEncodingAttr &layout) -> SmallVector { auto ty = value.getType().cast(); auto shape = ty.getShape(); // Here, we assume that all inputs should have a blockedLayout @@ -630,13 +637,16 @@ struct StoreOpConversion }; auto [ptrLayout, ptrNumElems] = getLayout(ptr); - auto [maskLayout, maskNumElems] = getLayout(mask); auto [valueLayout, valueNumElems] = getLayout(value); - auto ptrElems = getLLVMElems(mask, llPtr, maskLayout); + auto ptrElems = getLLVMElems(ptr, llPtr, ptrLayout); auto valueElems = getLLVMElems(value, llValue, valueLayout); - auto maskElems = getLLVMElems(mask, llMask, maskLayout); - assert(valueElems.size() == maskElems.size()); + SmallVector maskElems; + if (llMask) { + auto [maskLayout, maskNumElems] = getLayout(mask); + maskElems = getLLVMElems(mask, llMask, maskLayout); + assert(valueElems.size() == maskElems.size()); + } auto getAlign = [this](Value val, const BlockedEncodingAttr &layout) -> unsigned { @@ -710,10 +720,12 @@ struct StoreOpConversion PTXBuilder ptxBuilder; auto &ptxStoreInstr = *ptxBuilder.create("st"); - ptxStoreInstr.predicate(maskElems[vecIdx], "b") - .global() - .b(width) - .v(nWords); + + Value maskVal = + llMask ? maskElems[vecIdx] + : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(), + rewriter.getIntegerType(1), 1); + ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords); llvm::SmallVector asmArgs; @@ -746,8 +758,8 @@ struct StoreOpConversion } ptxStoreInstr(asmAddr, asmArgList); - - llvm::SmallVector argTys({mask.getType(), ptr.getType()}); + Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); + llvm::SmallVector argTys({boolTy, ptr.getType()}); for (int i = 0; i < nWords; i++) argTys.push_back(valArgTy); @@ -970,9 +982,12 @@ struct LoadOpConversion auto elemTy = resultTy.getElementType(); unsigned numElems = getElemsPerThread(blockedLayout, shape); auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter); - auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter); + SmallVector maskVals; + if (mask) { + maskVals = getElementsFromStruct(loc, mask, numElems, rewriter); + } SmallVector otherVals; - if (other != nullptr) { + if (other) { otherVals = getElementsFromStruct(loc, other, numElems, rewriter); } unsigned nbits = elemTy.isa() @@ -1004,7 +1019,10 @@ struct LoadOpConversion // TODO: Handle the optimization if ptr is from GEP and the idx is // constant. This should be a canonicalization pattern in LLVM Dialect unsigned in_off = 0; - Value pred = maskVals[i]; + Value pred = + mask ? maskVals[i] + : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(), + rewriter.getIntegerType(1), 1); // --- // create inline asm string diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 7a4e3add6..e000d2604 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -227,8 +227,9 @@ struct TritonLoadPattern : public OpConversionPattern { matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - op, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(), - adaptor.evict(), adaptor.isVolatile()); + op, typeConverter->convertType(op.getType()), adaptor.ptr(), + adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(), + adaptor.isVolatile()); return success(); } }; diff --git a/python/src/triton.cc b/python/src/triton.cc index 549498175..8366b90b9 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -26,6 +26,7 @@ #include "llvm/Support/raw_ostream.h" #include +#include #include #include #include @@ -99,6 +100,14 @@ long pow2_divisor(long N) { return 1; } +bool getBoolEnv(const std::string &env) { + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return (str == "on" || str == "true" || str == "1"); +} + // Returns something like "int16", whether dtype is a torch.dtype or // triton.language.dtype. std::string dtype_cache_key_part(const py::object &dtype) { @@ -1589,6 +1598,21 @@ void init_triton_ir(py::module &&m) { py::class_(m, "pass_manager") .def(py::init()) + .def("enable_debug", + [](mlir::PassManager &self) { + auto printingFlags = mlir::OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + self.enableIRPrinting( + /*shouldPrintBeforePass=*/nullptr, + /*shouldPrintAfterPass=*/ + [](mlir::Pass *pass, mlir::Operation *) { + return getBoolEnv("MLIR_ENABLE_DUMP"); + }, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/true, + /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), + printingFlags); + }) .def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool { return mlir::succeeded(self.run(mod.getOperation())); diff --git a/python/tests/test_vecadd_no_scf.py b/python/tests/test_vecadd_no_scf.py index d1604d8f1..ee4994d2b 100644 --- a/python/tests/test_vecadd_no_scf.py +++ b/python/tests/test_vecadd_no_scf.py @@ -27,7 +27,13 @@ def test_vecadd_no_scf(): z_ptrs = z_ptr + offset tl.store(z_ptrs, z) - ptx, shem_size, kernel_name = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx") + # TODO: add this to CI, to make sure the the compilation flow is at lease OK + # before we have GPU machines for CI. + # ptx, shem_size, kernel_name = triton.compile(kernel, + # "*fp32,i32,*fp32,i32,*fp32,i32", + # constants={"BLOCK_SIZE_N": 256}, + # num_warps=NUM_WARPS, + # device=0, output="ptx") torch.zeros([10], device=torch.device('cuda')) device = torch.cuda.current_device() diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 45213d84a..821889dd6 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -772,6 +772,7 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()): def optimize_triton_ir(mod): pm = _triton.ir.pass_manager(mod.context) + pm.enable_debug() pm.add_inliner_pass() pm.add_canonicalizer_pass() pm.run(mod) @@ -780,6 +781,7 @@ def optimize_triton_ir(mod): def make_tritongpu_ir(mod, num_warps): pm = _triton.ir.pass_manager(mod.context) + pm.enable_debug() pm.add_inliner_pass() pm.add_triton_combine_pass() pm.add_canonicalizer_pass() @@ -791,6 +793,7 @@ def make_tritongpu_ir(mod, num_warps): def optimize_tritongpu_ir(mod, num_stages): pm = _triton.ir.pass_manager(mod.context) + pm.enable_debug() pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() pm.add_cse_pass()