[BUILD] Fix compilation problems in the release build (#897)
This commit is contained in:
@@ -1438,9 +1438,10 @@ struct BroadcastOpConversion
|
|||||||
SmallVector<int64_t> resultLogicalShape(2 * rank);
|
SmallVector<int64_t> resultLogicalShape(2 * rank);
|
||||||
SmallVector<unsigned> broadcastDims;
|
SmallVector<unsigned> broadcastDims;
|
||||||
for (unsigned d = 0; d < rank; ++d) {
|
for (unsigned d = 0; d < rank; ++d) {
|
||||||
unsigned resultShapePerCTA = triton::gpu::getSizePerThread(resultLayout)[d] *
|
unsigned resultShapePerCTA =
|
||||||
triton::gpu::getThreadsPerWarp(resultLayout)[d] *
|
triton::gpu::getSizePerThread(resultLayout)[d] *
|
||||||
triton::gpu::getWarpsPerCTA(resultLayout)[d];
|
triton::gpu::getThreadsPerWarp(resultLayout)[d] *
|
||||||
|
triton::gpu::getWarpsPerCTA(resultLayout)[d];
|
||||||
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
|
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
|
||||||
if (srcShape[d] != resultShape[d]) {
|
if (srcShape[d] != resultShape[d]) {
|
||||||
assert(srcShape[d] == 1);
|
assert(srcShape[d] == 1);
|
||||||
@@ -1450,10 +1451,12 @@ struct BroadcastOpConversion
|
|||||||
std::max<unsigned>(1, triton::gpu::getSizePerThread(srcLayout)[d]);
|
std::max<unsigned>(1, triton::gpu::getSizePerThread(srcLayout)[d]);
|
||||||
} else {
|
} else {
|
||||||
srcLogicalShape[d] = numCtas;
|
srcLogicalShape[d] = numCtas;
|
||||||
srcLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d];
|
srcLogicalShape[d + rank] =
|
||||||
|
triton::gpu::getSizePerThread(resultLayout)[d];
|
||||||
}
|
}
|
||||||
resultLogicalShape[d] = numCtas;
|
resultLogicalShape[d] = numCtas;
|
||||||
resultLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d];
|
resultLogicalShape[d + rank] =
|
||||||
|
triton::gpu::getSizePerThread(resultLayout)[d];
|
||||||
|
|
||||||
srcLogicalOrder[d] = order[d] + rank;
|
srcLogicalOrder[d] = order[d] + rank;
|
||||||
srcLogicalOrder[d + rank] = order[d];
|
srcLogicalOrder[d + rank] = order[d];
|
||||||
@@ -1968,6 +1971,7 @@ struct PrintfOpConversion
|
|||||||
return "%u";
|
return "%u";
|
||||||
}
|
}
|
||||||
assert(false && "not supported type");
|
assert(false && "not supported type");
|
||||||
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
// declare vprintf(i8*, i8*) as external function
|
// declare vprintf(i8*, i8*) as external function
|
||||||
@@ -5482,6 +5486,7 @@ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert(false && "Unsupported mma layout found");
|
assert(false && "Unsupported mma layout found");
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
|
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
|
||||||
|
@@ -6,7 +6,6 @@ import shutil
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
import tempfile
|
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
@@ -26,7 +25,9 @@ def get_build_type():
|
|||||||
elif check_env_flag("REL_WITH_DEB_INFO"):
|
elif check_env_flag("REL_WITH_DEB_INFO"):
|
||||||
return "RelWithDebInfo"
|
return "RelWithDebInfo"
|
||||||
else:
|
else:
|
||||||
return "Release"
|
return "Debug"
|
||||||
|
# TODO(Keren): Restore this before we merge into master
|
||||||
|
#return "Release"
|
||||||
|
|
||||||
|
|
||||||
# --- third party packages -----
|
# --- third party packages -----
|
||||||
@@ -124,19 +125,14 @@ class CMakeBuild(build_ext):
|
|||||||
self.build_extension(ext)
|
self.build_extension(ext)
|
||||||
|
|
||||||
def build_extension(self, ext):
|
def build_extension(self, ext):
|
||||||
self.debug = True
|
|
||||||
lit_dir = shutil.which('lit')
|
lit_dir = shutil.which('lit')
|
||||||
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
|
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
|
||||||
# lit is used by the test suite
|
# lit is used by the test suite
|
||||||
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
|
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
|
||||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||||
# create build directories
|
# create build directories
|
||||||
build_suffix = 'debug' if self.debug else 'release'
|
|
||||||
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm-" + build_suffix)
|
|
||||||
if not os.path.exists(self.build_temp):
|
if not os.path.exists(self.build_temp):
|
||||||
os.makedirs(self.build_temp)
|
os.makedirs(self.build_temp)
|
||||||
if not os.path.exists(llvm_build_dir):
|
|
||||||
os.makedirs(llvm_build_dir)
|
|
||||||
# python directories
|
# python directories
|
||||||
python_include_dir = distutils.sysconfig.get_python_inc()
|
python_include_dir = distutils.sysconfig.get_python_inc()
|
||||||
cmake_args = [
|
cmake_args = [
|
||||||
@@ -145,13 +141,13 @@ class CMakeBuild(build_ext):
|
|||||||
"-DTRITON_BUILD_TUTORIALS=OFF",
|
"-DTRITON_BUILD_TUTORIALS=OFF",
|
||||||
"-DTRITON_BUILD_PYTHON_MODULE=ON",
|
"-DTRITON_BUILD_PYTHON_MODULE=ON",
|
||||||
# '-DPYTHON_EXECUTABLE=' + sys.executable,
|
# '-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||||
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
||||||
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
|
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
|
||||||
"-DLLVM_EXTERNAL_LIT=" + lit_dir
|
"-DLLVM_EXTERNAL_LIT=" + lit_dir
|
||||||
] + thirdparty_cmake_args
|
] + thirdparty_cmake_args
|
||||||
|
|
||||||
# configuration
|
# configuration
|
||||||
cfg = "Debug" if self.debug else "Release"
|
cfg = get_build_type()
|
||||||
build_args = ["--config", cfg]
|
build_args = ["--config", cfg]
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
@@ -183,7 +179,11 @@ setup(
|
|||||||
"torch",
|
"torch",
|
||||||
"lit",
|
"lit",
|
||||||
],
|
],
|
||||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
package_data={
|
||||||
|
"triton/ops": ["*.c"],
|
||||||
|
"triton/ops/blocksparse": ["*.c"],
|
||||||
|
"triton/language": ["*.bc"]
|
||||||
|
},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||||
cmdclass={"build_ext": CMakeBuild},
|
cmdclass={"build_ext": CMakeBuild},
|
||||||
|
@@ -551,7 +551,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
|
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||||
mlir::ArrayRef<mlir::NamedAttribute> attrs = {
|
llvm::SmallVector<mlir::NamedAttribute> attrs = {
|
||||||
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
|
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
|
||||||
self.getStringAttr(visibility))};
|
self.getStringAttr(visibility))};
|
||||||
return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);
|
return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);
|
||||||
|
Reference in New Issue
Block a user