[codegen/tune] bugfix in heuristics for nano-tile sizes
This commit is contained in:
@@ -17,5 +17,4 @@ if(TensorFlow_FOUND)
|
||||
set(TensorFlow_ABI ${TF_ABI})
|
||||
endif()
|
||||
|
||||
# hide locals from GUI
|
||||
mark_as_advanced(TF_INC TF_LIB TF_ABI)
|
||||
|
@@ -1,101 +1,11 @@
|
||||
# FindTorch
|
||||
# -------
|
||||
#
|
||||
# Finds the Torch library
|
||||
#
|
||||
# This will define the following variables:
|
||||
#
|
||||
# TORCH_FOUND -- True if the system has the Torch library
|
||||
# TORCH_INCLUDE_DIRS -- The include directories for torch
|
||||
# TORCH_LIBRARIES -- Libraries to link against
|
||||
# TORCH_CXX_FLAGS -- Additional (required) compiler flags
|
||||
#
|
||||
# and the following imported targets:
|
||||
#
|
||||
# torch
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
execute_process(COMMAND python -c "import torch; import os; print(os.path.dirname(torch.__file__))"
|
||||
OUTPUT_VARIABLE TORCH_INSTALL_PREFIX OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET)
|
||||
|
||||
if (DEFINED ENV{TORCH_INSTALL_PREFIX})
|
||||
set(TORCH_INSTALL_PREFIX $ENV{TORCH_INSTALL_PREFIX})
|
||||
else()
|
||||
# Assume we are in <install-prefix>/share/cmake/Torch/TorchConfig.cmake
|
||||
get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
|
||||
get_filename_component(TORCH_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
|
||||
find_package_handle_standard_args(TORCH DEFAULT_MSG TORCH_INSTALL_PREFIX)
|
||||
if(TORCH_INSTALL_PREFIX)
|
||||
set(TORCH_INCLUDE_DIRS ${TORCH_INSTALL_PREFIX}/lib/include/ ${TORCH_INSTALL_PREFIX}/lib/include/torch/csrc/api/include)
|
||||
set(TORCH_LIBRARY_DIRS ${TORCH_INSTALL_PREFIX}/lib/)
|
||||
endif()
|
||||
|
||||
# Include directories.
|
||||
if (EXISTS "${TORCH_INSTALL_PREFIX}/include")
|
||||
set(TORCH_INCLUDE_DIRS
|
||||
${TORCH_INSTALL_PREFIX}/include
|
||||
${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include)
|
||||
else()
|
||||
set(TORCH_INCLUDE_DIRS
|
||||
${TORCH_INSTALL_PREFIX}/include
|
||||
${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include)
|
||||
endif()
|
||||
|
||||
# Library dependencies.
|
||||
if (@BUILD_SHARED_LIBS@)
|
||||
find_package(Caffe2 REQUIRED PATHS ${CMAKE_CURRENT_LIST_DIR}/../Caffe2)
|
||||
endif()
|
||||
|
||||
if (NOT ANDROID)
|
||||
find_library(TORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||
else()
|
||||
find_library(TORCH_LIBRARY NO_CMAKE_FIND_ROOT_PATH torch PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||
endif()
|
||||
add_library(torch UNKNOWN IMPORTED)
|
||||
set(TORCH_LIBRARIES torch ${Caffe2_MAIN_LIBS})
|
||||
|
||||
if (NOT ANDROID)
|
||||
find_library(C10_LIBRARY c10 PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||
else()
|
||||
find_library(C10_LIBRARY c10 NO_CMAKE_FIND_ROOT_PATH PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||
endif()
|
||||
list(APPEND TORCH_LIBRARIES ${C10_LIBRARY})
|
||||
|
||||
if (@USE_CUDA@)
|
||||
if(MSVC)
|
||||
set(NVTOOLEXT_HOME "C:/Program Files/NVIDIA Corporation/NvToolsExt")
|
||||
if ($ENV{NVTOOLEXT_HOME})
|
||||
set(NVTOOLEXT_HOME $ENV{NVTOOLEXT_HOME})
|
||||
endif()
|
||||
set(TORCH_CUDA_LIBRARIES
|
||||
${NVTOOLEXT_HOME}/lib/x64/nvToolsExt64_1.lib
|
||||
${CUDA_LIBRARIES})
|
||||
list(APPEND TORCH_INCLUDE_DIRS ${NVTOOLEXT_HOME}/include)
|
||||
elseif(APPLE)
|
||||
set(TORCH_CUDA_LIBRARIES
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/libcudart.dylib
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/libnvrtc.dylib
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/libnvToolsExt.dylib
|
||||
${CUDA_LIBRARIES})
|
||||
else()
|
||||
find_library(LIBNVTOOLSEXT libnvToolsExt.so PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64/)
|
||||
set(TORCH_CUDA_LIBRARIES
|
||||
${CUDA_CUDA_LIB}
|
||||
${CUDA_NVRTC_LIB}
|
||||
${LIBNVTOOLSEXT}
|
||||
${CUDA_LIBRARIES})
|
||||
endif()
|
||||
find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||
list(APPEND TORCH_CUDA_LIBRARIES ${C10_CUDA_LIBRARY})
|
||||
list(APPEND TORCH_LIBRARIES ${TORCH_CUDA_LIBRARIES})
|
||||
endif()
|
||||
|
||||
# When we build libtorch with the old GCC ABI, dependent libraries must too.
|
||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||
set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=@GLIBCXX_USE_CXX11_ABI@")
|
||||
endif()
|
||||
|
||||
set_target_properties(torch PROPERTIES
|
||||
IMPORTED_LOCATION "${TORCH_LIBRARY}"
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}"
|
||||
CXX_STANDARD 11
|
||||
)
|
||||
if (TORCH_CXX_FLAGS)
|
||||
set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS "${TORCH_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
find_package_handle_standard_args(torch DEFAULT_MSG TORCH_LIBRARY TORCH_INCLUDE_DIRS)
|
||||
mark_as_advanced(TORCH_INCLUDE_DIRS TORCH_LIBRARY_DIRS)
|
||||
|
@@ -217,7 +217,7 @@ int main() {
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 8,
|
||||
8, 1, 8,
|
||||
4
|
||||
};
|
||||
// jit.autotune("conv", src, benchmark);
|
||||
|
@@ -1,6 +1,10 @@
|
||||
find_package(Torch)
|
||||
if(${Torch_FOUND})
|
||||
if(${TORCH_FOUND})
|
||||
set(CUDA_HOME "/usr/local/cuda")
|
||||
include_directories(${TORCH_INCLUDE_DIRS})
|
||||
include_directories("${CUDA_HOME}/include")
|
||||
link_directories(${TORCH_LIBRARY_DIRS})
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||
add_library(torch_triton SHARED conv.cpp)
|
||||
target_compile_features(torch_triton PRIVATE cxx_range_for)
|
||||
target_link_libraries(torch_triton "${TORCH_LIBRARIES}")
|
||||
target_link_libraries(torch_triton torch triton)
|
||||
endif()
|
||||
|
@@ -1,13 +1,96 @@
|
||||
#include <torch/torch.h>
|
||||
#include <torch/script.h>
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <vector>
|
||||
#include "triton/jit.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
at::Tensor conv_forward(
|
||||
const at::Tensor data,
|
||||
const at::Tensor weight) {
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[18];
|
||||
__constant__ int32* masks = alloc_const int32[1024];
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AN, int32 AH, int32 AW,
|
||||
int32 CN, int32 CK, int32 CP, int32 CQ,
|
||||
int32 AC, int32 AR, int32 AS,
|
||||
int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w,
|
||||
int32 bound){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rb1[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 ranh[TM] = rxa / CQ;
|
||||
int32 raw[TM] = rxa % CQ - pad_w;
|
||||
int32 ran[TM] = ranh / CP;
|
||||
int32 rah[TM] = ranh % CP - pad_h;
|
||||
int32 ra0[TM] = ran*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 racr[TK] = rka / AS;
|
||||
int32 ras[TK] = rka % AS;
|
||||
int32 rac[TK] = racr / AR;
|
||||
int32 rar[TK] = racr % AR;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rb1[newaxis, :]*CK + rb0[:, newaxis];
|
||||
__constant__ int32* pincd[TK] = delta + rka;
|
||||
__constant__ int32* pd[TK] = delta + AR*AS + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
|
||||
__constant__ int32* pm[TM] = masks + AR*AS + maskw*AR*AS + maskh*AR*AS*(2*pad_w + 1);
|
||||
__constant__ int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*CK;
|
||||
pa = pa + d[newaxis, :];
|
||||
b = *pb;
|
||||
pd = pd + incd;
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
checka0 = *pm;
|
||||
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
a = checka ? *pa : 0;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (CP*CQ);
|
||||
int32 rcpq[TM] = rxc % (CP*CQ);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
|
||||
torch::Tensor conv_forward(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor weight) {
|
||||
// Check
|
||||
CHECK_INPUT(data);
|
||||
CHECK_INPUT(weight);
|
||||
@@ -21,10 +104,30 @@ at::Tensor conv_forward(
|
||||
const auto R = weight.size(1);
|
||||
const auto S = weight.size(2);
|
||||
const auto K = weight.size(3);
|
||||
// Create output
|
||||
// Allocate output
|
||||
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
||||
return at::empty({B, K, H, W}, at::kFloat);
|
||||
torch::Tensor output = torch::empty({B, K, H, W}, torch::kFloat);
|
||||
// Wrap CUDA handles
|
||||
triton::driver::cu_stream sstream(at::cuda::getCurrentCUDAStream(), false);
|
||||
triton::driver::stream* stream = &sstream;
|
||||
triton::driver::context* ctx = stream->context();
|
||||
triton::driver::cu_buffer d(ctx, (CUdeviceptr)data.storage().data(), false);
|
||||
triton::driver::cu_buffer w(ctx, (CUdeviceptr)weight.storage().data(), false);
|
||||
// Create JIT
|
||||
triton::jit jit(ctx);
|
||||
std::vector<unsigned> params = {
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 8,
|
||||
4
|
||||
};
|
||||
jit.add_module("conv", src, params);
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
static auto registry =
|
||||
torch::jit::RegisterOperators("triton::conv::forward", &conv_forward);
|
||||
torch::jit::RegisterOperators("triton::conv_forward", &conv_forward);
|
||||
|
@@ -1,11 +1,9 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.cpp_extension import load
|
||||
from torch.distributions import categorical
|
||||
from itertools import product
|
||||
|
||||
conv_triton = load( 'conv_triton', ['conv.cpp', 'conv.cu'], extra_cflags=['-O3'])
|
||||
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||
|
||||
d = torch.empty(64, 64, 64, 64).uniform_(0, 1).cuda()
|
||||
w = torch.empty(64, 3, 3, 64).uniform_(0, 1).cuda()
|
||||
a = torch.ops.triton.conv_forward(d, w)
|
||||
print(a)
|
||||
|
@@ -171,11 +171,19 @@ void tune::run(ir::module &mod) {
|
||||
// Simplify metaparameters
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
for(ir::instruction *i : block->get_inst_list()){
|
||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user