From fc58250a0698168bd342d8c6a64e24598db0a6f0 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Thu, 18 Aug 2022 20:46:45 +0800 Subject: [PATCH] [BACKEND] Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp & GEPOp and bugfix for SplatOp, StoreOp, FuncOp (#60) Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp, GEPOp and bugfix for SplatOp, StoreOp, FuncOp Co-authored-by: gzhu --- CMakeLists.txt | 1 + bin/CMakeLists.txt | 1 + .../TritonGPUToLLVM/PtxAsmFormat.cpp | 4 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 243 +++++++++++------- lib/Target/LLVMIR/LLVMIRTranslation.cpp | 7 +- python/src/triton.cc | 2 +- python/test/vecadd_no_scf.py | 30 +++ test/Conversion/triton_to_llvm.mlir | 7 +- test/Conversion/tritongpu_to_llvm.mlir | 97 +++++-- 9 files changed, 270 insertions(+), 122 deletions(-) create mode 100644 python/test/vecadd_no_scf.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 1ffb44152..6d217b903 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -197,6 +197,7 @@ target_link_libraries(triton MLIRSupport MLIRTargetLLVMIRExport MLIRExecutionEngine + MLIRNVVMToLLVMIRTranslation ) target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 6d5673a02..462873630 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -54,5 +54,6 @@ target_link_libraries(triton-translate PRIVATE MLIRExecutionEngine MLIRTransformUtils MLIRLLVMToLLVMIRTranslation + MLIRNVVMToLLVMIRTranslation ) mlir_check_all_link_libraries(triton-translate) diff --git a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp index 579b91c39..ee126df18 100644 --- a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp +++ b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp @@ -61,7 +61,7 @@ std::string PtxInstr::Operand::dump() const { if (repr) return repr(idx); if (!isList()) - return llvm::formatv("%{0}", idx); + return llvm::formatv("${0}", idx); llvm::SmallVector oprs; for (auto *opr : list) oprs.push_back(opr->dump()); @@ -72,7 +72,7 @@ PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr, StringRef constraint, int off) { auto *opr = newOperand(addr, constraint); opr->repr = [off](int idx) -> std::string { - return llvm::formatv("[ %{0} + {1} ]", idx, off); + return llvm::formatv("[ ${0} + {1} ]", idx, off); }; return opr; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 74dcf8220..3c8ce991e 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -46,20 +46,10 @@ template size_t product(llvm::ArrayRef arr) { return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{}); } -// The following code are borrowed from mlir project including the following -// functions or classes: -// - filterFuncAttributes -// - ConvertOpToLLVMPattern -// - FuncOpConversion -// -// The code are hidden in the CPP files in MLIR repo, and we can't call them -// directly. I found such code snippets are refactored and added to LLVMCommon -// in the latest MLIR code, but the v14.0.0 version currentlly used in Triton -// doesn't contain the code. +// FuncOpConversion/FuncOpConversionBase is borrowed from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276 +// since it is not exposed on header files in mlir v14 // TODO(Superjomn) Remove the code when mlir v15.0 is included. -// -// The original code: -// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L219 // All the rights are reserved by LLVM community. /// Only retain those attributes that are not constructed by @@ -79,6 +69,12 @@ static void filterFuncAttributes(ArrayRef attrs, } } +/// Helper function for wrapping all attributes into a single DictionaryAttr +static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { + return DictionaryAttr::get( + b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs)); +} + struct FuncOpConversionBase : public ConvertOpToLLVMPattern { protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -90,25 +86,34 @@ protected: ConversionPatternRewriter &rewriter) const { // Convert the original function arguments. They are converted using the // LLVMTypeConverter provided to this legalization pattern. - auto varargsAttr = funcOp->getAttrOfType("std.varargs"); + auto varargsAttr = funcOp->getAttrOfType("func.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = getTypeConverter()->convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); if (!llvmType) return nullptr; - // Propagate argument attributes to all converted arguments obtained after - // converting a given original argument. + // Propagate argument/result attributes to all converted arguments/result + // obtained after converting a given original argument/result. SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, + filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true, attributes); + if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { + assert(!resAttrDicts.empty() && "expected array to be non-empty"); + auto newResAttrDicts = + (funcOp.getNumResults() == 1) + ? resAttrDicts + : rewriter.getArrayAttr( + {wrapAsStructAttrs(rewriter, resAttrDicts)}); + attributes.push_back(rewriter.getNamedAttr( + FunctionOpInterface::getResultDictAttrName(), newResAttrDicts)); + } if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( llvmType.cast().getNumParams()); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { auto mapping = result.getInputMapping(i); - assert(mapping.hasValue() && - "unexpected deletion of function argument"); + assert(mapping && "unexpected deletion of function argument"); for (size_t j = 0; j < mapping->size; ++j) newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; } @@ -136,37 +141,15 @@ protected: } linkage = attr.getLinkage(); } - - auto oldArgs = funcOp.getArguments(); auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, linkage, /*dsoLocal*/ false, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, &result))) return nullptr; - // Convert argument - llvm::DenseMap argMap; - for (int i = 0, n = funcOp.getNumArguments(); i < n; i++) { - Value oldArg = oldArgs[i]; - Value newArg = newFuncOp.getArgument(i); - argMap.try_emplace(oldArg, newArg); - } - - newFuncOp.getBody().walk([&](Operation *op) { - // Convert the function argument types, e.g, from !tt.ptr to - // ptr - for (int i = 0; i < op->getNumOperands(); i++) { - auto arg = op->getOperand(i); - auto it = argMap.find(arg); - if (it != argMap.end()) - op->setOperand(i, it->second); - } - }); - return newFuncOp; } }; @@ -245,8 +228,13 @@ static int64_t getLinearIndex(std::vector multidim_index, static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout, ArrayRef shape) { - return product(shape) / (product(layout.getThreadsPerWarp()) * - product(layout.getWarpsPerCTA())); + size_t rank = shape.size(); + SmallVector elemsPerThreadPerDim(rank); + for (size_t i = 0; i < rank; ++i) { + unsigned t = layout.getThreadsPerWarp()[i] * layout.getWarpsPerCTA()[i]; + elemsPerThreadPerDim[i] = (shape[i] + t - 1) / t; + } + return product(elemsPerThreadPerDim); } static Value createIndexAttrConstant(OpBuilder &builder, Location loc, @@ -257,7 +245,7 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Value getStructFromElements(Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter, - Type structType, Type elemPtrPtrType) { + Type structType) { Value llvmStruct = rewriter.create(loc, structType); for (auto v : llvm::enumerate(resultVals)) { llvmStruct = rewriter.create( @@ -513,10 +501,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, auto structTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); - auto llElemPtrPtrTy = - LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(srcType)); - auto llStruct = - getStructFromElements(loc, elems, rewriter, structTy, llElemPtrPtrTy); + auto llStruct = getStructFromElements(loc, elems, rewriter, structTy); return llStruct; } @@ -529,29 +514,7 @@ struct SplatOpConversion matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto src = op->getOperand(0); - - LLVM::ConstantOp arithConstantOp; - if (src.getDefiningOp() && - (arithConstantOp = - llvm::dyn_cast(src.getDefiningOp()))) { - Value constant; - auto values = arithConstantOp.getValue().dyn_cast(); - - assert(values.size() == 1); - Attribute val; - if (type::isInt(src.getType())) { - val = values.getValues()[0]; - } else if (type::isFloat(src.getType())) { - val = values.getValues()[0]; - } else { - llvm::errs() << "Constant op type not supported"; - return failure(); - } - - src = rewriter.create(loc, val.getType(), val); - } - + auto src = adaptor.src(); auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src, getTypeConverter(), rewriter, loc); rewriter.replaceOp(op, {llStruct}); @@ -618,12 +581,15 @@ struct StoreOpConversion Value mask = op.mask(); Value value = op.value(); - Value llPtr = adaptor.ptr(); // should be LLVM ops + Value llPtr = adaptor.ptr(); Value llMask = adaptor.mask(); Value llValue = adaptor.value(); - Type valueElemTy = getTypeConverter()->convertType( - value.getType().cast().getElementType()); + auto valueTy = value.getType().dyn_cast(); + if (!valueTy) + return failure(); + Type valueElemTy = + getTypeConverter()->convertType(valueTy.getElementType()); MLIRContext *ctx = rewriter.getContext(); auto loc = op->getLoc(); @@ -662,6 +628,7 @@ struct StoreOpConversion auto [maskLayout, maskNumElems] = getLayout(mask); auto [valueLayout, valueNumElems] = getLayout(value); + auto ptrElems = getLLVMElems(mask, llPtr, maskLayout); auto valueElems = getLLVMElems(value, llValue, valueLayout); auto maskElems = getLLVMElems(mask, llMask, maskLayout); assert(valueElems.size() == maskElems.size()); @@ -718,17 +685,8 @@ struct StoreOpConversion const int numVecs = ptrNumElems / vec; for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) { - size_t in_off{}; - auto ptrProducer = llPtr.getDefiningOp(); - auto in_gep = llvm::dyn_cast(ptrProducer); - - if (in_gep) { - auto indices = in_gep.getIndices(); - auto cst = dyn_cast(indices.front().getDefiningOp()); - in_off = - cst ? cst.getValue().dyn_cast().getInt() * dtsize : 0; - ptr = cst ? in_gep.getBase() : in_gep; - } + // TODO: optimization when ptr is GEP with constant offset + size_t in_off = 0; // pack sub-words (< 32/64bits) into words // each load has width min(nbits*vec, 32/64) @@ -747,7 +705,7 @@ struct StoreOpConversion const bool hasL2EvictPolicy = false; PtxIOInstr asmStoreInstr("st"); - asmStoreInstr.predicate(llMask, "b"); + asmStoreInstr.predicate(maskElems[vecIdx], "b"); asmStoreInstr.global().v(width).b(nWords); llvm::SmallVector asmArgs; @@ -755,7 +713,8 @@ struct StoreOpConversion Type valArgTy = IntegerType::get(ctx, width); auto wordTy = VectorType::get(wordNElems, valueElemTy); - auto *asmAddr = asmStoreInstr.newAddrOperand(llPtr, "l", in_off); + auto *asmAddr = + asmStoreInstr.newAddrOperand(ptrElems[vecIdx], "l", in_off); auto *asmArgList = asmStoreInstr.newList(); for (int wordIdx = 0; wordIdx < nWords; wordIdx++) { // llWord is a width-len composition @@ -800,9 +759,8 @@ struct StoreOpConversion LLVM::AsmDialect::AD_ATT), // asm_dialect ArrayAttr::get(ctx, {}) // operand_attrs ); - - rewriter.replaceOp(op, inlineAsm.getRes()); } + rewriter.eraseOp(op); return success(); } @@ -1135,6 +1093,10 @@ struct LoadOpConversion // finally call inline ASM // --- SmallVector args = {pred, ptr}; + for (Value v : others) { + args.push_back(v); + } + // TODO: if (has_l2_evict_policy) auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); auto inlineAsmOp = rewriter.create( @@ -1177,6 +1139,95 @@ struct LoadOpConversion } }; +struct GetProgramIdOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::GetProgramIdOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>( + loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x); + auto llvmIndexTy = getTypeConverter()->getIndexType(); + rewriter.replaceOpWithNewOp( + op, TypeRange{llvmIndexTy}, ValueRange{blockId}); + return success(); + } +}; + +struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::GEPOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = op.getType().dyn_cast(); + auto resultLayout = + resultTy.getEncoding().dyn_cast(); + auto resultShape = resultTy.getShape(); + unsigned elems = getElemsPerThread(resultLayout, resultShape); + Type elemTy = + this->getTypeConverter()->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types); + auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter); + auto offsets = + getElementsFromStruct(loc, adaptor.offset(), elems, rewriter); + SmallVector resultVals(elems); + for (unsigned i = 0; i < elems; ++i) { + resultVals[i] = + rewriter.create(loc, elemTy, ptrs[i], offsets[i]); + } + Value view = getStructFromElements(loc, resultVals, rewriter, structTy); + rewriter.replaceOp(op, view); + return success(); + } +}; + +template +class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit BinaryOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTy = op.getType().template dyn_cast(); + // ArithmeticToLLVM will handle the lowering of scalar ArithOps + if (!resultTy) + return failure(); + + Location loc = op->getLoc(); + auto resultLayout = resultTy.getEncoding() + .template dyn_cast(); + auto resultShape = resultTy.getShape(); + unsigned elems = getElemsPerThread(resultLayout, resultShape); + Type elemTy = + this->getTypeConverter()->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); + auto lhss = + this->getElementsFromStruct(loc, adaptor.getLhs(), elems, rewriter); + auto rhss = + this->getElementsFromStruct(loc, adaptor.getRhs(), elems, rewriter); + SmallVector resultVals(elems); + for (unsigned i = 0; i < elems; ++i) { + resultVals[i] = rewriter.create(loc, elemTy, lhss[i], rhss[i]); + } + Value view = getStructFromElements(loc, resultVals, rewriter, structTy); + rewriter.replaceOp(op, view); + return success(); + } +}; + class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { public: using TypeConverter::convertType; @@ -1221,14 +1272,20 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &analysis, PatternBenefit benefit = 1) { - patterns.add(typeConverter, numWarps, benefit); - patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, analysis, benefit); - patterns.add(typeConverter, benefit); + patterns.add>(typeConverter, + benefit); + patterns.add>(typeConverter, + benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, numWarps, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, analysis, benefit); patterns.add(typeConverter, benefit); } diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 99d4710ca..3f35e7c23 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -8,8 +8,10 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "mlir/Transforms/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "triton/driver/llvm.h" #include "llvm/IR/Constants.h" @@ -82,7 +84,8 @@ std::unique_ptr translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) { auto context = module->getContext(); DialectRegistry registry; - registerLLVMDialectTranslation(registry); + mlir::registerLLVMDialectTranslation(registry); + mlir::registerNVVMDialectTranslation(registry); context->appendDialectRegistry(registry); llvm::DenseMap nvvmMetadata; @@ -123,6 +126,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, applyPassManagerCLOptions(pm); pm.addPass(createConvertTritonGPUToLLVMPass()); + // Conanicalize to eliminate the remaining UnrealizedConversionCastOp + pm.addPass(mlir::createCanonicalizerPass()); if (failed(pm.run(module))) { llvm::errs() << "Pass execution failed"; diff --git a/python/src/triton.cc b/python/src/triton.cc index 8b7d93e89..f4e038d8a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1504,7 +1504,7 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUVerifier()); }) - .def("triton_gpu_to_llvm", [](mlir::PassManager &self) { + .def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) { self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); }); } diff --git a/python/test/vecadd_no_scf.py b/python/test/vecadd_no_scf.py new file mode 100644 index 000000000..573194a59 --- /dev/null +++ b/python/test/vecadd_no_scf.py @@ -0,0 +1,30 @@ +import triton +import triton.language as tl + +NUM_WARPS = 4 + +# triton kernel + + +@triton.jit +def kernel(x_ptr, stride_xn, + y_ptr, stride_yn, + z_ptr, stride_zn, + BLOCK_SIZE_N: tl.constexpr): + pid = tl.program_id(axis=0) + offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_ptrs = x_ptr + offset + y_ptrs = y_ptr + offset + x = tl.load(x_ptrs) + y = tl.load(y_ptrs) + z = x + y + z_ptrs = z_ptr + offset + tl.store(z_ptrs, z) + + +ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx") + +print(ret) + +# TODO: base class for python end2end tests, +# runtime execution, correctness comparison etc. diff --git a/test/Conversion/triton_to_llvm.mlir b/test/Conversion/triton_to_llvm.mlir index 0d9aea81d..5bb411c35 100644 --- a/test/Conversion/triton_to_llvm.mlir +++ b/test/Conversion/triton_to_llvm.mlir @@ -19,6 +19,8 @@ func @test_splat(%ptr: !tt.ptr) { return } +// ----- + func @test_store_splat(%ptr: !tt.ptr) { %ptrs = tt.splat %ptr : (!tt.ptr) -> tensor<128x!tt.ptr> %a = arith.constant 1.0 : f32 @@ -27,9 +29,8 @@ func @test_store_splat(%ptr: !tt.ptr) { %vs = tt.splat %a : (f32) -> tensor<128xf32> %mask = tt.splat %true : (i1) -> tensor<128xi1> - // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@%0 st.global.v32.b1 [ %1 + 0 ], { %2 };", - // CHECK: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.struct<(i1, i1)>, !llvm.struct<(ptr, ptr)>, i32) -> !llvm.struct<()> - + // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.v32.b1 [ $1 + 0 ], { $2 };", + // CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> tt.store %ptrs, %vs, %mask, {} : tensor<128xf32> return diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 4628e326d..b714fbb44 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -112,28 +112,81 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { } } -// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -// #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -// #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -// module attributes {"triton_gpu.num-warps" = 4 : i32} { -// func @debut_kernel(%lb : index, %A : !tt.ptr, %B : !tt.ptr, %C : !tt.ptr) { -// %cst = arith.constant dense : tensor<256xi1, #blocked0> -// %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> -// %cst_1 = arith.constant dense : tensor<1024x256xi1, #blocked1> -// %cst_2 = arith.constant dense : tensor<256x2048xi1, #blocked2> -// %a_ptr_init = tt.splat %A : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> -// %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> -// %4 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<1x256xf32,#blocked1> -// %5 = tt.broadcast %4 : (tensor<1x256xf32,#blocked1>) -> tensor<1024x256xf32, #blocked1> -// %6 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2> -// %7 = tt.broadcast %6 : (tensor<256x1xf32,#blocked2>) -> tensor<256x2048xf32, #blocked2> -// %b_ptr_init = tt.splat %A : (!tt.ptr) -> tensor<1024x256x!tt.ptr, #blocked1> -// %c_ptr_init = tt.splat %A : (!tt.ptr) -> tensor<256x2048x!tt.ptr, #blocked2> -// tt.store %b_ptr_init, %5, %cst_1, : tensor<1024x256xf32, #blocked1> -// tt.store %c_ptr_init, %7, %cst_2, : tensor<256x2048xf32, #blocked2> -// return -// } -// } +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_addf + func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { + // CHECK: llvm.fadd + // CHECK: llvm.fadd + %1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0> + return + } +} +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_addi + func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { + // CHECK: llvm.add + // CHECK: llvm.add + %1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0> + return + } +} + +// ----- + +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_program_id + func @basic_program_id() { + // CHECK: nvvm.read.ptx.sreg.ctaid.x : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_gep + func @basic_gep(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { + // CHECK: llvm.getelementptr + // CHECK: llvm.getelementptr + %0 = tt.getelementptr %arg0, %arg1 : tensor<256x!tt.ptr, #blocked0> + return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK: basic_splat + func @basic_splat(%ptr: !tt.ptr) { + // CHECK: llvm.mlir.undef + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue + %0 = tt.splat %ptr : (!tt.ptr) -> tensor<256x!tt.ptr,#blocked0> + return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_store + func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> + tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0> + return + } +} \ No newline at end of file