[CI] add assert-enabled MLIR option (#78)
This deprecates the use of release-build LLVM hosted by the LLVM project, which makes debugging harder for developers. This PR implements the following solution: 1. Create LLVM release tarballs with assert enabled on our own (using Docker) 2. Host them in our own GitHub repositories 3. Use our LLVM for CI and/or development if `TRITON_USE_ASSERT_ENABLED_LLVM=1` is set.
This commit is contained in:
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
run: |
|
||||
alias python='python3'
|
||||
cd python
|
||||
pip3 install -e '.[tests]'
|
||||
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
|
||||
|
||||
- name: Run lit tests
|
||||
run: |
|
||||
|
@@ -192,12 +192,12 @@ target_link_libraries(triton
|
||||
# optimizations
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
MLIRExecutionEngine
|
||||
MLIRNVVMToLLVMIRTranslation
|
||||
MLIRIR
|
||||
)
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
@@ -23,7 +23,8 @@ def Triton_Dialect : Dialect {
|
||||
let dependentDialects = [
|
||||
"arith::ArithmeticDialect",
|
||||
"StandardOpsDialect",
|
||||
"scf::SCFDialect"
|
||||
"scf::SCFDialect",
|
||||
"gpu::GPUDialect",
|
||||
|
||||
// Since LLVM 15
|
||||
// "cf::ControlFlowDialect",
|
||||
|
@@ -721,18 +721,11 @@ struct StoreOpConversion
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
|
||||
|
||||
Value maskVal =
|
||||
llMask ? maskElems[vecIdx]
|
||||
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
||||
rewriter.getIntegerType(1), 1);
|
||||
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
|
||||
|
||||
llvm::SmallVector<std::string> asmArgs;
|
||||
|
||||
Type valArgTy = IntegerType::get(ctx, width);
|
||||
auto wordTy = VectorType::get(wordNElems, valueElemTy);
|
||||
|
||||
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off);
|
||||
auto *asmArgList = ptxBuilder.newListOperand();
|
||||
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
||||
// llWord is a width-len composition
|
||||
@@ -757,13 +750,21 @@ struct StoreOpConversion
|
||||
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
||||
}
|
||||
|
||||
Value maskVal =
|
||||
llMask ? maskElems[vecIdx]
|
||||
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
||||
rewriter.getIntegerType(1), 1);
|
||||
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
|
||||
|
||||
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off);
|
||||
|
||||
ptxStoreInstr(asmAddr, asmArgList);
|
||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
||||
for (int i = 0; i < nWords; i++)
|
||||
argTys.push_back(valArgTy);
|
||||
|
||||
auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {});
|
||||
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
|
||||
|
||||
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
||||
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
|
||||
@@ -1028,14 +1029,22 @@ struct LoadOpConversion
|
||||
// create inline asm string
|
||||
// ---
|
||||
|
||||
const std::string writeConstrait =
|
||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||
const std::string readConstrait =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
const std::string writeConstrait =
|
||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
PtxIOInstr &ld = *ptxBuilder.create<PtxIOInstr>("ld");
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (int i = 0; i < n_words; i++) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
}
|
||||
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
|
||||
|
||||
// Define the instruction opcode
|
||||
ld.predicate(pred, "b")
|
||||
.o("violatile", op.isVolatile())
|
||||
@@ -1049,14 +1058,6 @@ struct LoadOpConversion
|
||||
.v(n_words)
|
||||
.b(width);
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (int i = 0; i < n_words; i++) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
}
|
||||
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
|
||||
|
||||
PTXBuilder::Operand *evictOpr{};
|
||||
// Here lack a mlir::Value to bind to this operation, so disabled.
|
||||
// if (has_l2_evict_policy)
|
||||
|
@@ -65,7 +65,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
||||
// maxntid
|
||||
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
|
||||
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
|
||||
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
|
||||
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getSInt();
|
||||
hasMetadata = true;
|
||||
}
|
||||
|
||||
|
@@ -15,11 +15,22 @@ from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
||||
def check_env_flag(name: str, default: str = "") -> bool:
|
||||
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
|
||||
|
||||
|
||||
def get_llvm():
|
||||
# download if nothing is installed
|
||||
system = platform.system()
|
||||
suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
|
||||
name = f'clang+llvm-14.0.0-x86_64-{suffix}'
|
||||
system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
|
||||
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
if use_assert_enabled_llvm:
|
||||
name = 'llvm+mlir-14.0.0-x86_64-{}-assert'.format(system_suffix)
|
||||
url = "https://github.com/shintaro-iwasaki/llvm-releases/releases/download/llvm-14.0.0-329fda39c507/{}.tar.xz".format(name)
|
||||
else:
|
||||
name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix)
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name)
|
||||
dir = '/tmp'
|
||||
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
|
||||
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
|
||||
@@ -28,7 +39,6 @@ def get_llvm():
|
||||
shutil.rmtree(os.path.join(dir, name))
|
||||
except Exception:
|
||||
pass
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{name}.tar.xz".format(name=name)
|
||||
print('downloading and extracting ' + url + '...')
|
||||
ftpstream = urllib.request.urlopen(url)
|
||||
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
|
||||
|
@@ -29,8 +29,8 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
||||
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
||||
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
||||
|
||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };",
|
||||
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$1 st.global.b32 [ $2 + 0 ], { $0 };",
|
||||
// CHECK-SAME: "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
|
||||
tt.store %ptrs, %vs, %mask : tensor<128xf32>
|
||||
|
||||
return
|
||||
|
@@ -183,9 +183,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_store
|
||||
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
|
||||
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user