diff --git a/.ci/azure-pipelines.yml b/.ci/azure-pipelines.yml index bfe07a8c1..d6891d460 100644 --- a/.ci/azure-pipelines.yml +++ b/.ci/azure-pipelines.yml @@ -8,7 +8,7 @@ variables: value: venv # Run CI when something pushed to master -trigger: [ master ] +# trigger: [ master ] # Run CI when a PR is created or updated from master pr: - master diff --git a/.ci/build-wheels.yml b/.ci/build-wheels.yml new file mode 100644 index 000000000..71f79e06c --- /dev/null +++ b/.ci/build-wheels.yml @@ -0,0 +1,35 @@ +trigger: none +pr: none + +jobs: +- job: linux + + timeoutInMinutes: 180 + + pool: default + + steps: + - bash: | + set -o errexit + python3 --version + python3 -m pip install --upgrade pip + pip3 install cibuildwheel==1.10.0 + pip3 install twine + displayName: Install dependencies + - bash: | + sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py + sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev`date '+%Y%m%d'`\"/g" python/setup.py + echo "" >> python/setup.cfg + echo "[build_ext]" >> python/setup.cfg + echo "base-dir=/project" >> python/setup.cfg + displayName: Patch setup.py + - bash: | + export CIBW_BEFORE_BUILD="pip install cmake" + export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64" + python3 -m cibuildwheel python --output-dir wheelhouse + displayName: Build wheels + - task: PublishBuildArtifacts@1 + inputs: {pathtoPublish: 'wheelhouse'} + - bash: | + python3 -m twine upload wheelhouse/* --skip-existing -u $(PYPI_USERNAME) -p $(PYPI_PASSWORD) + displayName: Upload wheels to PyPI \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 7faea69d6..b6681ecdc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,11 @@ -cmake_minimum_required(VERSION 2.8) +cmake_minimum_required(VERSION 3.6) +include(ExternalProject) + +if(NOT TRITON_LLVM_BUILD_DIR) + set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR}) +endif() + + project(triton) include(CTest) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") @@ -7,12 +14,6 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON) option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) -# LLVM -find_package(LLVM REQUIRED) -link_directories(${LLVM_LIBRARY_DIRS}) -include_directories(${LLVM_INCLUDE_DIRS}) -add_definitions(${LLVM_DEFINITIONS}) - # Default build type if(NOT CMAKE_BUILD_TYPE) message(STATUS "Default build type: Release") @@ -23,38 +24,77 @@ endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=default -std=gnu++14") -# Tests -if(BUILD_TUTORIALS) - message(STATUS "Adding C++ tutorials") - add_subdirectory(tutorials) -endif() + + +########## +# LLVM +########## +get_cmake_property(_variableNames VARIABLES) +set(__variableNames ${_variableNames}) + +configure_file(cmake/DownloadLLVM.in ${TRITON_LLVM_BUILD_DIR}/llvm-download/CMakeLists.txt) +execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${TRITON_LLVM_BUILD_DIR}/llvm-download" +) +execute_process(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${TRITON_LLVM_BUILD_DIR}/llvm-download" +) +set(LLVM_TARGETS_TO_BUILD "NVPTX" CACHE INTERNAL "") +set(LLVM_BUILD_RUNTIME "OFF" CACHE INTERNAL "") +set(LLVM_BUILD_RUNTIMES "OFF" CACHE INTERNAL "") +set(LLVM_BUILD_TOOLS "OFF" CACHE INTERNAL "") +set(LLVM_BUILD_UTILS "OFF" CACHE INTERNAL "") +set(LLVM_INCLUDE_BENCHMARKS "OFF" CACHE INTERNAL "") +set(LLVM_INCLUDE_DOCS "OFF" CACHE INTERNAL "") +set(LLVM_INCLUDE_EXAMPLES "OFF" CACHE INTERNAL "") +set(LLVM_INCLUDE_GO_TESTS "OFF" CACHE INTERNAL "") +set(LLVM_INCLUDE_RUNTIME "OFF" CACHE INTERNAL "") +set(LLVM_INCLUDE_TESTS "OFF" CACHE INTERNAL "") +set(LLVM_INCLUDE_TOOLS "OFF" CACHE INTERNAL "") +set(LLVM_INCLUDE_UTILS "OFF" CACHE INTERNAL "") +add_subdirectory(${TRITON_LLVM_BUILD_DIR}/llvm-src + ${TRITON_LLVM_BUILD_DIR}/llvm-build) +get_property(LLVM_LIBRARIES GLOBAL PROPERTY LLVM_COMPONENT_LIBS) +# remove LLVM-specific variables so we don't pollute GUI +get_cmake_property(_variableNames VARIABLES) +list(REMOVE_ITEM _variableNames ${__variableNames}) +list(REMOVE_ITEM _variableNames ${LLVM_LIBRARIES}) +foreach (_variableName ${_variableNames}) + unset(${_variableName} CACHE) +endforeach() +include_directories("${TRITON_LLVM_BUILD_DIR}/llvm-build/include/" + "${TRITON_LLVM_BUILD_DIR}/llvm-src/include/") # Python module if(BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") - # PyBind11 wrapper source file - file(GLOB_RECURSE TORCH_SRC torch/*.cc) # Build CUTLASS python wrapper if requested + set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}") set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}") if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL "")) - set(TORCH_SRC ${TORCH_SRC} cutlass.cc) + set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc) add_definitions(-DWITH_CUTLASS_BINDINGS) set(CUTLASS_LIBRARIES "cutlass.a") endif() message(STATUS ${CUTLASS_INCLUDE_PATH}) - set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC}) - set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ${CUTLASS_OPT}") - include_directories("." ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR}) + include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR}) link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR}) + set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC}) endif() # Triton file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) -target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SYSTEM_LIBS}) +target_link_libraries(triton ${LLVM_LIBRARIES}) if(BUILD_PYTHON_MODULE) target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES}) endif() + +# Tutorials +if(BUILD_TUTORIALS) + message(STATUS "Adding C++ tutorials") + add_subdirectory(tutorials) +endif() diff --git a/cmake/DownloadLLVM.in b/cmake/DownloadLLVM.in new file mode 100644 index 000000000..afe3d8362 --- /dev/null +++ b/cmake/DownloadLLVM.in @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.6) + +project(llvm-download NONE) +include(ExternalProject) + + +ExternalProject_Add(llvm + URL "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz" + SOURCE_DIR "${TRITON_LLVM_BUILD_DIR}/llvm-src" + BINARY_DIR "${TRITON_LLVM_BUILD_DIR}/llvm-build" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/docs/getting-started/installation.rst b/docs/getting-started/installation.rst index b4e3fd262..fb56eca43 100644 --- a/docs/getting-started/installation.rst +++ b/docs/getting-started/installation.rst @@ -2,6 +2,17 @@ Installation ============== +--------------------- +Binary Distributions +--------------------- + +You can install the latest nightly release of Triton from pip: + +.. code-block:: bash + + pip install triton-nightly + + -------------- From Source -------------- @@ -14,11 +25,12 @@ You can install the Python package from source by running the following commands .. code-block:: bash - sudo apt-get install llvm-10-dev git clone https://github.com/ptillet/triton.git; cd triton/python; pip install -e . +This may take a while (10-20 minutes) as it will download and compile LLVM from source. + You can then test your installation by running the unit tests: .. code-block:: bash @@ -40,17 +52,10 @@ Those not interested in Python integration may want to use the internals of Trit .. code-block:: bash - sudo apt-get install llvm-10-dev git clone https://github.com/ptillet/triton.git; mkdir build; cd build; cmake ../; make -j8; -A custom llvm-config binary can also be provided: - -.. code-block:: bash - - cmake ../ -DLLVM_CONFIG=/path/to/llvm-config - Note that while direct usage of the C++ API is not officially supported, a usage tutorial can be found `here `_ diff --git a/python/setup.py b/python/setup.py index 4e1e9a092..ef45a3a37 100644 --- a/python/setup.py +++ b/python/setup.py @@ -6,28 +6,12 @@ import platform import subprocess import distutils import glob +import tempfile from distutils.version import LooseVersion from setuptools import setup, Extension, find_packages -from torch.utils.cpp_extension import include_paths, library_paths from setuptools.command.build_ext import build_ext from setuptools.command.test import test as TestCommand import distutils.spawn -import torch - - -def find_llvm(): - versions = ["-10", "-10.0", ""] - supported = ["llvm-config{v}".format(v=v) for v in versions] - paths = [distutils.spawn.find_executable(cfg) for cfg in supported] - paths = [p for p in paths if p is not None] - if paths: - return paths[0] - config = distutils.spawn.find_executable("llvm-config") - instructions = "Please install llvm-10-dev" - if config is None: - raise RuntimeError("Could not find llvm-config. " + instructions) - version = os.popen("{config} --version".format(config=config)).read() - raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions) class CMakeExtension(Extension): @@ -38,6 +22,16 @@ class CMakeExtension(Extension): class CMakeBuild(build_ext): + + user_options = build_ext.user_options + [('base-dir=', None, 'base directory of Triton')] + + def initialize_options(self): + build_ext.initialize_options(self) + self.base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + + def finalize_options(self): + build_ext.finalize_options(self) + def run(self): try: out = subprocess.check_output(["cmake", "--version"]) @@ -58,23 +52,23 @@ class CMakeBuild(build_ext): # self.debug = True self.debug = False extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) + # create build directories + llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm") + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + if not os.path.exists(llvm_build_dir): + os.makedirs(llvm_build_dir) # python directories python_include_dirs = distutils.sysconfig.get_python_inc() python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR") - torch_include_dirs = include_paths(True) - torch_library_dirs = library_paths(True) - cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI)) cmake_args = [ "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DBUILD_TUTORIALS=OFF", "-DBUILD_PYTHON_MODULE=ON", #'-DPYTHON_EXECUTABLE=' + sys.executable, #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON, - "-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs] + include_paths(True)), - "-DPYTHON_LINK_DIRS=" + ";".join(library_paths(True)), - "-DTORCH_CXX11_ABI=" + cxx11abi, - "-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton", - "-DLLVM_CONFIG=" + find_llvm(), + "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, + "-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs]) ] # configuration cfg = "Debug" if self.debug else "Release" @@ -87,13 +81,10 @@ class CMakeBuild(build_ext): build_args += ["--", "/m"] else: cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] - build_args += ["--", "-j4"] + build_args += ["--", "-j8"] env = os.environ.copy() - if not os.path.exists(self.build_temp): - os.makedirs(self.build_temp) - sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "src")) - subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env) + subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp) @@ -106,7 +97,10 @@ setup( long_description="", packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"], install_requires=["numpy", "torch"], - package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, + package_data={ + "triton/ops": ["*.c"], + "triton/ops/blocksparse": ["*.c"] + }, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], cmdclass={"build_ext": CMakeBuild}, @@ -116,10 +110,10 @@ setup( url="https://github.com/ptillet/triton/", download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz", classifiers=[ - "Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package - "Intended Audience :: Developers", # Define that your audience are developers + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", "Topic :: Software Development :: Build Tools", - "License :: OSI Approved :: MIT License", # Again, pick a license + "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.6", ], ) diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt deleted file mode 120000 index 8c50e0213..000000000 --- a/python/src/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -../../CMakeLists.txt \ No newline at end of file diff --git a/python/src/cmake b/python/src/cmake deleted file mode 120000 index c06bb027c..000000000 --- a/python/src/cmake +++ /dev/null @@ -1 +0,0 @@ -../../cmake/ \ No newline at end of file diff --git a/python/src/cutlass.cc b/python/src/cutlass.cc index e680d83e2..14da81330 100644 --- a/python/src/cutlass.cc +++ b/python/src/cutlass.cc @@ -4,8 +4,6 @@ #include "cutlass/library/singleton.h" #include "pybind11/pybind11.h" #include "triton/tools/bench.hpp" -#include -#include using namespace cutlass; using namespace cutlass::library; @@ -132,58 +130,56 @@ const Operation *autotune(int M, int N, int K, } // map of torch datatypes to cutlass datatypes -std::map type_map = { - {caffe2::TypeMeta::Id(), NumericTypeID::kF16}, - {caffe2::TypeMeta::Id(), NumericTypeID::kF32}, - {caffe2::TypeMeta::Id(), NumericTypeID::kF64}}; +std::map type_map = { + {"float16", NumericTypeID::kF16}, + {"float32", NumericTypeID::kF32}, + {"float64", NumericTypeID::kF64}}; -void cutlass_matmul(torch::Tensor A, torch::Tensor B, torch::Tensor C) { - size_t M = A.size(0); - size_t N = B.size(1); - size_t K = A.size(1); - size_t lda = A.stride(0); - size_t ldb = B.stride(0); - size_t ldc = C.stride(1); - size_t ldd = C.stride(1); - void *ptr_A = A.data_ptr(); - void *ptr_B = B.data_ptr(); - void *ptr_C = C.data_ptr(); +void cutlass_matmul(uintptr_t A, uintptr_t B, uintptr_t C, + size_t M, size_t N, size_t K, + size_t stride_a_0, size_t stride_a_1, + size_t stride_b_0, size_t stride_b_1, + size_t stride_c_0, size_t stride_c_1, + std::string type_a, std::string type_b, std::string type_c, + size_t dev_id, uint64_t stream_handle) { + void *ptr_A = (void *)A; + void *ptr_B = (void *)B; + void *ptr_C = (void *)C; void *ptr_D = ptr_C; + size_t lda = stride_a_0; + size_t ldb = stride_b_0; + size_t ldc = stride_c_1; + size_t ldd = ldc; float alpha = 1.0f; float beta = 0.0f; // layout for A LayoutTypeID layout_A; - if (A.stride(0) == 1) + if (stride_a_0 == 1) layout_A = LayoutTypeID::kColumnMajor; - else if (A.stride(1) == 1) + else if (stride_a_1 == 1) layout_A = LayoutTypeID::kRowMajor; - else { - A = A.contiguous(); - layout_A = LayoutTypeID::kRowMajor; - } + else + throw std::runtime_error("A layout is not supported"); // layout for B LayoutTypeID layout_B; - if (B.stride(0) == 1) + if (stride_b_0 == 1) layout_B = LayoutTypeID::kColumnMajor; - else if (B.stride(1) == 1) + else if (stride_b_1 == 1) layout_B = LayoutTypeID::kRowMajor; - else { - B = B.contiguous(); - layout_B = LayoutTypeID::kRowMajor; - } + else + throw std::runtime_error("B layout is not supported"); // data types NumericTypeID element_compute = NumericTypeID::kF32; - NumericTypeID element_A = type_map[A.dtype().id()]; - NumericTypeID element_B = type_map[B.dtype().id()]; - NumericTypeID element_C = type_map[C.dtype().id()]; + NumericTypeID element_A = type_map[type_a]; + NumericTypeID element_B = type_map[type_b]; + NumericTypeID element_C = type_map[type_c]; // misc. flags ScalarPointerMode scalar_mode = ScalarPointerMode::kHost; NumericTypeID element_scalar = NumericTypeID::kF32; ComplexTransform transform_A = ComplexTransform::kNone; ComplexTransform transform_B = ComplexTransform::kNone; // runtime flags - size_t dev_id = C.device().index(); - cudaStream_t stream = c10::cuda::getCurrentCUDAStream(dev_id).stream(); + cudaStream_t stream = (cudaStream_t)stream_handle; // auto-tune std::vector tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C, dev_id, (size_t)element_compute, (size_t)scalar_mode}; diff --git a/python/src/include b/python/src/include deleted file mode 120000 index 3611dd266..000000000 --- a/python/src/include +++ /dev/null @@ -1 +0,0 @@ -../../include/ \ No newline at end of file diff --git a/python/src/lib b/python/src/lib deleted file mode 120000 index bc1a1ee04..000000000 --- a/python/src/lib +++ /dev/null @@ -1 +0,0 @@ -../../lib/ \ No newline at end of file diff --git a/python/src/main.cc b/python/src/main.cc index 1d664f8f8..48fc69e0d 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -8,7 +8,6 @@ void init_cutlass(pybind11::module &m); PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; init_triton(m); - init_torch_utils(m); init_superblocking(m); #ifdef WITH_CUTLASS_BINDINGS init_cutlass(m); diff --git a/python/src/superblock.cc b/python/src/superblock.cc new file mode 100644 index 000000000..35b7e9de4 --- /dev/null +++ b/python/src/superblock.cc @@ -0,0 +1,119 @@ +#include +#include +#include +#include +#include +#include +#include +#ifdef _OPENMP +#include +#endif + +// row-major 3d tensor +class tensor_3d { +public: + tensor_3d(int size_0, int size_1, int size_2, int *data = nullptr) : data_(size_0 * size_1 * size_2, 0) { + if (data) + std::copy(data, data + data_.size(), data_.begin()); + stride_0_ = size_1 * size_2; + stride_1_ = size_2; + stride_2_ = 1; + } + + int &operator()(int i, int j, int k) { + return data_[i * stride_0_ + j * stride_1_ + k]; + } + +private: + std::vector data_; + int stride_0_; + int stride_1_; + int stride_2_; +}; + +std::vector segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width, int H, int M, int N) { + tensor_3d tmp(H, M, N); + std::vector current(H, 0); + int num = 0; + std::vector lut(H * M * N * 4); + for (size_t h = 0; h < H; h++) { + // surrounding indices + std::vector ii_left(max_width, -1); + std::vector> ii_top(max_width, std::vector(N, -1)); + // start the dynamic programming algorithm + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + int v = layout(h, m, n); + if (v == 0) + continue; + int n_left = ii_left[max_width - 1]; + int m_top = ii_top[max_width - 1][n]; + int top = (m_top >= 0) ? tmp(h, m_top, n) : 0; + int left = (n_left >= 0) ? tmp(h, m, n_left) : 0; + int topleft = (m_top >= 0 && n_left >= 0) ? tmp(h, m_top, n_left) : 0; + int width = std::min(left, std::min(top, topleft)) + 1; + // reset width if blocks cannot be + // packed together (i.e., there's a 1 "in the middle") + for (int nn = n_left + 1; nn < n; nn++) + if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n]) + width = 1; + tmp(h, m, n) = width; + // update n_left ring buffer + for (int k = 0; k < max_width - 1; k++) + ii_left[k] = ii_left[k + 1]; + ii_left[max_width - 1] = n; + // update ii_top ring buffer + for (int k = 0; k < max_width - 1; k++) + ii_top[k][n] = ii_top[k + 1][n]; + ii_top[max_width - 1][n] = m; + // block is too small -- skip + if (width != max_width) + continue; + // retained blocks are set to zeros + for (size_t km = 0; km < max_width; km++) + for (size_t kn = 0; kn < max_width; kn++) { + int mm = ii_top[km][n]; + int nn = ii_left[kn]; + if (mm < 0 || nn < 0) + continue; + layout(h, mm, nn) = 0; + tmp(h, mm, nn) = 0; + lut[num++] = (int)h; + lut[num++] = (int)mm; + lut[num++] = (int)nn; + lut[num++] = idx(h, mm, nn); + } + } + } + } + lut.resize(num); + return lut; +} + +typedef std::pair> lut_t; + +std::vector superblock(uintptr_t LAYOUT, int H, int M, int N, int start_width) { + std::vector ret; + int current = 0; + tensor_3d layout(H, M, N, (int *)LAYOUT); + tensor_3d idx(H, M, N); + for (int64_t h = 0; h < H; h++) + for (int64_t m = 0; m < M; m++) + for (int64_t n = 0; n < N; n++) { + if (layout(h, m, n) == 0) + continue; + idx(h, m, n) = current++; + } + // create lut + for (int max_width = start_width; max_width > 0; max_width /= 2) { + auto lut = segment_blocks(layout, idx, max_width, H, M, N); + if (lut.size() == 0) + continue; + ret.push_back(std::make_pair(max_width, pybind11::array_t(lut.size(), lut.data()))); + } + return ret; +} + +void init_superblocking(pybind11::module &m) { + m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication"); +} \ No newline at end of file diff --git a/python/src/torch/superblock.cc b/python/src/torch/superblock.cc deleted file mode 100644 index 2243eec79..000000000 --- a/python/src/torch/superblock.cc +++ /dev/null @@ -1,117 +0,0 @@ -#include -#include -#include -#include -#ifdef _OPENMP -#include -#endif - -typedef std::vector> ret_t; - -void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){ - size_t H = layout.size(0); - size_t M = layout.size(1); - size_t N = layout.size(2); - torch::Tensor tmp = torch::zeros_like(layout); - auto _tmp = tmp.accessor (); - auto _layout = layout.accessor (); - auto _idx = idx.accessor (); - auto _scratch = scratch.accessor(); - std::vector current(H, 0); - #ifdef _OPENMP - #pragma omp parallel for - #endif - for(size_t h = 0; h < H; h++){ - // surrounding indices - std::vector ii_left(max_width, -1); - std::vector> ii_top(max_width, std::vector(N, -1)); - - for(size_t m = 0; m < M; m++){ - for(size_t n = 0; n < N; n++){ - int v = _layout[h][m][n]; - if(v == 0) - continue; - int n_left= ii_left[max_width-1]; - int m_top = ii_top [max_width-1][n]; - int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0; - int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0; - int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0; - int width = std::min(left, std::min(top, topleft)) + 1; - - // reset width if blocks cannot be - // packed together (i.e., there's a 1 "in the middle") - for(int nn = n_left + 1; nn < n; nn++) - if(ii_top[max_width-1][nn] > ii_top[max_width-1][n]) - width = 1; - _tmp[h][m][n] = width; - - // update n_left ring buffer - for(int k = 0; k < max_width-1; k++) - ii_left[k] = ii_left[k+1]; - ii_left[max_width-1] = n; - - // update ii_top ring buffer - for(int k = 0; k < max_width-1; k++) - ii_top[k][n] = ii_top[k+1][n]; - ii_top[max_width-1][n] = m; - - // block is too small -- skip - if(width != max_width) - continue; - - // retained blocks are set to zeros - for(size_t km = 0; km < max_width; km++) - for(size_t kn = 0; kn < max_width; kn++) - { - int mm = ii_top[km][n]; - int nn = ii_left[kn]; - if(mm < 0 || nn < 0) - continue; - _layout[h][mm][nn] = 0; - _tmp[h][mm][nn] = 0; - _scratch[h][current[h]][0] = (int)h; - _scratch[h][current[h]][1] = (int)mm; - _scratch[h][current[h]][2] = (int)nn; - _scratch[h][current[h]][3] = _idx[h][mm][nn]; - current[h]++; - } - } - } - } - std::vector to_cat; - for(size_t h = 0; h < H; h++) - if(current[h] > 0) - to_cat.push_back(scratch[h].slice(0, 0, current[h])); - if(!to_cat.empty()) - ret.push_back(std::make_tuple(max_width, torch::cat(to_cat))); -} - - -ret_t superblock(torch::Tensor layout, int start_width) { - ret_t ret; - // block index - torch::Tensor idx = torch::zeros_like(layout); - int current = 0; - int64_t H = layout.size(0); - int64_t M = layout.size(1); - int64_t N = layout.size(2); - auto _layout = layout.accessor (); - auto _idx = idx.accessor(); - for(int64_t h = 0; h < H; h++) - for(int64_t m = 0; m < M; m++) - for(int64_t n = 0; n < N; n++){ - if(_layout[h][m][n] == 0) - continue; - _idx[h][m][n] = current++; - } - // scratch memory - torch::Tensor scratch = torch::empty({H, layout.sum().item(), 4}, layout.dtype()); - for(int max_width = start_width; max_width > 0; max_width /= 2) - segment_blocks(layout, idx, scratch, max_width, ret); - return ret; -} - - -void init_superblocking(pybind11::module &m) { - m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication"); -} \ No newline at end of file diff --git a/python/src/torch/utils.cc b/python/src/torch/utils.cc deleted file mode 100644 index c7bf74105..000000000 --- a/python/src/torch/utils.cc +++ /dev/null @@ -1,32 +0,0 @@ - -#include "triton/driver/device.h" -#include "triton/driver/stream.h" -#include -#include -#include - -namespace torch_utils { - -uint64_t cu_device(int64_t dev_id) { - CUdevice handle; - triton::driver::dispatch::cuDeviceGet(&handle, dev_id); - return (uint64_t)handle; -} - -uint64_t cu_stream(int64_t dev_id) { - return (uint64_t)c10::cuda::getCurrentCUDAStream(dev_id).stream(); -} - -void set_device(int64_t dev_id) { - if (dev_id >= 0) - C10_CUDA_CHECK(cudaSetDevice(dev_id)); -} - -} // namespace torch_utils - -void init_torch_utils(pybind11::module &m) { - pybind11::module subm = m.def_submodule("torch_utils"); - subm.def("cu_device", &torch_utils::cu_device); - subm.def("cu_stream", &torch_utils::cu_stream); - subm.def("set_device", &torch_utils::set_device); -} \ No newline at end of file diff --git a/python/src/triton.cc b/python/src/triton.cc index 6eea1717e..05e7b9d10 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -89,7 +89,11 @@ void init_triton_driver(py::module &&m) { py::class_(m, "device"); // cuda device py::class_(m, "cu_device") - .def(py::init()); + .def(py::init([](int dev_id, bool take_ownership) { + CUdevice handle; + drv::dispatch::cuDeviceGet(&handle, dev_id); + return new drv::cu_device(handle, take_ownership); + })); // host device py::class_(m, "host_device") .def(py::init<>()); diff --git a/python/test/test_blocksparse.py b/python/test/test_blocksparse.py index fde63b29d..6e748b6c0 100644 --- a/python/test/test_blocksparse.py +++ b/python/test/test_blocksparse.py @@ -2,12 +2,15 @@ import torch import triton import pytest + @pytest.mark.parametrize( "MODE, TRANS_A, TRANS_B, BLOCK, DTYPE", - [(mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"] - for at in [False, True] for bt in [False, True] for block in [16, 32, 64]], + [ + (mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"] + for at in [False, True] for bt in [False, True] for block in [16, 32, 64] + ], ) -def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=128, N=256, K=384): +def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256): DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] # set seed torch.random.manual_seed(0) @@ -36,6 +39,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=128, N=256, K= # compare assert triton.testing.allclose(rc, tc) + @pytest.mark.parametrize( "BLOCK, WIDTH", [(block, width) for block in [32] for width in [256, 576, 1024, 1792]], @@ -76,6 +80,7 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16): # compare assert triton.testing.allclose(ry, ty) + def test_attention_fwd_bwd( input_scale=1.0, tol=2e-2, @@ -88,9 +93,7 @@ def test_attention_fwd_bwd( ): # inputs qkv_shape = (batch_size, n_heads, n_ctx, 64) - qkvs = [ - torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3) - ] + qkvs = [torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)] attn_mask = torch.tril( torch.ones( [n_ctx, n_ctx], @@ -134,6 +137,7 @@ def test_attention_fwd_bwd( for g1, g2 in zip(grads, torch_grads): torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol) + def triton_attention( layout, block: int, diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 8f0c1668e..7841be5c0 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -5,8 +5,6 @@ import torch from . import testing from .kernel import * from . import ops -# C bindings -import triton._C.libtriton.torch_utils as _torch_utils # version __version__ = '1.0.0' \ No newline at end of file diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 1f83d2f9c..f84ceae05 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -4,14 +4,19 @@ from typing import Optional, Dict, List import torch # C bindings import triton._C.libtriton.triton as _triton -import triton._C.libtriton.torch_utils as _torch_utils codes = { - _triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I', - _triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f', - _triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P' + _triton.runtime.arg_type.int1: 'B', + _triton.runtime.arg_type.int8: 'B', + _triton.runtime.arg_type.int32: 'I', + _triton.runtime.arg_type.int64: 'Q', + _triton.runtime.arg_type.half: 'H', + _triton.runtime.arg_type.float: 'f', + _triton.runtime.arg_type.double: 'd', + _triton.runtime.arg_type.buffer: 'P' } + def th_to_triton(obj): tys = { torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long',\ @@ -21,9 +26,11 @@ def th_to_triton(obj): return tys[obj] return str(obj) + def cdiv(a, b): return (a + b - 1) // b + def read(path, kernel_names: Optional[List] = None): if kernel_names is None: kernel_names = [] @@ -32,11 +39,20 @@ def read(path, kernel_names: Optional[List] = None): source = _triton.tools.extract_kernels(source, kernel_names) return source + config = _triton.runtime.config + class kernel: - def __init__(self, src, device, defines: Optional[Dict] = None, num_warps: int = 4, - autotune_vals: Optional[List] = None, autotune_key: Optional[List] = None): + def __init__( + self, + src, + device, + defines: Optional[Dict] = None, + num_warps: int = 4, + autotune_vals: Optional[List] = None, + autotune_key: Optional[List] = None + ): if defines is None: defines = {} if autotune_vals is None: @@ -51,13 +67,14 @@ class kernel: assert device.type in ['cuda', 'cpu'] if device.type == 'cuda': self.device_id = torch.cuda.current_device() if device.index is None else device.index - self.device = _triton.driver.cu_device(_torch_utils.cu_device(self.device_id), False) - self.stream = _triton.driver.cu_stream(_torch_utils.cu_stream(self.device_id), False) + self.device = _triton.driver.cu_device(self.device_id, False) + cu_stream = torch.cuda.current_stream(self.device_id).cuda_stream + self.stream = _triton.driver.cu_stream(cu_stream, False) if device.type == 'cpu': self.device_id = -1 self.device = _triton.driver.host_device() self.device = _triton.driver.host_stream() - _torch_utils.set_device(self.device_id) + torch.cuda.set_device(self.device_id) # function self.opt = _triton.runtime.options() self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} @@ -68,7 +85,7 @@ class kernel: def __call__(self, *args, grid): # make sure that the executing thread is on the right device - _torch_utils.set_device(self.device_id) + torch.cuda.set_device(self.device_id) # pack parameters into a byte buffer params = struct.pack(self.tys, *args) kernel = self.fn.autotune(params, grid, self.stream) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 3ad63dcd3..3eff88060 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -6,6 +6,7 @@ import math src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) + ############## # MAIN API # ############## @@ -82,16 +83,13 @@ class _matmul(torch.autograd.Function): @staticmethod def make_sdd_lut(layout, block, dtype, device): start_width = 128 // block - superblocks = libtriton.superblock(layout.type(torch.int32), start_width) + layout = layout.type(torch.int32) + superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width) luts, widths, packs = [], [], [] for size, nnz in superblocks: + nnz = nnz.reshape(-1, 4) width = nnz.shape[0] // (size * size) - h = nnz[:, 0] - i = nnz[:, 1] - j = nnz[:, 2] - b = nnz[:, 3] - lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous() - luts.append(lut.type(torch.int32).to(device)) + luts.append(torch.from_numpy(nnz).type(torch.int32).to(device)) widths.append(width) packs.append(size) # create locks @@ -126,10 +124,21 @@ class _matmul(torch.autograd.Function): key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple) if key not in _matmul.sdd_cache: defines = { - 'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK': - 32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1', - 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc', - 'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel' + 'TM': block * pack, + 'TN': block * pack, + 'TMN': block * block * pack * pack, + 'BLOCK': block, + 'TK': 32, + 'TYPE': dtype, + 'STRIDE_AM': '1' if trans_a else 'lda', + 'STRIDE_AK': 'lda' if trans_a else '1', + 'STRIDE_BN': 'ldb' if trans_b else '1', + 'STRIDE_BK': '1' if trans_b else 'ldb', + 'STRIDE_CM': 'ldc', + 'STRIDE_CN': '1', + 'SDD': True, + 'TZ': 1, + 'NAME': 'sdd_kernel' } _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines) @@ -141,10 +150,28 @@ class _matmul(torch.autograd.Function): # kernel calls max_width = 49152 for off_width in range(0, width, max_width): - kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0), - b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width, - lut.data_ptr(), locks.data_ptr(), num_lock, - grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]) + kernel( + a.data_ptr(), + b.data_ptr(), + c.data_ptr(), + a.stride(2), + b.stride(2), + block, + a.stride(0), + b.stride(0), + c.stride(0), + a.stride(1), + b.stride(1), + c.stride(0), + AS2, + AS2, + AS3, + off_width, + lut.data_ptr(), + locks.data_ptr(), + num_lock, + grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0] + ) # save for backward pass return c @@ -258,10 +285,19 @@ class _matmul(torch.autograd.Function): key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _matmul.dds_cache: defines = { - 'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda', - 'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK': - 1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1', - 'NAME': 'dds_kernel', 'DDS': True + 'TM': 128, + 'TN': block, + 'TK': 16, + 'BLOCK': block, + 'TYPE': dtype, + 'STRIDE_AM': 1 if trans_a else 'lda', + 'STRIDE_AK': 'lda' if trans_a else 1, + 'STRIDE_BN': block if trans_b else 1, + 'STRIDE_BK': 1 if trans_b else block, + 'STRIDE_CM': '1' if trans_c else 'ldc', + 'STRIDE_CN': 'ldc' if trans_c else '1', + 'NAME': 'dds_kernel', + 'DDS': True } _matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines) kernel = _matmul.dds_cache[key] @@ -272,9 +308,28 @@ class _matmul(torch.autograd.Function): CS3 = AS2 if trans_c else BS2 locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) - kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0), - c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), - num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0]) + kernel( + a.data_ptr(), + b.data_ptr(), + c.data_ptr(), + a.stride(2), + block, + c.stride(2), + a.stride(0), + b.stride(0), + c.stride(0), + a.stride(1), + b.stride(1), + c.stride(1), + AS2, + BS2, + 0, + 0, + lut.data_ptr(), + locks.data_ptr(), + num_locks, + grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0] + ) return c @staticmethod @@ -292,10 +347,19 @@ class _matmul(torch.autograd.Function): key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _matmul.dsd_cache: defines = { - 'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block, - 'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': - '1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': - 'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True + 'TM': block, + 'TN': 128, + 'TK': 16, + 'BLOCK': block, + 'TYPE': dtype, + 'STRIDE_AM': 1 if trans_a else block, + 'STRIDE_AK': block if trans_a else 1, + 'STRIDE_BN': 'ldb' if trans_b else '1', + 'STRIDE_BK': '1' if trans_b else 'ldb', + 'STRIDE_CM': '1' if trans_c else 'ldc', + 'STRIDE_CN': 'ldc' if trans_c else '1', + 'NAME': 'dsd_kernel', + 'DSD': True } _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines) kernel = _matmul.dsd_cache[key] @@ -306,16 +370,37 @@ class _matmul(torch.autograd.Function): CS3 = AS1 if trans_c else BS3 locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) - kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0), - c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), - num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0]) + kernel( + a.data_ptr(), + b.data_ptr(), + c.data_ptr(), + block, + b.stride(2), + c.stride(2), + a.stride(0), + b.stride(0), + c.stride(0), + a.stride(1), + b.stride(1), + c.stride(1), + BS3, + AS1, + 0, + 0, + lut.data_ptr(), + locks.data_ptr(), + num_locks, + grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0] + ) return c fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)} @staticmethod - def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut, - da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs): + def forward( + ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut, da_num_locks, + da_width, da_packs, db_lut, db_num_locks, db_width, db_packs + ): c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs) # save for backward ctx.save_for_backward(a, b) @@ -342,19 +427,24 @@ class _matmul(torch.autograd.Function): # gradients w.r.t. a if ctx.needs_input_grad[0]: mode_da = mode[1] + mode[0] + mode[2] - da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, - ctx.da_num_locks, ctx.da_width, ctx.da_packs) + da = _matmul.fn[mode_da]( + dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_num_locks, ctx.da_width, + ctx.da_packs + ) # gradients w.r.t. b if ctx.needs_input_grad[1]: mode_db = mode[2] + mode[1] + mode[0] - db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, - ctx.db_num_locks, ctx.db_width, ctx.db_packs) + db = _matmul.fn[mode_db]( + a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_num_locks, ctx.db_width, + ctx.db_packs + ) return da, db, None, None, None,\ None, None, None, None,\ None, None, None, None, None, None,\ None, None, None, None, None, None,\ None, None, None, None, None, None + class matmul: def make_lut(self, dtype, device): key = (dtype, device) @@ -375,8 +465,7 @@ class matmul: elif self.mode == 'dsd': da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device) elif self.mode == 'dds': - da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, - device) + da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device) # DB look-up table if self.mode == 'sdd': db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device) @@ -419,7 +508,8 @@ class matmul: a = matmul._pad_shape(a, self.mode == 'dsd') b = matmul._pad_shape(b, self.mode == 'dds') # execute - c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, - c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, - db_width, db_packs) + c = _matmul.apply( + a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs, + da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs + ) return c diff --git a/python/triton/testing.py b/python/triton/testing.py index 652b4657d..99c0bc1e6 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -18,8 +18,22 @@ def cutlass_matmul(a, b): if _cutlass is None: raise RuntimeError("Cannot find cutlass library") M, N = a.shape[0], b.shape[1] + Ka, Kb = a.shape[1], b.shape[0] + assert Ka == Kb + assert a.dtype == b.dtype + assert a.device == b.device + # allocate output c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device) - _cutlass.matmul(a, b, c) + # run function + dtype = str(a.dtype).split('.')[-1] + _cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \ + M, N, Ka,\ + a.stride(0), a.stride(1),\ + b.stride(0), b.stride(1),\ + c.stride(0), c.stride(1),\ + dtype, dtype, dtype, + a.device.index, torch.cuda.current_stream(a.device).cuda_stream) + return c diff --git a/tutorials/01-matmul.cc b/tutorials/01-matmul.cc index 53817e9c8..08f57084d 100644 --- a/tutorials/01-matmul.cc +++ b/tutorials/01-matmul.cc @@ -189,14 +189,14 @@ float triton_dot(drv::context* context, drv::stream* stream, // grid auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; }; auto grid = [ceil, M, N](const rt::options_t& x) { - return rt::grid_t{ceil(M, x.D("TM"))* - ceil(N, x.D("TN")), - (size_t)x.D("TZ")}; + return rt::kernel::grid_t{ceil(M, x.D("TM"))* + ceil(N, x.D("TN")), + (size_t)x.D("TZ")}; }; // metrics auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function(oss.str(), grid, stream);}, stream); return tflops(triton_ns); }