[BUILD] Added automatic nightly build releases to pip in CI; removed build-time dependence on LLVM and PyTorch (#77)
Recently there has been more and more report about installation issues: - Installing Triton before upgrading pytorch can create some issues because Triton uses some torch headers - llvm-10-dev not available on some platform; llvm-11-dev not available on e.g. Ubuntu. absence of nightly builds This PR should fix all these issues. Some CMake tricks are used to download and install llvm at build time. Triton Python bindings were modified to remove dependence on pytorch ops. Midnight CI job added to generate binary wheels for all Triton version and update them on pypi's new triton-nightly project. This PR will also make it very easy to use LLVM forks in the future for whatever needs we have.
This commit is contained in:
committed by
Philippe Tillet
parent
3ad0a4d7be
commit
2f80a98776
@@ -8,7 +8,7 @@ variables:
|
|||||||
value: venv
|
value: venv
|
||||||
|
|
||||||
# Run CI when something pushed to master
|
# Run CI when something pushed to master
|
||||||
trigger: [ master ]
|
# trigger: [ master ]
|
||||||
# Run CI when a PR is created or updated from master
|
# Run CI when a PR is created or updated from master
|
||||||
pr:
|
pr:
|
||||||
- master
|
- master
|
||||||
|
35
.ci/build-wheels.yml
Normal file
35
.ci/build-wheels.yml
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
trigger: none
|
||||||
|
pr: none
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
- job: linux
|
||||||
|
|
||||||
|
timeoutInMinutes: 180
|
||||||
|
|
||||||
|
pool: default
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- bash: |
|
||||||
|
set -o errexit
|
||||||
|
python3 --version
|
||||||
|
python3 -m pip install --upgrade pip
|
||||||
|
pip3 install cibuildwheel==1.10.0
|
||||||
|
pip3 install twine
|
||||||
|
displayName: Install dependencies
|
||||||
|
- bash: |
|
||||||
|
sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
|
||||||
|
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev`date '+%Y%m%d'`\"/g" python/setup.py
|
||||||
|
echo "" >> python/setup.cfg
|
||||||
|
echo "[build_ext]" >> python/setup.cfg
|
||||||
|
echo "base-dir=/project" >> python/setup.cfg
|
||||||
|
displayName: Patch setup.py
|
||||||
|
- bash: |
|
||||||
|
export CIBW_BEFORE_BUILD="pip install cmake"
|
||||||
|
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64"
|
||||||
|
python3 -m cibuildwheel python --output-dir wheelhouse
|
||||||
|
displayName: Build wheels
|
||||||
|
- task: PublishBuildArtifacts@1
|
||||||
|
inputs: {pathtoPublish: 'wheelhouse'}
|
||||||
|
- bash: |
|
||||||
|
python3 -m twine upload wheelhouse/* --skip-existing -u $(PYPI_USERNAME) -p $(PYPI_PASSWORD)
|
||||||
|
displayName: Upload wheels to PyPI
|
@@ -1,4 +1,11 @@
|
|||||||
cmake_minimum_required(VERSION 2.8)
|
cmake_minimum_required(VERSION 3.6)
|
||||||
|
include(ExternalProject)
|
||||||
|
|
||||||
|
if(NOT TRITON_LLVM_BUILD_DIR)
|
||||||
|
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
project(triton)
|
project(triton)
|
||||||
include(CTest)
|
include(CTest)
|
||||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||||
@@ -7,12 +14,6 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
|||||||
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
|
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
|
||||||
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
||||||
|
|
||||||
# LLVM
|
|
||||||
find_package(LLVM REQUIRED)
|
|
||||||
link_directories(${LLVM_LIBRARY_DIRS})
|
|
||||||
include_directories(${LLVM_INCLUDE_DIRS})
|
|
||||||
add_definitions(${LLVM_DEFINITIONS})
|
|
||||||
|
|
||||||
# Default build type
|
# Default build type
|
||||||
if(NOT CMAKE_BUILD_TYPE)
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
message(STATUS "Default build type: Release")
|
message(STATUS "Default build type: Release")
|
||||||
@@ -23,38 +24,77 @@ endif()
|
|||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=default -std=gnu++14")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=default -std=gnu++14")
|
||||||
|
|
||||||
# Tests
|
|
||||||
if(BUILD_TUTORIALS)
|
|
||||||
message(STATUS "Adding C++ tutorials")
|
##########
|
||||||
add_subdirectory(tutorials)
|
# LLVM
|
||||||
endif()
|
##########
|
||||||
|
get_cmake_property(_variableNames VARIABLES)
|
||||||
|
set(__variableNames ${_variableNames})
|
||||||
|
|
||||||
|
configure_file(cmake/DownloadLLVM.in ${TRITON_LLVM_BUILD_DIR}/llvm-download/CMakeLists.txt)
|
||||||
|
execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
|
||||||
|
WORKING_DIRECTORY "${TRITON_LLVM_BUILD_DIR}/llvm-download"
|
||||||
|
)
|
||||||
|
execute_process(COMMAND "${CMAKE_COMMAND}" --build .
|
||||||
|
WORKING_DIRECTORY "${TRITON_LLVM_BUILD_DIR}/llvm-download"
|
||||||
|
)
|
||||||
|
set(LLVM_TARGETS_TO_BUILD "NVPTX" CACHE INTERNAL "")
|
||||||
|
set(LLVM_BUILD_RUNTIME "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_BUILD_RUNTIMES "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_BUILD_TOOLS "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_BUILD_UTILS "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_INCLUDE_BENCHMARKS "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_INCLUDE_DOCS "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_INCLUDE_EXAMPLES "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_INCLUDE_GO_TESTS "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_INCLUDE_RUNTIME "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_INCLUDE_TESTS "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_INCLUDE_TOOLS "OFF" CACHE INTERNAL "")
|
||||||
|
set(LLVM_INCLUDE_UTILS "OFF" CACHE INTERNAL "")
|
||||||
|
add_subdirectory(${TRITON_LLVM_BUILD_DIR}/llvm-src
|
||||||
|
${TRITON_LLVM_BUILD_DIR}/llvm-build)
|
||||||
|
get_property(LLVM_LIBRARIES GLOBAL PROPERTY LLVM_COMPONENT_LIBS)
|
||||||
|
# remove LLVM-specific variables so we don't pollute GUI
|
||||||
|
get_cmake_property(_variableNames VARIABLES)
|
||||||
|
list(REMOVE_ITEM _variableNames ${__variableNames})
|
||||||
|
list(REMOVE_ITEM _variableNames ${LLVM_LIBRARIES})
|
||||||
|
foreach (_variableName ${_variableNames})
|
||||||
|
unset(${_variableName} CACHE)
|
||||||
|
endforeach()
|
||||||
|
include_directories("${TRITON_LLVM_BUILD_DIR}/llvm-build/include/"
|
||||||
|
"${TRITON_LLVM_BUILD_DIR}/llvm-src/include/")
|
||||||
|
|
||||||
# Python module
|
# Python module
|
||||||
if(BUILD_PYTHON_MODULE)
|
if(BUILD_PYTHON_MODULE)
|
||||||
message(STATUS "Adding Python module")
|
message(STATUS "Adding Python module")
|
||||||
# PyBind11 wrapper source file
|
|
||||||
file(GLOB_RECURSE TORCH_SRC torch/*.cc)
|
|
||||||
# Build CUTLASS python wrapper if requested
|
# Build CUTLASS python wrapper if requested
|
||||||
|
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||||
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
|
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
|
||||||
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
|
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
|
||||||
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
|
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
|
||||||
set(TORCH_SRC ${TORCH_SRC} cutlass.cc)
|
set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc)
|
||||||
add_definitions(-DWITH_CUTLASS_BINDINGS)
|
add_definitions(-DWITH_CUTLASS_BINDINGS)
|
||||||
set(CUTLASS_LIBRARIES "cutlass.a")
|
set(CUTLASS_LIBRARIES "cutlass.a")
|
||||||
endif()
|
endif()
|
||||||
message(STATUS ${CUTLASS_INCLUDE_PATH})
|
message(STATUS ${CUTLASS_INCLUDE_PATH})
|
||||||
set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC})
|
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||||
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ${CUTLASS_OPT}")
|
|
||||||
include_directories("." ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
|
||||||
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||||
|
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
# Triton
|
# Triton
|
||||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||||
target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SYSTEM_LIBS})
|
target_link_libraries(triton ${LLVM_LIBRARIES})
|
||||||
|
|
||||||
if(BUILD_PYTHON_MODULE)
|
if(BUILD_PYTHON_MODULE)
|
||||||
target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES})
|
target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Tutorials
|
||||||
|
if(BUILD_TUTORIALS)
|
||||||
|
message(STATUS "Adding C++ tutorials")
|
||||||
|
add_subdirectory(tutorials)
|
||||||
|
endif()
|
||||||
|
15
cmake/DownloadLLVM.in
Normal file
15
cmake/DownloadLLVM.in
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.6)
|
||||||
|
|
||||||
|
project(llvm-download NONE)
|
||||||
|
include(ExternalProject)
|
||||||
|
|
||||||
|
|
||||||
|
ExternalProject_Add(llvm
|
||||||
|
URL "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz"
|
||||||
|
SOURCE_DIR "${TRITON_LLVM_BUILD_DIR}/llvm-src"
|
||||||
|
BINARY_DIR "${TRITON_LLVM_BUILD_DIR}/llvm-build"
|
||||||
|
CONFIGURE_COMMAND ""
|
||||||
|
BUILD_COMMAND ""
|
||||||
|
INSTALL_COMMAND ""
|
||||||
|
TEST_COMMAND ""
|
||||||
|
)
|
@@ -2,6 +2,17 @@
|
|||||||
Installation
|
Installation
|
||||||
==============
|
==============
|
||||||
|
|
||||||
|
---------------------
|
||||||
|
Binary Distributions
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
You can install the latest nightly release of Triton from pip:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install triton-nightly
|
||||||
|
|
||||||
|
|
||||||
--------------
|
--------------
|
||||||
From Source
|
From Source
|
||||||
--------------
|
--------------
|
||||||
@@ -14,11 +25,12 @@ You can install the Python package from source by running the following commands
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo apt-get install llvm-10-dev
|
|
||||||
git clone https://github.com/ptillet/triton.git;
|
git clone https://github.com/ptillet/triton.git;
|
||||||
cd triton/python;
|
cd triton/python;
|
||||||
pip install -e .
|
pip install -e .
|
||||||
|
|
||||||
|
This may take a while (10-20 minutes) as it will download and compile LLVM from source.
|
||||||
|
|
||||||
You can then test your installation by running the unit tests:
|
You can then test your installation by running the unit tests:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
@@ -40,17 +52,10 @@ Those not interested in Python integration may want to use the internals of Trit
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo apt-get install llvm-10-dev
|
|
||||||
git clone https://github.com/ptillet/triton.git;
|
git clone https://github.com/ptillet/triton.git;
|
||||||
mkdir build;
|
mkdir build;
|
||||||
cd build;
|
cd build;
|
||||||
cmake ../;
|
cmake ../;
|
||||||
make -j8;
|
make -j8;
|
||||||
|
|
||||||
A custom llvm-config binary can also be provided:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
cmake ../ -DLLVM_CONFIG=/path/to/llvm-config
|
|
||||||
|
|
||||||
Note that while direct usage of the C++ API is not officially supported, a usage tutorial can be found `here <https://github.com/ptillet/triton/blob/master/tutorials/01-matmul.cc>`_
|
Note that while direct usage of the C++ API is not officially supported, a usage tutorial can be found `here <https://github.com/ptillet/triton/blob/master/tutorials/01-matmul.cc>`_
|
||||||
|
@@ -6,28 +6,12 @@ import platform
|
|||||||
import subprocess
|
import subprocess
|
||||||
import distutils
|
import distutils
|
||||||
import glob
|
import glob
|
||||||
|
import tempfile
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
from setuptools import setup, Extension, find_packages
|
from setuptools import setup, Extension, find_packages
|
||||||
from torch.utils.cpp_extension import include_paths, library_paths
|
|
||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
from setuptools.command.test import test as TestCommand
|
from setuptools.command.test import test as TestCommand
|
||||||
import distutils.spawn
|
import distutils.spawn
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def find_llvm():
|
|
||||||
versions = ["-10", "-10.0", ""]
|
|
||||||
supported = ["llvm-config{v}".format(v=v) for v in versions]
|
|
||||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
|
||||||
paths = [p for p in paths if p is not None]
|
|
||||||
if paths:
|
|
||||||
return paths[0]
|
|
||||||
config = distutils.spawn.find_executable("llvm-config")
|
|
||||||
instructions = "Please install llvm-10-dev"
|
|
||||||
if config is None:
|
|
||||||
raise RuntimeError("Could not find llvm-config. " + instructions)
|
|
||||||
version = os.popen("{config} --version".format(config=config)).read()
|
|
||||||
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
|
|
||||||
|
|
||||||
|
|
||||||
class CMakeExtension(Extension):
|
class CMakeExtension(Extension):
|
||||||
@@ -38,6 +22,16 @@ class CMakeExtension(Extension):
|
|||||||
|
|
||||||
|
|
||||||
class CMakeBuild(build_ext):
|
class CMakeBuild(build_ext):
|
||||||
|
|
||||||
|
user_options = build_ext.user_options + [('base-dir=', None, 'base directory of Triton')]
|
||||||
|
|
||||||
|
def initialize_options(self):
|
||||||
|
build_ext.initialize_options(self)
|
||||||
|
self.base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
||||||
|
|
||||||
|
def finalize_options(self):
|
||||||
|
build_ext.finalize_options(self)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
try:
|
try:
|
||||||
out = subprocess.check_output(["cmake", "--version"])
|
out = subprocess.check_output(["cmake", "--version"])
|
||||||
@@ -58,23 +52,23 @@ class CMakeBuild(build_ext):
|
|||||||
# self.debug = True
|
# self.debug = True
|
||||||
self.debug = False
|
self.debug = False
|
||||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||||
|
# create build directories
|
||||||
|
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm")
|
||||||
|
if not os.path.exists(self.build_temp):
|
||||||
|
os.makedirs(self.build_temp)
|
||||||
|
if not os.path.exists(llvm_build_dir):
|
||||||
|
os.makedirs(llvm_build_dir)
|
||||||
# python directories
|
# python directories
|
||||||
python_include_dirs = distutils.sysconfig.get_python_inc()
|
python_include_dirs = distutils.sysconfig.get_python_inc()
|
||||||
python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
|
python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
|
||||||
torch_include_dirs = include_paths(True)
|
|
||||||
torch_library_dirs = library_paths(True)
|
|
||||||
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
|
|
||||||
cmake_args = [
|
cmake_args = [
|
||||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||||
"-DBUILD_TUTORIALS=OFF",
|
"-DBUILD_TUTORIALS=OFF",
|
||||||
"-DBUILD_PYTHON_MODULE=ON",
|
"-DBUILD_PYTHON_MODULE=ON",
|
||||||
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||||
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
|
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
|
||||||
"-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs] + include_paths(True)),
|
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
|
||||||
"-DPYTHON_LINK_DIRS=" + ";".join(library_paths(True)),
|
"-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs])
|
||||||
"-DTORCH_CXX11_ABI=" + cxx11abi,
|
|
||||||
"-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton",
|
|
||||||
"-DLLVM_CONFIG=" + find_llvm(),
|
|
||||||
]
|
]
|
||||||
# configuration
|
# configuration
|
||||||
cfg = "Debug" if self.debug else "Release"
|
cfg = "Debug" if self.debug else "Release"
|
||||||
@@ -87,13 +81,10 @@ class CMakeBuild(build_ext):
|
|||||||
build_args += ["--", "/m"]
|
build_args += ["--", "/m"]
|
||||||
else:
|
else:
|
||||||
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
|
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
|
||||||
build_args += ["--", "-j4"]
|
build_args += ["--", "-j8"]
|
||||||
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
if not os.path.exists(self.build_temp):
|
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env)
|
||||||
os.makedirs(self.build_temp)
|
|
||||||
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))
|
|
||||||
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
|
||||||
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
||||||
|
|
||||||
|
|
||||||
@@ -106,7 +97,10 @@ setup(
|
|||||||
long_description="",
|
long_description="",
|
||||||
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
|
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
|
||||||
install_requires=["numpy", "torch"],
|
install_requires=["numpy", "torch"],
|
||||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
package_data={
|
||||||
|
"triton/ops": ["*.c"],
|
||||||
|
"triton/ops/blocksparse": ["*.c"]
|
||||||
|
},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||||
cmdclass={"build_ext": CMakeBuild},
|
cmdclass={"build_ext": CMakeBuild},
|
||||||
@@ -116,10 +110,10 @@ setup(
|
|||||||
url="https://github.com/ptillet/triton/",
|
url="https://github.com/ptillet/triton/",
|
||||||
download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz",
|
download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz",
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers", # Define that your audience are developers
|
"Intended Audience :: Developers",
|
||||||
"Topic :: Software Development :: Build Tools",
|
"Topic :: Software Development :: Build Tools",
|
||||||
"License :: OSI Approved :: MIT License", # Again, pick a license
|
"License :: OSI Approved :: MIT License",
|
||||||
"Programming Language :: Python :: 3.6",
|
"Programming Language :: Python :: 3.6",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@@ -1 +0,0 @@
|
|||||||
../../CMakeLists.txt
|
|
@@ -1 +0,0 @@
|
|||||||
../../cmake/
|
|
@@ -4,8 +4,6 @@
|
|||||||
#include "cutlass/library/singleton.h"
|
#include "cutlass/library/singleton.h"
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "triton/tools/bench.hpp"
|
#include "triton/tools/bench.hpp"
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
using namespace cutlass;
|
using namespace cutlass;
|
||||||
using namespace cutlass::library;
|
using namespace cutlass::library;
|
||||||
@@ -132,58 +130,56 @@ const Operation *autotune(int M, int N, int K,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// map of torch datatypes to cutlass datatypes
|
// map of torch datatypes to cutlass datatypes
|
||||||
std::map<caffe2::TypeIdentifier, NumericTypeID> type_map = {
|
std::map<std::string, NumericTypeID> type_map = {
|
||||||
{caffe2::TypeMeta::Id<at::Half>(), NumericTypeID::kF16},
|
{"float16", NumericTypeID::kF16},
|
||||||
{caffe2::TypeMeta::Id<float>(), NumericTypeID::kF32},
|
{"float32", NumericTypeID::kF32},
|
||||||
{caffe2::TypeMeta::Id<double>(), NumericTypeID::kF64}};
|
{"float64", NumericTypeID::kF64}};
|
||||||
|
|
||||||
void cutlass_matmul(torch::Tensor A, torch::Tensor B, torch::Tensor C) {
|
void cutlass_matmul(uintptr_t A, uintptr_t B, uintptr_t C,
|
||||||
size_t M = A.size(0);
|
size_t M, size_t N, size_t K,
|
||||||
size_t N = B.size(1);
|
size_t stride_a_0, size_t stride_a_1,
|
||||||
size_t K = A.size(1);
|
size_t stride_b_0, size_t stride_b_1,
|
||||||
size_t lda = A.stride(0);
|
size_t stride_c_0, size_t stride_c_1,
|
||||||
size_t ldb = B.stride(0);
|
std::string type_a, std::string type_b, std::string type_c,
|
||||||
size_t ldc = C.stride(1);
|
size_t dev_id, uint64_t stream_handle) {
|
||||||
size_t ldd = C.stride(1);
|
void *ptr_A = (void *)A;
|
||||||
void *ptr_A = A.data_ptr();
|
void *ptr_B = (void *)B;
|
||||||
void *ptr_B = B.data_ptr();
|
void *ptr_C = (void *)C;
|
||||||
void *ptr_C = C.data_ptr();
|
|
||||||
void *ptr_D = ptr_C;
|
void *ptr_D = ptr_C;
|
||||||
|
size_t lda = stride_a_0;
|
||||||
|
size_t ldb = stride_b_0;
|
||||||
|
size_t ldc = stride_c_1;
|
||||||
|
size_t ldd = ldc;
|
||||||
float alpha = 1.0f;
|
float alpha = 1.0f;
|
||||||
float beta = 0.0f;
|
float beta = 0.0f;
|
||||||
// layout for A
|
// layout for A
|
||||||
LayoutTypeID layout_A;
|
LayoutTypeID layout_A;
|
||||||
if (A.stride(0) == 1)
|
if (stride_a_0 == 1)
|
||||||
layout_A = LayoutTypeID::kColumnMajor;
|
layout_A = LayoutTypeID::kColumnMajor;
|
||||||
else if (A.stride(1) == 1)
|
else if (stride_a_1 == 1)
|
||||||
layout_A = LayoutTypeID::kRowMajor;
|
layout_A = LayoutTypeID::kRowMajor;
|
||||||
else {
|
else
|
||||||
A = A.contiguous();
|
throw std::runtime_error("A layout is not supported");
|
||||||
layout_A = LayoutTypeID::kRowMajor;
|
|
||||||
}
|
|
||||||
// layout for B
|
// layout for B
|
||||||
LayoutTypeID layout_B;
|
LayoutTypeID layout_B;
|
||||||
if (B.stride(0) == 1)
|
if (stride_b_0 == 1)
|
||||||
layout_B = LayoutTypeID::kColumnMajor;
|
layout_B = LayoutTypeID::kColumnMajor;
|
||||||
else if (B.stride(1) == 1)
|
else if (stride_b_1 == 1)
|
||||||
layout_B = LayoutTypeID::kRowMajor;
|
layout_B = LayoutTypeID::kRowMajor;
|
||||||
else {
|
else
|
||||||
B = B.contiguous();
|
throw std::runtime_error("B layout is not supported");
|
||||||
layout_B = LayoutTypeID::kRowMajor;
|
|
||||||
}
|
|
||||||
// data types
|
// data types
|
||||||
NumericTypeID element_compute = NumericTypeID::kF32;
|
NumericTypeID element_compute = NumericTypeID::kF32;
|
||||||
NumericTypeID element_A = type_map[A.dtype().id()];
|
NumericTypeID element_A = type_map[type_a];
|
||||||
NumericTypeID element_B = type_map[B.dtype().id()];
|
NumericTypeID element_B = type_map[type_b];
|
||||||
NumericTypeID element_C = type_map[C.dtype().id()];
|
NumericTypeID element_C = type_map[type_c];
|
||||||
// misc. flags
|
// misc. flags
|
||||||
ScalarPointerMode scalar_mode = ScalarPointerMode::kHost;
|
ScalarPointerMode scalar_mode = ScalarPointerMode::kHost;
|
||||||
NumericTypeID element_scalar = NumericTypeID::kF32;
|
NumericTypeID element_scalar = NumericTypeID::kF32;
|
||||||
ComplexTransform transform_A = ComplexTransform::kNone;
|
ComplexTransform transform_A = ComplexTransform::kNone;
|
||||||
ComplexTransform transform_B = ComplexTransform::kNone;
|
ComplexTransform transform_B = ComplexTransform::kNone;
|
||||||
// runtime flags
|
// runtime flags
|
||||||
size_t dev_id = C.device().index();
|
cudaStream_t stream = (cudaStream_t)stream_handle;
|
||||||
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
|
||||||
// auto-tune
|
// auto-tune
|
||||||
std::vector<size_t> tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C,
|
std::vector<size_t> tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C,
|
||||||
dev_id, (size_t)element_compute, (size_t)scalar_mode};
|
dev_id, (size_t)element_compute, (size_t)scalar_mode};
|
||||||
|
@@ -1 +0,0 @@
|
|||||||
../../include/
|
|
@@ -1 +0,0 @@
|
|||||||
../../lib/
|
|
@@ -8,7 +8,6 @@ void init_cutlass(pybind11::module &m);
|
|||||||
PYBIND11_MODULE(libtriton, m) {
|
PYBIND11_MODULE(libtriton, m) {
|
||||||
m.doc() = "Python bindings to the C++ Triton API";
|
m.doc() = "Python bindings to the C++ Triton API";
|
||||||
init_triton(m);
|
init_triton(m);
|
||||||
init_torch_utils(m);
|
|
||||||
init_superblocking(m);
|
init_superblocking(m);
|
||||||
#ifdef WITH_CUTLASS_BINDINGS
|
#ifdef WITH_CUTLASS_BINDINGS
|
||||||
init_cutlass(m);
|
init_cutlass(m);
|
||||||
|
119
python/src/superblock.cc
Normal file
119
python/src/superblock.cc
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <pybind11/numpy.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <vector>
|
||||||
|
#ifdef _OPENMP
|
||||||
|
#include <omp.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// row-major 3d tensor
|
||||||
|
class tensor_3d {
|
||||||
|
public:
|
||||||
|
tensor_3d(int size_0, int size_1, int size_2, int *data = nullptr) : data_(size_0 * size_1 * size_2, 0) {
|
||||||
|
if (data)
|
||||||
|
std::copy(data, data + data_.size(), data_.begin());
|
||||||
|
stride_0_ = size_1 * size_2;
|
||||||
|
stride_1_ = size_2;
|
||||||
|
stride_2_ = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int &operator()(int i, int j, int k) {
|
||||||
|
return data_[i * stride_0_ + j * stride_1_ + k];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int> data_;
|
||||||
|
int stride_0_;
|
||||||
|
int stride_1_;
|
||||||
|
int stride_2_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width, int H, int M, int N) {
|
||||||
|
tensor_3d tmp(H, M, N);
|
||||||
|
std::vector<int> current(H, 0);
|
||||||
|
int num = 0;
|
||||||
|
std::vector<int> lut(H * M * N * 4);
|
||||||
|
for (size_t h = 0; h < H; h++) {
|
||||||
|
// surrounding indices
|
||||||
|
std::vector<int> ii_left(max_width, -1);
|
||||||
|
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
||||||
|
// start the dynamic programming algorithm
|
||||||
|
for (size_t m = 0; m < M; m++) {
|
||||||
|
for (size_t n = 0; n < N; n++) {
|
||||||
|
int v = layout(h, m, n);
|
||||||
|
if (v == 0)
|
||||||
|
continue;
|
||||||
|
int n_left = ii_left[max_width - 1];
|
||||||
|
int m_top = ii_top[max_width - 1][n];
|
||||||
|
int top = (m_top >= 0) ? tmp(h, m_top, n) : 0;
|
||||||
|
int left = (n_left >= 0) ? tmp(h, m, n_left) : 0;
|
||||||
|
int topleft = (m_top >= 0 && n_left >= 0) ? tmp(h, m_top, n_left) : 0;
|
||||||
|
int width = std::min(left, std::min(top, topleft)) + 1;
|
||||||
|
// reset width if blocks cannot be
|
||||||
|
// packed together (i.e., there's a 1 "in the middle")
|
||||||
|
for (int nn = n_left + 1; nn < n; nn++)
|
||||||
|
if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n])
|
||||||
|
width = 1;
|
||||||
|
tmp(h, m, n) = width;
|
||||||
|
// update n_left ring buffer
|
||||||
|
for (int k = 0; k < max_width - 1; k++)
|
||||||
|
ii_left[k] = ii_left[k + 1];
|
||||||
|
ii_left[max_width - 1] = n;
|
||||||
|
// update ii_top ring buffer
|
||||||
|
for (int k = 0; k < max_width - 1; k++)
|
||||||
|
ii_top[k][n] = ii_top[k + 1][n];
|
||||||
|
ii_top[max_width - 1][n] = m;
|
||||||
|
// block is too small -- skip
|
||||||
|
if (width != max_width)
|
||||||
|
continue;
|
||||||
|
// retained blocks are set to zeros
|
||||||
|
for (size_t km = 0; km < max_width; km++)
|
||||||
|
for (size_t kn = 0; kn < max_width; kn++) {
|
||||||
|
int mm = ii_top[km][n];
|
||||||
|
int nn = ii_left[kn];
|
||||||
|
if (mm < 0 || nn < 0)
|
||||||
|
continue;
|
||||||
|
layout(h, mm, nn) = 0;
|
||||||
|
tmp(h, mm, nn) = 0;
|
||||||
|
lut[num++] = (int)h;
|
||||||
|
lut[num++] = (int)mm;
|
||||||
|
lut[num++] = (int)nn;
|
||||||
|
lut[num++] = idx(h, mm, nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lut.resize(num);
|
||||||
|
return lut;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef std::pair<int, pybind11::array_t<int>> lut_t;
|
||||||
|
|
||||||
|
std::vector<lut_t> superblock(uintptr_t LAYOUT, int H, int M, int N, int start_width) {
|
||||||
|
std::vector<lut_t> ret;
|
||||||
|
int current = 0;
|
||||||
|
tensor_3d layout(H, M, N, (int *)LAYOUT);
|
||||||
|
tensor_3d idx(H, M, N);
|
||||||
|
for (int64_t h = 0; h < H; h++)
|
||||||
|
for (int64_t m = 0; m < M; m++)
|
||||||
|
for (int64_t n = 0; n < N; n++) {
|
||||||
|
if (layout(h, m, n) == 0)
|
||||||
|
continue;
|
||||||
|
idx(h, m, n) = current++;
|
||||||
|
}
|
||||||
|
// create lut
|
||||||
|
for (int max_width = start_width; max_width > 0; max_width /= 2) {
|
||||||
|
auto lut = segment_blocks(layout, idx, max_width, H, M, N);
|
||||||
|
if (lut.size() == 0)
|
||||||
|
continue;
|
||||||
|
ret.push_back(std::make_pair(max_width, pybind11::array_t<int>(lut.size(), lut.data())));
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_superblocking(pybind11::module &m) {
|
||||||
|
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
||||||
|
}
|
@@ -1,117 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
#include <string>
|
|
||||||
#include <tuple>
|
|
||||||
#include <vector>
|
|
||||||
#ifdef _OPENMP
|
|
||||||
#include <omp.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
|
|
||||||
|
|
||||||
void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){
|
|
||||||
size_t H = layout.size(0);
|
|
||||||
size_t M = layout.size(1);
|
|
||||||
size_t N = layout.size(2);
|
|
||||||
torch::Tensor tmp = torch::zeros_like(layout);
|
|
||||||
auto _tmp = tmp.accessor <int, 3>();
|
|
||||||
auto _layout = layout.accessor <int, 3>();
|
|
||||||
auto _idx = idx.accessor <int, 3>();
|
|
||||||
auto _scratch = scratch.accessor<int, 3>();
|
|
||||||
std::vector<int> current(H, 0);
|
|
||||||
#ifdef _OPENMP
|
|
||||||
#pragma omp parallel for
|
|
||||||
#endif
|
|
||||||
for(size_t h = 0; h < H; h++){
|
|
||||||
// surrounding indices
|
|
||||||
std::vector<int> ii_left(max_width, -1);
|
|
||||||
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
|
||||||
|
|
||||||
for(size_t m = 0; m < M; m++){
|
|
||||||
for(size_t n = 0; n < N; n++){
|
|
||||||
int v = _layout[h][m][n];
|
|
||||||
if(v == 0)
|
|
||||||
continue;
|
|
||||||
int n_left= ii_left[max_width-1];
|
|
||||||
int m_top = ii_top [max_width-1][n];
|
|
||||||
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
|
|
||||||
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
|
|
||||||
int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
|
|
||||||
int width = std::min(left, std::min(top, topleft)) + 1;
|
|
||||||
|
|
||||||
// reset width if blocks cannot be
|
|
||||||
// packed together (i.e., there's a 1 "in the middle")
|
|
||||||
for(int nn = n_left + 1; nn < n; nn++)
|
|
||||||
if(ii_top[max_width-1][nn] > ii_top[max_width-1][n])
|
|
||||||
width = 1;
|
|
||||||
_tmp[h][m][n] = width;
|
|
||||||
|
|
||||||
// update n_left ring buffer
|
|
||||||
for(int k = 0; k < max_width-1; k++)
|
|
||||||
ii_left[k] = ii_left[k+1];
|
|
||||||
ii_left[max_width-1] = n;
|
|
||||||
|
|
||||||
// update ii_top ring buffer
|
|
||||||
for(int k = 0; k < max_width-1; k++)
|
|
||||||
ii_top[k][n] = ii_top[k+1][n];
|
|
||||||
ii_top[max_width-1][n] = m;
|
|
||||||
|
|
||||||
// block is too small -- skip
|
|
||||||
if(width != max_width)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// retained blocks are set to zeros
|
|
||||||
for(size_t km = 0; km < max_width; km++)
|
|
||||||
for(size_t kn = 0; kn < max_width; kn++)
|
|
||||||
{
|
|
||||||
int mm = ii_top[km][n];
|
|
||||||
int nn = ii_left[kn];
|
|
||||||
if(mm < 0 || nn < 0)
|
|
||||||
continue;
|
|
||||||
_layout[h][mm][nn] = 0;
|
|
||||||
_tmp[h][mm][nn] = 0;
|
|
||||||
_scratch[h][current[h]][0] = (int)h;
|
|
||||||
_scratch[h][current[h]][1] = (int)mm;
|
|
||||||
_scratch[h][current[h]][2] = (int)nn;
|
|
||||||
_scratch[h][current[h]][3] = _idx[h][mm][nn];
|
|
||||||
current[h]++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::vector<torch::Tensor> to_cat;
|
|
||||||
for(size_t h = 0; h < H; h++)
|
|
||||||
if(current[h] > 0)
|
|
||||||
to_cat.push_back(scratch[h].slice(0, 0, current[h]));
|
|
||||||
if(!to_cat.empty())
|
|
||||||
ret.push_back(std::make_tuple(max_width, torch::cat(to_cat)));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
ret_t superblock(torch::Tensor layout, int start_width) {
|
|
||||||
ret_t ret;
|
|
||||||
// block index
|
|
||||||
torch::Tensor idx = torch::zeros_like(layout);
|
|
||||||
int current = 0;
|
|
||||||
int64_t H = layout.size(0);
|
|
||||||
int64_t M = layout.size(1);
|
|
||||||
int64_t N = layout.size(2);
|
|
||||||
auto _layout = layout.accessor <int, 3>();
|
|
||||||
auto _idx = idx.accessor<int, 3>();
|
|
||||||
for(int64_t h = 0; h < H; h++)
|
|
||||||
for(int64_t m = 0; m < M; m++)
|
|
||||||
for(int64_t n = 0; n < N; n++){
|
|
||||||
if(_layout[h][m][n] == 0)
|
|
||||||
continue;
|
|
||||||
_idx[h][m][n] = current++;
|
|
||||||
}
|
|
||||||
// scratch memory
|
|
||||||
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
|
|
||||||
for(int max_width = start_width; max_width > 0; max_width /= 2)
|
|
||||||
segment_blocks(layout, idx, scratch, max_width, ret);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void init_superblocking(pybind11::module &m) {
|
|
||||||
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
|
||||||
}
|
|
@@ -1,32 +0,0 @@
|
|||||||
|
|
||||||
#include "triton/driver/device.h"
|
|
||||||
#include "triton/driver/stream.h"
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <cuda_runtime_api.h>
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
namespace torch_utils {
|
|
||||||
|
|
||||||
uint64_t cu_device(int64_t dev_id) {
|
|
||||||
CUdevice handle;
|
|
||||||
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
|
|
||||||
return (uint64_t)handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t cu_stream(int64_t dev_id) {
|
|
||||||
return (uint64_t)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_device(int64_t dev_id) {
|
|
||||||
if (dev_id >= 0)
|
|
||||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace torch_utils
|
|
||||||
|
|
||||||
void init_torch_utils(pybind11::module &m) {
|
|
||||||
pybind11::module subm = m.def_submodule("torch_utils");
|
|
||||||
subm.def("cu_device", &torch_utils::cu_device);
|
|
||||||
subm.def("cu_stream", &torch_utils::cu_stream);
|
|
||||||
subm.def("set_device", &torch_utils::set_device);
|
|
||||||
}
|
|
@@ -89,7 +89,11 @@ void init_triton_driver(py::module &&m) {
|
|||||||
py::class_<drv::device>(m, "device");
|
py::class_<drv::device>(m, "device");
|
||||||
// cuda device
|
// cuda device
|
||||||
py::class_<drv::cu_device, driver::device>(m, "cu_device")
|
py::class_<drv::cu_device, driver::device>(m, "cu_device")
|
||||||
.def(py::init<CUdevice, bool>());
|
.def(py::init([](int dev_id, bool take_ownership) {
|
||||||
|
CUdevice handle;
|
||||||
|
drv::dispatch::cuDeviceGet(&handle, dev_id);
|
||||||
|
return new drv::cu_device(handle, take_ownership);
|
||||||
|
}));
|
||||||
// host device
|
// host device
|
||||||
py::class_<drv::host_device, driver::device>(m, "host_device")
|
py::class_<drv::host_device, driver::device>(m, "host_device")
|
||||||
.def(py::init<>());
|
.def(py::init<>());
|
||||||
|
@@ -2,12 +2,15 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE",
|
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE",
|
||||||
[(mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"]
|
[
|
||||||
for at in [False, True] for bt in [False, True] for block in [16, 32, 64]],
|
(mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"]
|
||||||
|
for at in [False, True] for bt in [False, True] for block in [16, 32, 64]
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=128, N=256, K=384):
|
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
||||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||||
# set seed
|
# set seed
|
||||||
torch.random.manual_seed(0)
|
torch.random.manual_seed(0)
|
||||||
@@ -36,6 +39,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=128, N=256, K=
|
|||||||
# compare
|
# compare
|
||||||
assert triton.testing.allclose(rc, tc)
|
assert triton.testing.allclose(rc, tc)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"BLOCK, WIDTH",
|
"BLOCK, WIDTH",
|
||||||
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
|
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
|
||||||
@@ -76,6 +80,7 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
|||||||
# compare
|
# compare
|
||||||
assert triton.testing.allclose(ry, ty)
|
assert triton.testing.allclose(ry, ty)
|
||||||
|
|
||||||
|
|
||||||
def test_attention_fwd_bwd(
|
def test_attention_fwd_bwd(
|
||||||
input_scale=1.0,
|
input_scale=1.0,
|
||||||
tol=2e-2,
|
tol=2e-2,
|
||||||
@@ -88,9 +93,7 @@ def test_attention_fwd_bwd(
|
|||||||
):
|
):
|
||||||
# inputs
|
# inputs
|
||||||
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
||||||
qkvs = [
|
qkvs = [torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)]
|
||||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
|
||||||
]
|
|
||||||
attn_mask = torch.tril(
|
attn_mask = torch.tril(
|
||||||
torch.ones(
|
torch.ones(
|
||||||
[n_ctx, n_ctx],
|
[n_ctx, n_ctx],
|
||||||
@@ -134,6 +137,7 @@ def test_attention_fwd_bwd(
|
|||||||
for g1, g2 in zip(grads, torch_grads):
|
for g1, g2 in zip(grads, torch_grads):
|
||||||
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
|
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
|
||||||
|
|
||||||
|
|
||||||
def triton_attention(
|
def triton_attention(
|
||||||
layout,
|
layout,
|
||||||
block: int,
|
block: int,
|
||||||
|
@@ -5,8 +5,6 @@ import torch
|
|||||||
from . import testing
|
from . import testing
|
||||||
from .kernel import *
|
from .kernel import *
|
||||||
from . import ops
|
from . import ops
|
||||||
# C bindings
|
|
||||||
import triton._C.libtriton.torch_utils as _torch_utils
|
|
||||||
|
|
||||||
# version
|
# version
|
||||||
__version__ = '1.0.0'
|
__version__ = '1.0.0'
|
@@ -4,14 +4,19 @@ from typing import Optional, Dict, List
|
|||||||
import torch
|
import torch
|
||||||
# C bindings
|
# C bindings
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
import triton._C.libtriton.torch_utils as _torch_utils
|
|
||||||
|
|
||||||
codes = {
|
codes = {
|
||||||
_triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I',
|
_triton.runtime.arg_type.int1: 'B',
|
||||||
_triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f',
|
_triton.runtime.arg_type.int8: 'B',
|
||||||
_triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P'
|
_triton.runtime.arg_type.int32: 'I',
|
||||||
|
_triton.runtime.arg_type.int64: 'Q',
|
||||||
|
_triton.runtime.arg_type.half: 'H',
|
||||||
|
_triton.runtime.arg_type.float: 'f',
|
||||||
|
_triton.runtime.arg_type.double: 'd',
|
||||||
|
_triton.runtime.arg_type.buffer: 'P'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def th_to_triton(obj):
|
def th_to_triton(obj):
|
||||||
tys = {
|
tys = {
|
||||||
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long',\
|
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long',\
|
||||||
@@ -21,9 +26,11 @@ def th_to_triton(obj):
|
|||||||
return tys[obj]
|
return tys[obj]
|
||||||
return str(obj)
|
return str(obj)
|
||||||
|
|
||||||
|
|
||||||
def cdiv(a, b):
|
def cdiv(a, b):
|
||||||
return (a + b - 1) // b
|
return (a + b - 1) // b
|
||||||
|
|
||||||
|
|
||||||
def read(path, kernel_names: Optional[List] = None):
|
def read(path, kernel_names: Optional[List] = None):
|
||||||
if kernel_names is None:
|
if kernel_names is None:
|
||||||
kernel_names = []
|
kernel_names = []
|
||||||
@@ -32,11 +39,20 @@ def read(path, kernel_names: Optional[List] = None):
|
|||||||
source = _triton.tools.extract_kernels(source, kernel_names)
|
source = _triton.tools.extract_kernels(source, kernel_names)
|
||||||
return source
|
return source
|
||||||
|
|
||||||
|
|
||||||
config = _triton.runtime.config
|
config = _triton.runtime.config
|
||||||
|
|
||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
def __init__(self, src, device, defines: Optional[Dict] = None, num_warps: int = 4,
|
def __init__(
|
||||||
autotune_vals: Optional[List] = None, autotune_key: Optional[List] = None):
|
self,
|
||||||
|
src,
|
||||||
|
device,
|
||||||
|
defines: Optional[Dict] = None,
|
||||||
|
num_warps: int = 4,
|
||||||
|
autotune_vals: Optional[List] = None,
|
||||||
|
autotune_key: Optional[List] = None
|
||||||
|
):
|
||||||
if defines is None:
|
if defines is None:
|
||||||
defines = {}
|
defines = {}
|
||||||
if autotune_vals is None:
|
if autotune_vals is None:
|
||||||
@@ -51,13 +67,14 @@ class kernel:
|
|||||||
assert device.type in ['cuda', 'cpu']
|
assert device.type in ['cuda', 'cpu']
|
||||||
if device.type == 'cuda':
|
if device.type == 'cuda':
|
||||||
self.device_id = torch.cuda.current_device() if device.index is None else device.index
|
self.device_id = torch.cuda.current_device() if device.index is None else device.index
|
||||||
self.device = _triton.driver.cu_device(_torch_utils.cu_device(self.device_id), False)
|
self.device = _triton.driver.cu_device(self.device_id, False)
|
||||||
self.stream = _triton.driver.cu_stream(_torch_utils.cu_stream(self.device_id), False)
|
cu_stream = torch.cuda.current_stream(self.device_id).cuda_stream
|
||||||
|
self.stream = _triton.driver.cu_stream(cu_stream, False)
|
||||||
if device.type == 'cpu':
|
if device.type == 'cpu':
|
||||||
self.device_id = -1
|
self.device_id = -1
|
||||||
self.device = _triton.driver.host_device()
|
self.device = _triton.driver.host_device()
|
||||||
self.device = _triton.driver.host_stream()
|
self.device = _triton.driver.host_stream()
|
||||||
_torch_utils.set_device(self.device_id)
|
torch.cuda.set_device(self.device_id)
|
||||||
# function
|
# function
|
||||||
self.opt = _triton.runtime.options()
|
self.opt = _triton.runtime.options()
|
||||||
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
||||||
@@ -68,7 +85,7 @@ class kernel:
|
|||||||
|
|
||||||
def __call__(self, *args, grid):
|
def __call__(self, *args, grid):
|
||||||
# make sure that the executing thread is on the right device
|
# make sure that the executing thread is on the right device
|
||||||
_torch_utils.set_device(self.device_id)
|
torch.cuda.set_device(self.device_id)
|
||||||
# pack parameters into a byte buffer
|
# pack parameters into a byte buffer
|
||||||
params = struct.pack(self.tys, *args)
|
params = struct.pack(self.tys, *args)
|
||||||
kernel = self.fn.autotune(params, grid, self.stream)
|
kernel = self.fn.autotune(params, grid, self.stream)
|
||||||
|
@@ -6,6 +6,7 @@ import math
|
|||||||
|
|
||||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
||||||
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# MAIN API #
|
# MAIN API #
|
||||||
##############
|
##############
|
||||||
@@ -82,16 +83,13 @@ class _matmul(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def make_sdd_lut(layout, block, dtype, device):
|
def make_sdd_lut(layout, block, dtype, device):
|
||||||
start_width = 128 // block
|
start_width = 128 // block
|
||||||
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
|
layout = layout.type(torch.int32)
|
||||||
|
superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
|
||||||
luts, widths, packs = [], [], []
|
luts, widths, packs = [], [], []
|
||||||
for size, nnz in superblocks:
|
for size, nnz in superblocks:
|
||||||
|
nnz = nnz.reshape(-1, 4)
|
||||||
width = nnz.shape[0] // (size * size)
|
width = nnz.shape[0] // (size * size)
|
||||||
h = nnz[:, 0]
|
luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
|
||||||
i = nnz[:, 1]
|
|
||||||
j = nnz[:, 2]
|
|
||||||
b = nnz[:, 3]
|
|
||||||
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
|
|
||||||
luts.append(lut.type(torch.int32).to(device))
|
|
||||||
widths.append(width)
|
widths.append(width)
|
||||||
packs.append(size)
|
packs.append(size)
|
||||||
# create locks
|
# create locks
|
||||||
@@ -126,10 +124,21 @@ class _matmul(torch.autograd.Function):
|
|||||||
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
||||||
if key not in _matmul.sdd_cache:
|
if key not in _matmul.sdd_cache:
|
||||||
defines = {
|
defines = {
|
||||||
'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK':
|
'TM': block * pack,
|
||||||
32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1',
|
'TN': block * pack,
|
||||||
'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc',
|
'TMN': block * block * pack * pack,
|
||||||
'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'
|
'BLOCK': block,
|
||||||
|
'TK': 32,
|
||||||
|
'TYPE': dtype,
|
||||||
|
'STRIDE_AM': '1' if trans_a else 'lda',
|
||||||
|
'STRIDE_AK': 'lda' if trans_a else '1',
|
||||||
|
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||||
|
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||||
|
'STRIDE_CM': 'ldc',
|
||||||
|
'STRIDE_CN': '1',
|
||||||
|
'SDD': True,
|
||||||
|
'TZ': 1,
|
||||||
|
'NAME': 'sdd_kernel'
|
||||||
}
|
}
|
||||||
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
|
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
|
||||||
|
|
||||||
@@ -141,10 +150,28 @@ class _matmul(torch.autograd.Function):
|
|||||||
# kernel calls
|
# kernel calls
|
||||||
max_width = 49152
|
max_width = 49152
|
||||||
for off_width in range(0, width, max_width):
|
for off_width in range(0, width, max_width):
|
||||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0),
|
kernel(
|
||||||
b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width,
|
a.data_ptr(),
|
||||||
lut.data_ptr(), locks.data_ptr(), num_lock,
|
b.data_ptr(),
|
||||||
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
|
c.data_ptr(),
|
||||||
|
a.stride(2),
|
||||||
|
b.stride(2),
|
||||||
|
block,
|
||||||
|
a.stride(0),
|
||||||
|
b.stride(0),
|
||||||
|
c.stride(0),
|
||||||
|
a.stride(1),
|
||||||
|
b.stride(1),
|
||||||
|
c.stride(0),
|
||||||
|
AS2,
|
||||||
|
AS2,
|
||||||
|
AS3,
|
||||||
|
off_width,
|
||||||
|
lut.data_ptr(),
|
||||||
|
locks.data_ptr(),
|
||||||
|
num_lock,
|
||||||
|
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]
|
||||||
|
)
|
||||||
# save for backward pass
|
# save for backward pass
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@@ -258,10 +285,19 @@ class _matmul(torch.autograd.Function):
|
|||||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||||
if key not in _matmul.dds_cache:
|
if key not in _matmul.dds_cache:
|
||||||
defines = {
|
defines = {
|
||||||
'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda',
|
'TM': 128,
|
||||||
'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK':
|
'TN': block,
|
||||||
1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1',
|
'TK': 16,
|
||||||
'NAME': 'dds_kernel', 'DDS': True
|
'BLOCK': block,
|
||||||
|
'TYPE': dtype,
|
||||||
|
'STRIDE_AM': 1 if trans_a else 'lda',
|
||||||
|
'STRIDE_AK': 'lda' if trans_a else 1,
|
||||||
|
'STRIDE_BN': block if trans_b else 1,
|
||||||
|
'STRIDE_BK': 1 if trans_b else block,
|
||||||
|
'STRIDE_CM': '1' if trans_c else 'ldc',
|
||||||
|
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||||
|
'NAME': 'dds_kernel',
|
||||||
|
'DDS': True
|
||||||
}
|
}
|
||||||
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||||
kernel = _matmul.dds_cache[key]
|
kernel = _matmul.dds_cache[key]
|
||||||
@@ -272,9 +308,28 @@ class _matmul(torch.autograd.Function):
|
|||||||
CS3 = AS2 if trans_c else BS2
|
CS3 = AS2 if trans_c else BS2
|
||||||
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
|
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
|
||||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0),
|
kernel(
|
||||||
c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(),
|
a.data_ptr(),
|
||||||
num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
|
b.data_ptr(),
|
||||||
|
c.data_ptr(),
|
||||||
|
a.stride(2),
|
||||||
|
block,
|
||||||
|
c.stride(2),
|
||||||
|
a.stride(0),
|
||||||
|
b.stride(0),
|
||||||
|
c.stride(0),
|
||||||
|
a.stride(1),
|
||||||
|
b.stride(1),
|
||||||
|
c.stride(1),
|
||||||
|
AS2,
|
||||||
|
BS2,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
lut.data_ptr(),
|
||||||
|
locks.data_ptr(),
|
||||||
|
num_locks,
|
||||||
|
grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0]
|
||||||
|
)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -292,10 +347,19 @@ class _matmul(torch.autograd.Function):
|
|||||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||||
if key not in _matmul.dsd_cache:
|
if key not in _matmul.dsd_cache:
|
||||||
defines = {
|
defines = {
|
||||||
'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block,
|
'TM': block,
|
||||||
'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK':
|
'TN': 128,
|
||||||
'1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN':
|
'TK': 16,
|
||||||
'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True
|
'BLOCK': block,
|
||||||
|
'TYPE': dtype,
|
||||||
|
'STRIDE_AM': 1 if trans_a else block,
|
||||||
|
'STRIDE_AK': block if trans_a else 1,
|
||||||
|
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||||
|
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||||
|
'STRIDE_CM': '1' if trans_c else 'ldc',
|
||||||
|
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||||
|
'NAME': 'dsd_kernel',
|
||||||
|
'DSD': True
|
||||||
}
|
}
|
||||||
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||||
kernel = _matmul.dsd_cache[key]
|
kernel = _matmul.dsd_cache[key]
|
||||||
@@ -306,16 +370,37 @@ class _matmul(torch.autograd.Function):
|
|||||||
CS3 = AS1 if trans_c else BS3
|
CS3 = AS1 if trans_c else BS3
|
||||||
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
|
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
|
||||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0),
|
kernel(
|
||||||
c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(),
|
a.data_ptr(),
|
||||||
num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
|
b.data_ptr(),
|
||||||
|
c.data_ptr(),
|
||||||
|
block,
|
||||||
|
b.stride(2),
|
||||||
|
c.stride(2),
|
||||||
|
a.stride(0),
|
||||||
|
b.stride(0),
|
||||||
|
c.stride(0),
|
||||||
|
a.stride(1),
|
||||||
|
b.stride(1),
|
||||||
|
c.stride(1),
|
||||||
|
BS3,
|
||||||
|
AS1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
lut.data_ptr(),
|
||||||
|
locks.data_ptr(),
|
||||||
|
num_locks,
|
||||||
|
grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0]
|
||||||
|
)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
|
fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut,
|
def forward(
|
||||||
da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs):
|
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut, da_num_locks,
|
||||||
|
da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
|
||||||
|
):
|
||||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
|
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
|
||||||
# save for backward
|
# save for backward
|
||||||
ctx.save_for_backward(a, b)
|
ctx.save_for_backward(a, b)
|
||||||
@@ -342,19 +427,24 @@ class _matmul(torch.autograd.Function):
|
|||||||
# gradients w.r.t. a
|
# gradients w.r.t. a
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
mode_da = mode[1] + mode[0] + mode[2]
|
mode_da = mode[1] + mode[0] + mode[2]
|
||||||
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut,
|
da = _matmul.fn[mode_da](
|
||||||
ctx.da_num_locks, ctx.da_width, ctx.da_packs)
|
dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_num_locks, ctx.da_width,
|
||||||
|
ctx.da_packs
|
||||||
|
)
|
||||||
# gradients w.r.t. b
|
# gradients w.r.t. b
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
mode_db = mode[2] + mode[1] + mode[0]
|
mode_db = mode[2] + mode[1] + mode[0]
|
||||||
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut,
|
db = _matmul.fn[mode_db](
|
||||||
ctx.db_num_locks, ctx.db_width, ctx.db_packs)
|
a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_num_locks, ctx.db_width,
|
||||||
|
ctx.db_packs
|
||||||
|
)
|
||||||
return da, db, None, None, None,\
|
return da, db, None, None, None,\
|
||||||
None, None, None, None,\
|
None, None, None, None,\
|
||||||
None, None, None, None, None, None,\
|
None, None, None, None, None, None,\
|
||||||
None, None, None, None, None, None,\
|
None, None, None, None, None, None,\
|
||||||
None, None, None, None, None, None
|
None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class matmul:
|
class matmul:
|
||||||
def make_lut(self, dtype, device):
|
def make_lut(self, dtype, device):
|
||||||
key = (dtype, device)
|
key = (dtype, device)
|
||||||
@@ -375,8 +465,7 @@ class matmul:
|
|||||||
elif self.mode == 'dsd':
|
elif self.mode == 'dsd':
|
||||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||||
elif self.mode == 'dds':
|
elif self.mode == 'dds':
|
||||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b,
|
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
|
||||||
device)
|
|
||||||
# DB look-up table
|
# DB look-up table
|
||||||
if self.mode == 'sdd':
|
if self.mode == 'sdd':
|
||||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
|
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
|
||||||
@@ -419,7 +508,8 @@ class matmul:
|
|||||||
a = matmul._pad_shape(a, self.mode == 'dsd')
|
a = matmul._pad_shape(a, self.mode == 'dsd')
|
||||||
b = matmul._pad_shape(b, self.mode == 'dds')
|
b = matmul._pad_shape(b, self.mode == 'dds')
|
||||||
# execute
|
# execute
|
||||||
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,
|
c = _matmul.apply(
|
||||||
c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks,
|
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs,
|
||||||
db_width, db_packs)
|
da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
|
||||||
|
)
|
||||||
return c
|
return c
|
||||||
|
@@ -18,8 +18,22 @@ def cutlass_matmul(a, b):
|
|||||||
if _cutlass is None:
|
if _cutlass is None:
|
||||||
raise RuntimeError("Cannot find cutlass library")
|
raise RuntimeError("Cannot find cutlass library")
|
||||||
M, N = a.shape[0], b.shape[1]
|
M, N = a.shape[0], b.shape[1]
|
||||||
|
Ka, Kb = a.shape[1], b.shape[0]
|
||||||
|
assert Ka == Kb
|
||||||
|
assert a.dtype == b.dtype
|
||||||
|
assert a.device == b.device
|
||||||
|
# allocate output
|
||||||
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
|
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
|
||||||
_cutlass.matmul(a, b, c)
|
# run function
|
||||||
|
dtype = str(a.dtype).split('.')[-1]
|
||||||
|
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
|
||||||
|
M, N, Ka,\
|
||||||
|
a.stride(0), a.stride(1),\
|
||||||
|
b.stride(0), b.stride(1),\
|
||||||
|
c.stride(0), c.stride(1),\
|
||||||
|
dtype, dtype, dtype,
|
||||||
|
a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@@ -189,14 +189,14 @@ float triton_dot(drv::context* context, drv::stream* stream,
|
|||||||
// grid
|
// grid
|
||||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||||
auto grid = [ceil, M, N](const rt::options_t& x) {
|
auto grid = [ceil, M, N](const rt::options_t& x) {
|
||||||
return rt::grid_t{ceil(M, x.D<int>("TM"))*
|
return rt::kernel::grid_t{ceil(M, x.D<int>("TM"))*
|
||||||
ceil(N, x.D<int>("TN")),
|
ceil(N, x.D<int>("TN")),
|
||||||
(size_t)x.D<int>("TZ")};
|
(size_t)x.D<int>("TZ")};
|
||||||
};
|
};
|
||||||
|
|
||||||
// metrics
|
// metrics
|
||||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||||
double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream);}, stream);
|
double triton_ns = triton::tools::bench([&]() { function(oss.str(), grid, stream);}, stream);
|
||||||
return tflops(triton_ns);
|
return tflops(triton_ns);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user