[examples] added skeleton for pytorch wrapper
This commit is contained in:
@@ -1 +1,2 @@
|
||||
add_subdirectory(tensorflow)
|
||||
add_subdirectory(pytorch)
|
||||
|
6
examples/python/pytorch/CMakeLists.txt
Normal file
6
examples/python/pytorch/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
find_package(Torch)
|
||||
if(${Torch_FOUND})
|
||||
add_library(torch_triton SHARED conv.cpp)
|
||||
target_compile_features(torch_triton PRIVATE cxx_range_for)
|
||||
target_link_libraries(torch_triton "${TORCH_LIBRARIES}")
|
||||
endif()
|
30
examples/python/pytorch/conv.cpp
Normal file
30
examples/python/pytorch/conv.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
at::Tensor conv_forward(
|
||||
const at::Tensor data,
|
||||
const at::Tensor weight) {
|
||||
// Check
|
||||
CHECK_INPUT(data);
|
||||
CHECK_INPUT(weight);
|
||||
// Unpack data shapes
|
||||
const auto B = data.size(0);
|
||||
const auto Ci = data.size(1);
|
||||
const auto H = data.size(2);
|
||||
const auto W = data.size(3);
|
||||
// Unpack weight shapes
|
||||
const auto Cf = weight.size(0);
|
||||
const auto R = weight.size(1);
|
||||
const auto S = weight.size(2);
|
||||
const auto K = weight.size(3);
|
||||
// Create output
|
||||
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
||||
return at::empty({B, K, H, W}, at::kFloat);
|
||||
}
|
||||
|
||||
static auto registry =
|
||||
torch::jit::RegisterOperators("triton::conv::forward", &conv_forward);
|
11
examples/python/pytorch/main.py
Normal file
11
examples/python/pytorch/main.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.cpp_extension import load
|
||||
from torch.distributions import categorical
|
||||
from itertools import product
|
||||
|
||||
conv_triton = load( 'conv_triton', ['conv.cpp', 'conv.cu'], extra_cflags=['-O3'])
|
@@ -1,14 +1,10 @@
|
||||
execute_process(COMMAND python -c "from os.path import dirname; import tensorflow as tf; print(dirname(dirname(tf.sysconfig.get_include())))"
|
||||
OUTPUT_VARIABLE TF_INC OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.sysconfig.get_lib())"
|
||||
OUTPUT_VARIABLE TF_LIB OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)"
|
||||
OUTPUT_VARIABLE TF_ABI OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
set(CUDA_HOME "/usr/local/cuda")
|
||||
include_directories("${TF_INC}/tensorflow/include")
|
||||
include_directories("${CUDA_HOME}/include")
|
||||
link_directories(${TF_LIB})
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||
add_library(tf_blocksparse SHARED blocksparse.cpp)
|
||||
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||
find_package(TensorFlow)
|
||||
if(${TensorFlow_FOUND})
|
||||
set(CUDA_HOME "/usr/local/cuda")
|
||||
include_directories("${TF_INC}/tensorflow/include")
|
||||
include_directories("${CUDA_HOME}/include")
|
||||
link_directories(${TF_LIB})
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||
add_library(tf_blocksparse SHARED blocksparse.cpp)
|
||||
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||
endif()
|
||||
|
Reference in New Issue
Block a user