From d01353de077bcb1d4fc80d525e114c5cf9c4c431 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Wed, 31 Aug 2022 18:55:32 -0700 Subject: [PATCH] [CI] add assert-enabled MLIR option (#78) This deprecates the use of release-build LLVM hosted by the LLVM project, which makes debugging harder for developers. This PR implements the following solution: 1. Create LLVM release tarballs with assert enabled on our own (using Docker) 2. Host them in our own GitHub repositories 3. Use our LLVM for CI and/or development if `TRITON_USE_ASSERT_ENABLED_LLVM=1` is set. --- .github/workflows/integration-tests.yml | 2 +- CMakeLists.txt | 2 +- .../triton/Dialect/Triton/IR/TritonDialect.td | 3 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 37 ++++++++++--------- lib/Target/LLVMIR/LLVMIRTranslation.cpp | 2 +- python/setup.py | 16 ++++++-- test/Conversion/triton_to_llvm.mlir | 4 +- test/Conversion/tritongpu_to_llvm.mlir | 4 +- 8 files changed, 41 insertions(+), 29 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index b9605d427..70b0f5d7c 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -47,7 +47,7 @@ jobs: run: | alias python='python3' cd python - pip3 install -e '.[tests]' + TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]' - name: Run lit tests run: | diff --git a/CMakeLists.txt b/CMakeLists.txt index 6d217b903..0ae7c245e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -192,12 +192,12 @@ target_link_libraries(triton # optimizations MLIRPass MLIRTransforms - MLIRIR MLIRLLVMIR MLIRSupport MLIRTargetLLVMIRExport MLIRExecutionEngine MLIRNVVMToLLVMIRTranslation + MLIRIR ) target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td index 3c5611448..9262dc098 100644 --- a/include/triton/Dialect/Triton/IR/TritonDialect.td +++ b/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -23,7 +23,8 @@ def Triton_Dialect : Dialect { let dependentDialects = [ "arith::ArithmeticDialect", "StandardOpsDialect", - "scf::SCFDialect" + "scf::SCFDialect", + "gpu::GPUDialect", // Since LLVM 15 // "cf::ControlFlowDialect", diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index a8960cf95..3964de000 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -721,18 +721,11 @@ struct StoreOpConversion PTXBuilder ptxBuilder; auto &ptxStoreInstr = *ptxBuilder.create("st"); - Value maskVal = - llMask ? maskElems[vecIdx] - : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(), - rewriter.getIntegerType(1), 1); - ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords); - llvm::SmallVector asmArgs; Type valArgTy = IntegerType::get(ctx, width); auto wordTy = VectorType::get(wordNElems, valueElemTy); - auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off); auto *asmArgList = ptxBuilder.newListOperand(); for (int wordIdx = 0; wordIdx < nWords; wordIdx++) { // llWord is a width-len composition @@ -757,13 +750,21 @@ struct StoreOpConversion asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint)); } + Value maskVal = + llMask ? maskElems[vecIdx] + : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(), + rewriter.getIntegerType(1), 1); + ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords); + + auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off); + ptxStoreInstr(asmAddr, asmArgList); Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); llvm::SmallVector argTys({boolTy, ptr.getType()}); for (int i = 0; i < nWords; i++) argTys.push_back(valArgTy); - auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {}); + auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx); auto inlineAsm = rewriter.create( loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands @@ -1028,14 +1029,22 @@ struct LoadOpConversion // create inline asm string // --- - const std::string writeConstrait = - (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); const std::string readConstrait = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); + const std::string writeConstrait = + (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); PTXBuilder ptxBuilder; PtxIOInstr &ld = *ptxBuilder.create("ld"); + // prepare asm operands + auto *dstsOpr = ptxBuilder.newListOperand(); + for (int i = 0; i < n_words; i++) { + auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations + dstsOpr->listAppend(opr); + } + auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off); + // Define the instruction opcode ld.predicate(pred, "b") .o("violatile", op.isVolatile()) @@ -1049,14 +1058,6 @@ struct LoadOpConversion .v(n_words) .b(width); - // prepare asm operands - auto *dstsOpr = ptxBuilder.newListOperand(); - for (int i = 0; i < n_words; i++) { - auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations - dstsOpr->listAppend(opr); - } - auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off); - PTXBuilder::Operand *evictOpr{}; // Here lack a mlir::Value to bind to this operation, so disabled. // if (has_l2_evict_policy) diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 3f35e7c23..03ebb1578 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -65,7 +65,7 @@ void extractNVVMMetadata(mlir::ModuleOp module, // maxntid if (op->hasAttr(NVVMMetadataField::MaxNTid)) { auto attr = op->getAttr(NVVMMetadataField::MaxNTid); - meta.maxntidx = attr.dyn_cast().getInt(); + meta.maxntidx = attr.dyn_cast().getSInt(); hasMetadata = true; } diff --git a/python/setup.py b/python/setup.py index 5c994f2d5..305de5520 100644 --- a/python/setup.py +++ b/python/setup.py @@ -15,11 +15,22 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext +# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py +def check_env_flag(name: str, default: str = "") -> bool: + return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] + + def get_llvm(): # download if nothing is installed system = platform.system() - suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system] - name = f'clang+llvm-14.0.0-x86_64-{suffix}' + system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system] + use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False") + if use_assert_enabled_llvm: + name = 'llvm+mlir-14.0.0-x86_64-{}-assert'.format(system_suffix) + url = "https://github.com/shintaro-iwasaki/llvm-releases/releases/download/llvm-14.0.0-329fda39c507/{}.tar.xz".format(name) + else: + name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix) + url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name) dir = '/tmp' llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name) llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name) @@ -28,7 +39,6 @@ def get_llvm(): shutil.rmtree(os.path.join(dir, name)) except Exception: pass - url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{name}.tar.xz".format(name=name) print('downloading and extracting ' + url + '...') ftpstream = urllib.request.urlopen(url) file = tarfile.open(fileobj=ftpstream, mode="r|xz") diff --git a/test/Conversion/triton_to_llvm.mlir b/test/Conversion/triton_to_llvm.mlir index 1637d11af..218d58488 100644 --- a/test/Conversion/triton_to_llvm.mlir +++ b/test/Conversion/triton_to_llvm.mlir @@ -29,8 +29,8 @@ func @test_store_splat(%ptr: !tt.ptr) { %vs = tt.splat %a : (f32) -> tensor<128xf32> %mask = tt.splat %true : (i1) -> tensor<128xi1> - // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };", - // CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> + // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$1 st.global.b32 [ $2 + 0 ], { $0 };", + // CHECK-SAME: "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr) -> !llvm.void tt.store %ptrs, %vs, %mask : tensor<128xf32> return diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index cde1f85cd..b164db9b0 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -183,9 +183,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_store func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> + // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr) -> !llvm.void // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> + // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr) -> !llvm.void tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0> return }