diff --git a/cmake/FindTensorFlow.cmake b/cmake/FindTensorFlow.cmake index dcbb43924..405febbeb 100644 --- a/cmake/FindTensorFlow.cmake +++ b/cmake/FindTensorFlow.cmake @@ -17,5 +17,4 @@ if(TensorFlow_FOUND) set(TensorFlow_ABI ${TF_ABI}) endif() -# hide locals from GUI mark_as_advanced(TF_INC TF_LIB TF_ABI) diff --git a/cmake/FindTorch.cmake b/cmake/FindTorch.cmake index 906f021f3..56b1e7c16 100644 --- a/cmake/FindTorch.cmake +++ b/cmake/FindTorch.cmake @@ -1,101 +1,11 @@ -# FindTorch -# ------- -# -# Finds the Torch library -# -# This will define the following variables: -# -# TORCH_FOUND -- True if the system has the Torch library -# TORCH_INCLUDE_DIRS -- The include directories for torch -# TORCH_LIBRARIES -- Libraries to link against -# TORCH_CXX_FLAGS -- Additional (required) compiler flags -# -# and the following imported targets: -# -# torch - include(FindPackageHandleStandardArgs) +execute_process(COMMAND python -c "import torch; import os; print(os.path.dirname(torch.__file__))" + OUTPUT_VARIABLE TORCH_INSTALL_PREFIX OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) -if (DEFINED ENV{TORCH_INSTALL_PREFIX}) - set(TORCH_INSTALL_PREFIX $ENV{TORCH_INSTALL_PREFIX}) -else() - # Assume we are in /share/cmake/Torch/TorchConfig.cmake - get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) - get_filename_component(TORCH_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) +find_package_handle_standard_args(TORCH DEFAULT_MSG TORCH_INSTALL_PREFIX) +if(TORCH_INSTALL_PREFIX) + set(TORCH_INCLUDE_DIRS ${TORCH_INSTALL_PREFIX}/lib/include/ ${TORCH_INSTALL_PREFIX}/lib/include/torch/csrc/api/include) + set(TORCH_LIBRARY_DIRS ${TORCH_INSTALL_PREFIX}/lib/) endif() -# Include directories. -if (EXISTS "${TORCH_INSTALL_PREFIX}/include") - set(TORCH_INCLUDE_DIRS - ${TORCH_INSTALL_PREFIX}/include - ${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include) -else() - set(TORCH_INCLUDE_DIRS - ${TORCH_INSTALL_PREFIX}/include - ${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include) -endif() - -# Library dependencies. -if (@BUILD_SHARED_LIBS@) - find_package(Caffe2 REQUIRED PATHS ${CMAKE_CURRENT_LIST_DIR}/../Caffe2) -endif() - -if (NOT ANDROID) - find_library(TORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib") -else() - find_library(TORCH_LIBRARY NO_CMAKE_FIND_ROOT_PATH torch PATHS "${TORCH_INSTALL_PREFIX}/lib") -endif() -add_library(torch UNKNOWN IMPORTED) -set(TORCH_LIBRARIES torch ${Caffe2_MAIN_LIBS}) - -if (NOT ANDROID) - find_library(C10_LIBRARY c10 PATHS "${TORCH_INSTALL_PREFIX}/lib") -else() - find_library(C10_LIBRARY c10 NO_CMAKE_FIND_ROOT_PATH PATHS "${TORCH_INSTALL_PREFIX}/lib") -endif() -list(APPEND TORCH_LIBRARIES ${C10_LIBRARY}) - -if (@USE_CUDA@) - if(MSVC) - set(NVTOOLEXT_HOME "C:/Program Files/NVIDIA Corporation/NvToolsExt") - if ($ENV{NVTOOLEXT_HOME}) - set(NVTOOLEXT_HOME $ENV{NVTOOLEXT_HOME}) - endif() - set(TORCH_CUDA_LIBRARIES - ${NVTOOLEXT_HOME}/lib/x64/nvToolsExt64_1.lib - ${CUDA_LIBRARIES}) - list(APPEND TORCH_INCLUDE_DIRS ${NVTOOLEXT_HOME}/include) - elseif(APPLE) - set(TORCH_CUDA_LIBRARIES - ${CUDA_TOOLKIT_ROOT_DIR}/lib/libcudart.dylib - ${CUDA_TOOLKIT_ROOT_DIR}/lib/libnvrtc.dylib - ${CUDA_TOOLKIT_ROOT_DIR}/lib/libnvToolsExt.dylib - ${CUDA_LIBRARIES}) - else() - find_library(LIBNVTOOLSEXT libnvToolsExt.so PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64/) - set(TORCH_CUDA_LIBRARIES - ${CUDA_CUDA_LIB} - ${CUDA_NVRTC_LIB} - ${LIBNVTOOLSEXT} - ${CUDA_LIBRARIES}) - endif() - find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_CUDA_LIBRARIES ${C10_CUDA_LIBRARY}) - list(APPEND TORCH_LIBRARIES ${TORCH_CUDA_LIBRARIES}) -endif() - -# When we build libtorch with the old GCC ABI, dependent libraries must too. -if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=@GLIBCXX_USE_CXX11_ABI@") -endif() - -set_target_properties(torch PROPERTIES - IMPORTED_LOCATION "${TORCH_LIBRARY}" - INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}" - CXX_STANDARD 11 -) -if (TORCH_CXX_FLAGS) - set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS "${TORCH_CXX_FLAGS}") -endif() - -find_package_handle_standard_args(torch DEFAULT_MSG TORCH_LIBRARY TORCH_INCLUDE_DIRS) +mark_as_advanced(TORCH_INCLUDE_DIRS TORCH_LIBRARY_DIRS) diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 150fafb91..f8bec004e 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -217,7 +217,7 @@ int main() { 16, 2, 64, 32, 2, 64, 16, 8, 2, 2, - 8, 8, + 8, 1, 8, 4 }; // jit.autotune("conv", src, benchmark); diff --git a/examples/python/pytorch/CMakeLists.txt b/examples/python/pytorch/CMakeLists.txt index b400e1ef4..22e52c65d 100644 --- a/examples/python/pytorch/CMakeLists.txt +++ b/examples/python/pytorch/CMakeLists.txt @@ -1,6 +1,10 @@ find_package(Torch) -if(${Torch_FOUND}) +if(${TORCH_FOUND}) + set(CUDA_HOME "/usr/local/cuda") + include_directories(${TORCH_INCLUDE_DIRS}) + include_directories("${CUDA_HOME}/include") + link_directories(${TORCH_LIBRARY_DIRS}) + add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) add_library(torch_triton SHARED conv.cpp) - target_compile_features(torch_triton PRIVATE cxx_range_for) - target_link_libraries(torch_triton "${TORCH_LIBRARIES}") + target_link_libraries(torch_triton torch triton) endif() diff --git a/examples/python/pytorch/conv.cpp b/examples/python/pytorch/conv.cpp index 7230ed62e..d3d2bb212 100644 --- a/examples/python/pytorch/conv.cpp +++ b/examples/python/pytorch/conv.cpp @@ -1,13 +1,96 @@ #include +#include +#include "ATen/cuda/CUDAContext.h" #include +#include "triton/jit.h" +#include "triton/driver/stream.h" #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) -at::Tensor conv_forward( - const at::Tensor data, - const at::Tensor weight) { +const char* src = +R"( +const tunable int32 TM = {16, 32, 64}; +const tunable int32 TN = {16, 32, 64}; +const tunable int32 TK = {8}; + +__constant__ int32* delta = alloc_const int32[18]; +__constant__ int32* masks = alloc_const int32[1024]; + +void conv(read_only restrict fp32 *a, + read_only restrict fp32 *b, + fp32 *c, + int32 M, int32 N, int32 K, + int32 AN, int32 AH, int32 AW, + int32 CN, int32 CK, int32 CP, int32 CQ, + int32 AC, int32 AR, int32 AS, + int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w, + int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q, + int32 pad_h, int32 pad_w, + int32 bound){ + int32 rxa[TM] = get_global_range[TM](0); + int32 rb0[TN] = get_global_range[TN](1); + int32 rka[TK] = 0 ... TK; + int32 rb1[TK] = 0 ... TK; + fp32 C[TM, TN] = 0; + int32 ranh[TM] = rxa / CQ; + int32 raw[TM] = rxa % CQ - pad_w; + int32 ran[TM] = ranh / CP; + int32 rah[TM] = ranh % CP - pad_h; + int32 ra0[TM] = ran*lda_n + rah*lda_h + raw*lda_w; + int32 racr[TK] = rka / AS; + int32 ras[TK] = rka % AS; + int32 rac[TK] = racr / AR; + int32 rar[TK] = racr % AR; + int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; + fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis]; + fp32* pb[TN, TK] = b + rb1[newaxis, :]*CK + rb0[:, newaxis]; + __constant__ int32* pincd[TK] = delta + rka; + __constant__ int32* pd[TK] = delta + AR*AS + rka; + int32 d[TK] = *pd; + int32 incd[TK] = *pincd; + int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0); + int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0); + __constant__ int32* pm[TM] = masks + AR*AS + maskw*AR*AS + maskh*AR*AS*(2*pad_w + 1); + __constant__ int32* pincm[TM] = delta; + int32 incm[TM] = *pincm; + int32 checka0[TM] = *pm; + int32 checka1[TK] = 1 << rka; + int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; + fp32 a[TM, TK] = checka ? *pa : 0; + fp32 b[TN, TK] = *pb; + for(int32 k = K; k > 0; k = k - TK){ + C = dot(a, trans(b), C); + pb = pb + TK*CK; + pa = pa + d[newaxis, :]; + b = *pb; + pd = pd + incd; + pincd = pincd + incd; + d = *pd; + incd = *pincd; + pm = pm + incm; + pincm = pincm + incm; + incm = *pincm; + checka0 = *pm; + checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; + a = checka ? *pa : 0; + } + int32 rxc[TM] = get_global_range[TM](0); + int32 rc1[TN] = get_global_range[TN](1); + int32 rcn[TM] = rxc / (CP*CQ); + int32 rcpq[TM] = rxc % (CP*CQ); + int32 rc0[TM] = rcn * ldc_n + rcpq; + fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = rc1 < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + @checkc *pc = C; +})"; + +torch::Tensor conv_forward( + const torch::Tensor data, + const torch::Tensor weight) { // Check CHECK_INPUT(data); CHECK_INPUT(weight); @@ -21,10 +104,30 @@ at::Tensor conv_forward( const auto R = weight.size(1); const auto S = weight.size(2); const auto K = weight.size(3); - // Create output + // Allocate output AT_CHECK(Ci == Cf, "Number of channels in data and weights must match"); - return at::empty({B, K, H, W}, at::kFloat); + torch::Tensor output = torch::empty({B, K, H, W}, torch::kFloat); + // Wrap CUDA handles + triton::driver::cu_stream sstream(at::cuda::getCurrentCUDAStream(), false); + triton::driver::stream* stream = &sstream; + triton::driver::context* ctx = stream->context(); + triton::driver::cu_buffer d(ctx, (CUdeviceptr)data.storage().data(), false); + triton::driver::cu_buffer w(ctx, (CUdeviceptr)weight.storage().data(), false); + // Create JIT + triton::jit jit(ctx); + std::vector params = { + 16, 2, 64, + 32, 2, 64, + 16, 8, 2, 2, + 8, 8, + 4 + }; + jit.add_module("conv", src, params); + triton::driver::kernel* kernel = jit.get_function("conv"); + triton::jit::launch_information info = jit.get_launch_info("conv"); + + return output; } static auto registry = - torch::jit::RegisterOperators("triton::conv::forward", &conv_forward); + torch::jit::RegisterOperators("triton::conv_forward", &conv_forward); diff --git a/examples/python/pytorch/main.py b/examples/python/pytorch/main.py index b9984438b..d4b11e316 100644 --- a/examples/python/pytorch/main.py +++ b/examples/python/pytorch/main.py @@ -1,11 +1,9 @@ -import math -import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F from torch.autograd import Variable -from torch.utils.cpp_extension import load -from torch.distributions import categorical -from itertools import product -conv_triton = load( 'conv_triton', ['conv.cpp', 'conv.cu'], extra_cflags=['-O3']) +torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so") + +d = torch.empty(64, 64, 64, 64).uniform_(0, 1).cuda() +w = torch.empty(64, 3, 3, 64).uniform_(0, 1).cuda() +a = torch.ops.triton.conv_forward(d, w) +print(a) diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 1a1562c8f..9b71aea4f 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -171,11 +171,19 @@ void tune::run(ir::module &mod) { // Simplify metaparameters for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i : block->get_inst_list()) - if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ - ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 2)); - *params_.at(i).at("nts.d0") = *tmp; + for(ir::instruction *i : block->get_inst_list()){ + if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ + ir::type *ty = mod.get_builder().get_int32_ty(); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 2)); + *params_.at(i).at("nts.d0") = *tmp; + } + if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ + ir::type *ty = mod.get_builder().get_int32_ty(); + std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 2, 2)); + std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 2, 2)); + *params_.at(i).at("nts.d0") = *tmp1; + *params_.at(i).at("nts.d1") = *tmp2; + } } }