[CI] add assert-enabled MLIR option (#78)

This deprecates the use of release-build LLVM hosted by the LLVM project, which makes debugging harder for developers.

This PR implements the following solution:
1. Create LLVM release tarballs with assert enabled on our own (using Docker)
2. Host them in our own GitHub repositories
3. Use our LLVM for CI and/or development if `TRITON_USE_ASSERT_ENABLED_LLVM=1` is set.
This commit is contained in:
Shintaro Iwasaki
2022-08-31 18:55:32 -07:00
committed by GitHub
parent 02ebf24d35
commit d01353de07
8 changed files with 41 additions and 29 deletions

View File

@@ -47,7 +47,7 @@ jobs:
run: |
alias python='python3'
cd python
pip3 install -e '.[tests]'
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
- name: Run lit tests
run: |

View File

@@ -192,12 +192,12 @@ target_link_libraries(triton
# optimizations
MLIRPass
MLIRTransforms
MLIRIR
MLIRLLVMIR
MLIRSupport
MLIRTargetLLVMIRExport
MLIRExecutionEngine
MLIRNVVMToLLVMIRTranslation
MLIRIR
)
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})

View File

@@ -23,7 +23,8 @@ def Triton_Dialect : Dialect {
let dependentDialects = [
"arith::ArithmeticDialect",
"StandardOpsDialect",
"scf::SCFDialect"
"scf::SCFDialect",
"gpu::GPUDialect",
// Since LLVM 15
// "cf::ControlFlowDialect",

View File

@@ -721,18 +721,11 @@ struct StoreOpConversion
PTXBuilder ptxBuilder;
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
Value maskVal =
llMask ? maskElems[vecIdx]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
llvm::SmallVector<std::string> asmArgs;
Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = VectorType::get(wordNElems, valueElemTy);
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off);
auto *asmArgList = ptxBuilder.newListOperand();
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
// llWord is a width-len composition
@@ -757,13 +750,21 @@ struct StoreOpConversion
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
}
Value maskVal =
llMask ? maskElems[vecIdx]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off);
ptxStoreInstr(asmAddr, asmArgList);
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
for (int i = 0; i < nWords; i++)
argTys.push_back(valArgTy);
auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {});
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
@@ -1028,14 +1029,22 @@ struct LoadOpConversion
// create inline asm string
// ---
const std::string writeConstrait =
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
const std::string readConstrait =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
const std::string writeConstrait =
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
PTXBuilder ptxBuilder;
PtxIOInstr &ld = *ptxBuilder.create<PtxIOInstr>("ld");
// prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand();
for (int i = 0; i < n_words; i++) {
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations
dstsOpr->listAppend(opr);
}
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
// Define the instruction opcode
ld.predicate(pred, "b")
.o("violatile", op.isVolatile())
@@ -1049,14 +1058,6 @@ struct LoadOpConversion
.v(n_words)
.b(width);
// prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand();
for (int i = 0; i < n_words; i++) {
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations
dstsOpr->listAppend(opr);
}
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
PTXBuilder::Operand *evictOpr{};
// Here lack a mlir::Value to bind to this operation, so disabled.
// if (has_l2_evict_policy)

View File

@@ -65,7 +65,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
// maxntid
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getSInt();
hasMetadata = true;
}

View File

@@ -15,11 +15,22 @@ from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
def check_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
def get_llvm():
# download if nothing is installed
system = platform.system()
suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
name = f'clang+llvm-14.0.0-x86_64-{suffix}'
system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
if use_assert_enabled_llvm:
name = 'llvm+mlir-14.0.0-x86_64-{}-assert'.format(system_suffix)
url = "https://github.com/shintaro-iwasaki/llvm-releases/releases/download/llvm-14.0.0-329fda39c507/{}.tar.xz".format(name)
else:
name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix)
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name)
dir = '/tmp'
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
@@ -28,7 +39,6 @@ def get_llvm():
shutil.rmtree(os.path.join(dir, name))
except Exception:
pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...')
ftpstream = urllib.request.urlopen(url)
file = tarfile.open(fileobj=ftpstream, mode="r|xz")

View File

@@ -29,8 +29,8 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
%vs = tt.splat %a : (f32) -> tensor<128xf32>
%mask = tt.splat %true : (i1) -> tensor<128xi1>
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };",
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$1 st.global.b32 [ $2 + 0 ], { $0 };",
// CHECK-SAME: "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
tt.store %ptrs, %vs, %mask : tensor<128xf32>
return

View File

@@ -183,9 +183,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_store
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
return
}