[CI] add assert-enabled MLIR option (#78)
This deprecates the use of release-build LLVM hosted by the LLVM project, which makes debugging harder for developers. This PR implements the following solution: 1. Create LLVM release tarballs with assert enabled on our own (using Docker) 2. Host them in our own GitHub repositories 3. Use our LLVM for CI and/or development if `TRITON_USE_ASSERT_ENABLED_LLVM=1` is set.
This commit is contained in:
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
alias python='python3'
|
alias python='python3'
|
||||||
cd python
|
cd python
|
||||||
pip3 install -e '.[tests]'
|
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
|
||||||
|
|
||||||
- name: Run lit tests
|
- name: Run lit tests
|
||||||
run: |
|
run: |
|
||||||
|
@@ -192,12 +192,12 @@ target_link_libraries(triton
|
|||||||
# optimizations
|
# optimizations
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRIR
|
|
||||||
MLIRLLVMIR
|
MLIRLLVMIR
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
MLIRTargetLLVMIRExport
|
MLIRTargetLLVMIRExport
|
||||||
MLIRExecutionEngine
|
MLIRExecutionEngine
|
||||||
MLIRNVVMToLLVMIRTranslation
|
MLIRNVVMToLLVMIRTranslation
|
||||||
|
MLIRIR
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||||
|
@@ -23,7 +23,8 @@ def Triton_Dialect : Dialect {
|
|||||||
let dependentDialects = [
|
let dependentDialects = [
|
||||||
"arith::ArithmeticDialect",
|
"arith::ArithmeticDialect",
|
||||||
"StandardOpsDialect",
|
"StandardOpsDialect",
|
||||||
"scf::SCFDialect"
|
"scf::SCFDialect",
|
||||||
|
"gpu::GPUDialect",
|
||||||
|
|
||||||
// Since LLVM 15
|
// Since LLVM 15
|
||||||
// "cf::ControlFlowDialect",
|
// "cf::ControlFlowDialect",
|
||||||
|
@@ -721,18 +721,11 @@ struct StoreOpConversion
|
|||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
|
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
|
||||||
|
|
||||||
Value maskVal =
|
|
||||||
llMask ? maskElems[vecIdx]
|
|
||||||
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
|
||||||
rewriter.getIntegerType(1), 1);
|
|
||||||
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
|
|
||||||
|
|
||||||
llvm::SmallVector<std::string> asmArgs;
|
llvm::SmallVector<std::string> asmArgs;
|
||||||
|
|
||||||
Type valArgTy = IntegerType::get(ctx, width);
|
Type valArgTy = IntegerType::get(ctx, width);
|
||||||
auto wordTy = VectorType::get(wordNElems, valueElemTy);
|
auto wordTy = VectorType::get(wordNElems, valueElemTy);
|
||||||
|
|
||||||
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off);
|
|
||||||
auto *asmArgList = ptxBuilder.newListOperand();
|
auto *asmArgList = ptxBuilder.newListOperand();
|
||||||
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
||||||
// llWord is a width-len composition
|
// llWord is a width-len composition
|
||||||
@@ -757,13 +750,21 @@ struct StoreOpConversion
|
|||||||
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value maskVal =
|
||||||
|
llMask ? maskElems[vecIdx]
|
||||||
|
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
||||||
|
rewriter.getIntegerType(1), 1);
|
||||||
|
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
|
||||||
|
|
||||||
|
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off);
|
||||||
|
|
||||||
ptxStoreInstr(asmAddr, asmArgList);
|
ptxStoreInstr(asmAddr, asmArgList);
|
||||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||||
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
||||||
for (int i = 0; i < nWords; i++)
|
for (int i = 0; i < nWords; i++)
|
||||||
argTys.push_back(valArgTy);
|
argTys.push_back(valArgTy);
|
||||||
|
|
||||||
auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {});
|
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
|
||||||
|
|
||||||
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
||||||
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
|
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
|
||||||
@@ -1028,14 +1029,22 @@ struct LoadOpConversion
|
|||||||
// create inline asm string
|
// create inline asm string
|
||||||
// ---
|
// ---
|
||||||
|
|
||||||
const std::string writeConstrait =
|
|
||||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
|
||||||
const std::string readConstrait =
|
const std::string readConstrait =
|
||||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||||
|
const std::string writeConstrait =
|
||||||
|
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||||
|
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
PtxIOInstr &ld = *ptxBuilder.create<PtxIOInstr>("ld");
|
PtxIOInstr &ld = *ptxBuilder.create<PtxIOInstr>("ld");
|
||||||
|
|
||||||
|
// prepare asm operands
|
||||||
|
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||||
|
for (int i = 0; i < n_words; i++) {
|
||||||
|
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations
|
||||||
|
dstsOpr->listAppend(opr);
|
||||||
|
}
|
||||||
|
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
|
||||||
|
|
||||||
// Define the instruction opcode
|
// Define the instruction opcode
|
||||||
ld.predicate(pred, "b")
|
ld.predicate(pred, "b")
|
||||||
.o("violatile", op.isVolatile())
|
.o("violatile", op.isVolatile())
|
||||||
@@ -1049,14 +1058,6 @@ struct LoadOpConversion
|
|||||||
.v(n_words)
|
.v(n_words)
|
||||||
.b(width);
|
.b(width);
|
||||||
|
|
||||||
// prepare asm operands
|
|
||||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
|
||||||
for (int i = 0; i < n_words; i++) {
|
|
||||||
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations
|
|
||||||
dstsOpr->listAppend(opr);
|
|
||||||
}
|
|
||||||
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
|
|
||||||
|
|
||||||
PTXBuilder::Operand *evictOpr{};
|
PTXBuilder::Operand *evictOpr{};
|
||||||
// Here lack a mlir::Value to bind to this operation, so disabled.
|
// Here lack a mlir::Value to bind to this operation, so disabled.
|
||||||
// if (has_l2_evict_policy)
|
// if (has_l2_evict_policy)
|
||||||
|
@@ -65,7 +65,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
|||||||
// maxntid
|
// maxntid
|
||||||
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
|
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
|
||||||
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
|
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
|
||||||
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
|
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getSInt();
|
||||||
hasMetadata = true;
|
hasMetadata = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -15,11 +15,22 @@ from setuptools import Extension, setup
|
|||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
|
|
||||||
|
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
||||||
|
def check_env_flag(name: str, default: str = "") -> bool:
|
||||||
|
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
|
||||||
|
|
||||||
|
|
||||||
def get_llvm():
|
def get_llvm():
|
||||||
# download if nothing is installed
|
# download if nothing is installed
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
|
system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
|
||||||
name = f'clang+llvm-14.0.0-x86_64-{suffix}'
|
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||||
|
if use_assert_enabled_llvm:
|
||||||
|
name = 'llvm+mlir-14.0.0-x86_64-{}-assert'.format(system_suffix)
|
||||||
|
url = "https://github.com/shintaro-iwasaki/llvm-releases/releases/download/llvm-14.0.0-329fda39c507/{}.tar.xz".format(name)
|
||||||
|
else:
|
||||||
|
name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix)
|
||||||
|
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name)
|
||||||
dir = '/tmp'
|
dir = '/tmp'
|
||||||
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
|
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
|
||||||
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
|
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
|
||||||
@@ -28,7 +39,6 @@ def get_llvm():
|
|||||||
shutil.rmtree(os.path.join(dir, name))
|
shutil.rmtree(os.path.join(dir, name))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{name}.tar.xz".format(name=name)
|
|
||||||
print('downloading and extracting ' + url + '...')
|
print('downloading and extracting ' + url + '...')
|
||||||
ftpstream = urllib.request.urlopen(url)
|
ftpstream = urllib.request.urlopen(url)
|
||||||
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
|
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
|
||||||
|
@@ -29,8 +29,8 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
|||||||
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
||||||
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
||||||
|
|
||||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };",
|
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$1 st.global.b32 [ $2 + 0 ], { $0 };",
|
||||||
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
// CHECK-SAME: "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
|
||||||
tt.store %ptrs, %vs, %mask : tensor<128xf32>
|
tt.store %ptrs, %vs, %mask : tensor<128xf32>
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@@ -183,9 +183,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
// CHECK-LABEL: basic_store
|
// CHECK-LABEL: basic_store
|
||||||
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
|
||||||
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user