[examples] added skeleton for pytorch wrapper
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 2.8)
|
cmake_minimum_required(VERSION 2.8)
|
||||||
project(triton)
|
project(triton)
|
||||||
include(CTest)
|
include(CTest)
|
||||||
|
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||||
|
|
||||||
# FLEX/YACC
|
# FLEX/YACC
|
||||||
find_package(BISON)
|
find_package(BISON)
|
||||||
@@ -16,7 +17,6 @@ include_directories(${BISON_Parser_INCLUDE_DIRECTORIES})
|
|||||||
|
|
||||||
# LLVM
|
# LLVM
|
||||||
find_package(LLVM REQUIRED CONFIG)
|
find_package(LLVM REQUIRED CONFIG)
|
||||||
message(STATUS ${LLVM_INCLUDE_DIRS})
|
|
||||||
include_directories(${LLVM_INCLUDE_DIRS})
|
include_directories(${LLVM_INCLUDE_DIRS})
|
||||||
add_definitions(${LLVM_DEFINITIONS})
|
add_definitions(${LLVM_DEFINITIONS})
|
||||||
#llvm_map_components_to_libnames(llvm_libs all)
|
#llvm_map_components_to_libnames(llvm_libs all)
|
||||||
|
21
cmake/FindTensorFlow.cmake
Normal file
21
cmake/FindTensorFlow.cmake
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
unset(TENSORFLOW_FOUND)
|
||||||
|
|
||||||
|
execute_process(COMMAND python -c "from os.path import dirname; import tensorflow as tf; print(dirname(dirname(tf.sysconfig.get_include())))"
|
||||||
|
OUTPUT_VARIABLE TF_INC OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET)
|
||||||
|
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.sysconfig.get_lib())"
|
||||||
|
OUTPUT_VARIABLE TF_LIB OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET)
|
||||||
|
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)"
|
||||||
|
OUTPUT_VARIABLE TF_ABI OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET)
|
||||||
|
|
||||||
|
find_package_handle_standard_args(TensorFlow DEFAULT_MSG TF_INC TF_LIB)
|
||||||
|
|
||||||
|
# set external variables for usage in CMakeLists.txt
|
||||||
|
if(TensorFlow_FOUND)
|
||||||
|
set(TensorFlow_LIBRARIES ${TF_LIB})
|
||||||
|
set(TensorFlow_INCLUDE_DIRS ${TF_INC})
|
||||||
|
set(TensorFlow_ABI ${TF_ABI})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# hide locals from GUI
|
||||||
|
mark_as_advanced(TF_INC TF_LIB TF_ABI)
|
101
cmake/FindTorch.cmake
Normal file
101
cmake/FindTorch.cmake
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
# FindTorch
|
||||||
|
# -------
|
||||||
|
#
|
||||||
|
# Finds the Torch library
|
||||||
|
#
|
||||||
|
# This will define the following variables:
|
||||||
|
#
|
||||||
|
# TORCH_FOUND -- True if the system has the Torch library
|
||||||
|
# TORCH_INCLUDE_DIRS -- The include directories for torch
|
||||||
|
# TORCH_LIBRARIES -- Libraries to link against
|
||||||
|
# TORCH_CXX_FLAGS -- Additional (required) compiler flags
|
||||||
|
#
|
||||||
|
# and the following imported targets:
|
||||||
|
#
|
||||||
|
# torch
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
|
||||||
|
if (DEFINED ENV{TORCH_INSTALL_PREFIX})
|
||||||
|
set(TORCH_INSTALL_PREFIX $ENV{TORCH_INSTALL_PREFIX})
|
||||||
|
else()
|
||||||
|
# Assume we are in <install-prefix>/share/cmake/Torch/TorchConfig.cmake
|
||||||
|
get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
|
||||||
|
get_filename_component(TORCH_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Include directories.
|
||||||
|
if (EXISTS "${TORCH_INSTALL_PREFIX}/include")
|
||||||
|
set(TORCH_INCLUDE_DIRS
|
||||||
|
${TORCH_INSTALL_PREFIX}/include
|
||||||
|
${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include)
|
||||||
|
else()
|
||||||
|
set(TORCH_INCLUDE_DIRS
|
||||||
|
${TORCH_INSTALL_PREFIX}/include
|
||||||
|
${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Library dependencies.
|
||||||
|
if (@BUILD_SHARED_LIBS@)
|
||||||
|
find_package(Caffe2 REQUIRED PATHS ${CMAKE_CURRENT_LIST_DIR}/../Caffe2)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (NOT ANDROID)
|
||||||
|
find_library(TORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
else()
|
||||||
|
find_library(TORCH_LIBRARY NO_CMAKE_FIND_ROOT_PATH torch PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
endif()
|
||||||
|
add_library(torch UNKNOWN IMPORTED)
|
||||||
|
set(TORCH_LIBRARIES torch ${Caffe2_MAIN_LIBS})
|
||||||
|
|
||||||
|
if (NOT ANDROID)
|
||||||
|
find_library(C10_LIBRARY c10 PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
else()
|
||||||
|
find_library(C10_LIBRARY c10 NO_CMAKE_FIND_ROOT_PATH PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
endif()
|
||||||
|
list(APPEND TORCH_LIBRARIES ${C10_LIBRARY})
|
||||||
|
|
||||||
|
if (@USE_CUDA@)
|
||||||
|
if(MSVC)
|
||||||
|
set(NVTOOLEXT_HOME "C:/Program Files/NVIDIA Corporation/NvToolsExt")
|
||||||
|
if ($ENV{NVTOOLEXT_HOME})
|
||||||
|
set(NVTOOLEXT_HOME $ENV{NVTOOLEXT_HOME})
|
||||||
|
endif()
|
||||||
|
set(TORCH_CUDA_LIBRARIES
|
||||||
|
${NVTOOLEXT_HOME}/lib/x64/nvToolsExt64_1.lib
|
||||||
|
${CUDA_LIBRARIES})
|
||||||
|
list(APPEND TORCH_INCLUDE_DIRS ${NVTOOLEXT_HOME}/include)
|
||||||
|
elseif(APPLE)
|
||||||
|
set(TORCH_CUDA_LIBRARIES
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib/libcudart.dylib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib/libnvrtc.dylib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib/libnvToolsExt.dylib
|
||||||
|
${CUDA_LIBRARIES})
|
||||||
|
else()
|
||||||
|
find_library(LIBNVTOOLSEXT libnvToolsExt.so PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64/)
|
||||||
|
set(TORCH_CUDA_LIBRARIES
|
||||||
|
${CUDA_CUDA_LIB}
|
||||||
|
${CUDA_NVRTC_LIB}
|
||||||
|
${LIBNVTOOLSEXT}
|
||||||
|
${CUDA_LIBRARIES})
|
||||||
|
endif()
|
||||||
|
find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
list(APPEND TORCH_CUDA_LIBRARIES ${C10_CUDA_LIBRARY})
|
||||||
|
list(APPEND TORCH_LIBRARIES ${TORCH_CUDA_LIBRARIES})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# When we build libtorch with the old GCC ABI, dependent libraries must too.
|
||||||
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
|
set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=@GLIBCXX_USE_CXX11_ABI@")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set_target_properties(torch PROPERTIES
|
||||||
|
IMPORTED_LOCATION "${TORCH_LIBRARY}"
|
||||||
|
INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}"
|
||||||
|
CXX_STANDARD 11
|
||||||
|
)
|
||||||
|
if (TORCH_CXX_FLAGS)
|
||||||
|
set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS "${TORCH_CXX_FLAGS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_package_handle_standard_args(torch DEFAULT_MSG TORCH_LIBRARY TORCH_INCLUDE_DIRS)
|
@@ -1 +1,2 @@
|
|||||||
add_subdirectory(tensorflow)
|
add_subdirectory(tensorflow)
|
||||||
|
add_subdirectory(pytorch)
|
||||||
|
6
examples/python/pytorch/CMakeLists.txt
Normal file
6
examples/python/pytorch/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
find_package(Torch)
|
||||||
|
if(${Torch_FOUND})
|
||||||
|
add_library(torch_triton SHARED conv.cpp)
|
||||||
|
target_compile_features(torch_triton PRIVATE cxx_range_for)
|
||||||
|
target_link_libraries(torch_triton "${TORCH_LIBRARIES}")
|
||||||
|
endif()
|
30
examples/python/pytorch/conv.cpp
Normal file
30
examples/python/pytorch/conv.cpp
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#include <torch/torch.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
at::Tensor conv_forward(
|
||||||
|
const at::Tensor data,
|
||||||
|
const at::Tensor weight) {
|
||||||
|
// Check
|
||||||
|
CHECK_INPUT(data);
|
||||||
|
CHECK_INPUT(weight);
|
||||||
|
// Unpack data shapes
|
||||||
|
const auto B = data.size(0);
|
||||||
|
const auto Ci = data.size(1);
|
||||||
|
const auto H = data.size(2);
|
||||||
|
const auto W = data.size(3);
|
||||||
|
// Unpack weight shapes
|
||||||
|
const auto Cf = weight.size(0);
|
||||||
|
const auto R = weight.size(1);
|
||||||
|
const auto S = weight.size(2);
|
||||||
|
const auto K = weight.size(3);
|
||||||
|
// Create output
|
||||||
|
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
||||||
|
return at::empty({B, K, H, W}, at::kFloat);
|
||||||
|
}
|
||||||
|
|
||||||
|
static auto registry =
|
||||||
|
torch::jit::RegisterOperators("triton::conv::forward", &conv_forward);
|
11
examples/python/pytorch/main.py
Normal file
11
examples/python/pytorch/main.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
from torch.distributions import categorical
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
conv_triton = load( 'conv_triton', ['conv.cpp', 'conv.cu'], extra_cflags=['-O3'])
|
@@ -1,14 +1,10 @@
|
|||||||
execute_process(COMMAND python -c "from os.path import dirname; import tensorflow as tf; print(dirname(dirname(tf.sysconfig.get_include())))"
|
find_package(TensorFlow)
|
||||||
OUTPUT_VARIABLE TF_INC OUTPUT_STRIP_TRAILING_WHITESPACE)
|
if(${TensorFlow_FOUND})
|
||||||
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.sysconfig.get_lib())"
|
set(CUDA_HOME "/usr/local/cuda")
|
||||||
OUTPUT_VARIABLE TF_LIB OUTPUT_STRIP_TRAILING_WHITESPACE)
|
include_directories("${TF_INC}/tensorflow/include")
|
||||||
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)"
|
include_directories("${CUDA_HOME}/include")
|
||||||
OUTPUT_VARIABLE TF_ABI OUTPUT_STRIP_TRAILING_WHITESPACE)
|
link_directories(${TF_LIB})
|
||||||
|
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||||
set(CUDA_HOME "/usr/local/cuda")
|
add_library(tf_blocksparse SHARED blocksparse.cpp)
|
||||||
include_directories("${TF_INC}/tensorflow/include")
|
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||||
include_directories("${CUDA_HOME}/include")
|
endif()
|
||||||
link_directories(${TF_LIB})
|
|
||||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
|
||||||
add_library(tf_blocksparse SHARED blocksparse.cpp)
|
|
||||||
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
|
||||||
|
Reference in New Issue
Block a user