[BUILD] Fix compilation problems in the release build (#897)

This commit is contained in:
Keren Zhou
2022-11-20 21:40:36 -08:00
committed by GitHub
parent 23f71daa27
commit 85cccfb81f
3 changed files with 21 additions and 16 deletions

View File

@@ -1438,7 +1438,8 @@ struct BroadcastOpConversion
SmallVector<int64_t> resultLogicalShape(2 * rank); SmallVector<int64_t> resultLogicalShape(2 * rank);
SmallVector<unsigned> broadcastDims; SmallVector<unsigned> broadcastDims;
for (unsigned d = 0; d < rank; ++d) { for (unsigned d = 0; d < rank; ++d) {
unsigned resultShapePerCTA = triton::gpu::getSizePerThread(resultLayout)[d] * unsigned resultShapePerCTA =
triton::gpu::getSizePerThread(resultLayout)[d] *
triton::gpu::getThreadsPerWarp(resultLayout)[d] * triton::gpu::getThreadsPerWarp(resultLayout)[d] *
triton::gpu::getWarpsPerCTA(resultLayout)[d]; triton::gpu::getWarpsPerCTA(resultLayout)[d];
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA); int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
@@ -1450,10 +1451,12 @@ struct BroadcastOpConversion
std::max<unsigned>(1, triton::gpu::getSizePerThread(srcLayout)[d]); std::max<unsigned>(1, triton::gpu::getSizePerThread(srcLayout)[d]);
} else { } else {
srcLogicalShape[d] = numCtas; srcLogicalShape[d] = numCtas;
srcLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d]; srcLogicalShape[d + rank] =
triton::gpu::getSizePerThread(resultLayout)[d];
} }
resultLogicalShape[d] = numCtas; resultLogicalShape[d] = numCtas;
resultLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d]; resultLogicalShape[d + rank] =
triton::gpu::getSizePerThread(resultLayout)[d];
srcLogicalOrder[d] = order[d] + rank; srcLogicalOrder[d] = order[d] + rank;
srcLogicalOrder[d + rank] = order[d]; srcLogicalOrder[d + rank] = order[d];
@@ -1968,6 +1971,7 @@ struct PrintfOpConversion
return "%u"; return "%u";
} }
assert(false && "not supported type"); assert(false && "not supported type");
return "";
} }
// declare vprintf(i8*, i8*) as external function // declare vprintf(i8*, i8*) as external function
@@ -5482,6 +5486,7 @@ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
} }
assert(false && "Unsupported mma layout found"); assert(false && "Unsupported mma layout found");
return {};
} }
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {

View File

@@ -6,7 +6,6 @@ import shutil
import subprocess import subprocess
import sys import sys
import tarfile import tarfile
import tempfile
import urllib.request import urllib.request
from distutils.version import LooseVersion from distutils.version import LooseVersion
from typing import NamedTuple from typing import NamedTuple
@@ -26,7 +25,9 @@ def get_build_type():
elif check_env_flag("REL_WITH_DEB_INFO"): elif check_env_flag("REL_WITH_DEB_INFO"):
return "RelWithDebInfo" return "RelWithDebInfo"
else: else:
return "Release" return "Debug"
# TODO(Keren): Restore this before we merge into master
#return "Release"
# --- third party packages ----- # --- third party packages -----
@@ -124,19 +125,14 @@ class CMakeBuild(build_ext):
self.build_extension(ext) self.build_extension(ext)
def build_extension(self, ext): def build_extension(self, ext):
self.debug = True
lit_dir = shutil.which('lit') lit_dir = shutil.which('lit')
triton_cache_path = os.path.join(os.environ["HOME"], ".triton") triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
# lit is used by the test suite # lit is used by the test suite
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path) thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories # create build directories
build_suffix = 'debug' if self.debug else 'release'
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm-" + build_suffix)
if not os.path.exists(self.build_temp): if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp) os.makedirs(self.build_temp)
if not os.path.exists(llvm_build_dir):
os.makedirs(llvm_build_dir)
# python directories # python directories
python_include_dir = distutils.sysconfig.get_python_inc() python_include_dir = distutils.sysconfig.get_python_inc()
cmake_args = [ cmake_args = [
@@ -145,13 +141,13 @@ class CMakeBuild(build_ext):
"-DTRITON_BUILD_TUTORIALS=OFF", "-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON", "-DTRITON_BUILD_PYTHON_MODULE=ON",
# '-DPYTHON_EXECUTABLE=' + sys.executable, # '-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir, "-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DLLVM_EXTERNAL_LIT=" + lit_dir "-DLLVM_EXTERNAL_LIT=" + lit_dir
] + thirdparty_cmake_args ] + thirdparty_cmake_args
# configuration # configuration
cfg = "Debug" if self.debug else "Release" cfg = get_build_type()
build_args = ["--config", cfg] build_args = ["--config", cfg]
if platform.system() == "Windows": if platform.system() == "Windows":
@@ -183,7 +179,11 @@ setup(
"torch", "torch",
"lit", "lit",
], ],
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, package_data={
"triton/ops": ["*.c"],
"triton/ops/blocksparse": ["*.c"],
"triton/language": ["*.bc"]
},
include_package_data=True, include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")], ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild}, cmdclass={"build_ext": CMakeBuild},

View File

@@ -551,7 +551,7 @@ void init_triton_ir(py::module &&m) {
return llvm::dyn_cast<mlir::FuncOp>(funcOperation); return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) { if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
mlir::ArrayRef<mlir::NamedAttribute> attrs = { llvm::SmallVector<mlir::NamedAttribute> attrs = {
mlir::NamedAttribute(self.getStringAttr("sym_visibility"), mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
self.getStringAttr(visibility))}; self.getStringAttr(visibility))};
return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs); return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);