[python][examples] added template for blocksparse

This commit is contained in:
Philippe Tillet
2019-09-03 20:44:27 -04:00
parent 5e03f0a065
commit 2ccc915011
9 changed files with 225 additions and 56 deletions

View File

@@ -42,6 +42,8 @@ if(BUILD_PYTHON_MODULE)
file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cc)
add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC})
target_link_libraries(extra_tf_ops triton ${TF_LIBS})
target_compile_definitions(extra_tf_ops PRIVATE "-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI}")
endif()

View File

@@ -250,10 +250,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
try{
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
}catch(exception::cuda::base const &){
//#ifdef TRITON_LOG_PTX_ERROR
#ifdef TRITON_LOG_PTX_ERROR
std::cerr << "Compilation Failed! Log: " << std::endl;
std::cerr << errbuf << std::endl;
//#endif
#endif
throw;
}
}

View File

@@ -0,0 +1,158 @@
import tensorflow as tf
import triton
import numpy as np
src = '''
#if AT == 1
#define USE_A ^a
#define STRIDE_AK lda
#define STRIDE_AM 1
#define BROADCAST_AK :, newaxis
#define BROADCAST_AM newaxis, :
#define SHAPE_A TK, TM
#else
#define USE_A a
#define STRIDE_AK 1
#define STRIDE_AM lda
#define BROADCAST_AK newaxis, :
#define BROADCAST_AM :, newaxis
#define SHAPE_A TM, TK
#endif
#if BT == 1
#define USE_B ^b
#define STRIDE_BK 1
#define STRIDE_BM ldb
#define BROADCAST_BN newaxis, :
#define BROADCAST_BK :, newaxis
#define SHAPE_B TN, TK
#else
#define USE_B b
#define STRIDE_BK ldb
#define STRIDE_BM 1
#define BROADCAST_BN :, newaxis
#define BROADCAST_BK newaxis, :
#define SHAPE_B TK, TN
#endif
void dot (TYPE* A __readonly __noalias __align(16),
TYPE* B __readonly __noalias __align(16),
TYPE* C __writeonly __noalias __align(16),
int lda, int ldb, int ldc,
int N, int* lut,
int* locks, int nlocks) {
int ridx = get_program_id(0);
float c[TM, TN] = 0;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
// load LUT header
int *header = lut + get_program_id(1) * 4;
int offset = *(header + 0);
int K = *(header + 1);
int column = *(header + 2);
int lockid = *(header + 3);
int *plut = lut + offset * 2;
int offx = ridx;
int offy = 0;
// compute x, y offsets
int rxa[TM] = offx * TM + (0 ... TM);
int ryb[TN] = offy * TN + (0 ... TN);
// bounds checking
bool checka[SHAPE_A] = (rxa < N)[:, newaxis];
bool checkb[SHAPE_B] = 1;
// base offset
int offa[SHAPE_A] = rxa[BROADCAST_AM] * STRIDE_AM + rka[BROADCAST_AK] * STRIDE_AK;
int offb[SHAPE_B] = ryb[BROADCAST_BN] * STRIDE_BN + rkb[BROADCAST_BK] * STRIDE_BK;
for(int k = K; k > 0; k -= 1) {
// fetch block indices
int ak = *(plut + 0);
int bk = *(plut + 1);
lut += 2;
// compute pointers to blocks
TYPE* pa[SHAPE_A] = A + offa + ak * TK * lda;
TYPE* pb[SHAPE_B] = B + offb + bk * TK * TN;
// load blocks
TYPE a[SHAPE_A] = checka ? *pa : 0;
TYPE b[SHAPE_B] = *pb;
// multiply blocks
c += USE_A @ USE_B;
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = column * TN + (0 ... TN);
TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
bool checkc[TM, TN] = (rxc < N)[:, newaxis];
if(lockid == 0) {
*?(checkc) pc = c;
}
else {
int *plock = locks + ridx*nlocks + lockid - 1;
int *pcount = plock + get_num_program(0)*nlocks;
while(__atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *pc;
__atomic_exch(pcount, 1);
__atomic_exch(plock, 0);
}
}
'''
# std::string dot::triton_c_src_dw() const {
# bool AT = (op_ == WGRAD);
# bool BT = (op_ == FPROP);
# std::string usea = AT ? "trans(a)" : "a";
# std::string useb = BT ? "trans(b)" : "b";
# std::string sizea = AT ? "TK, TM" : "TM, TK";
# std::string sizeb = BT ? "TN, TK" : "TK, TN";
# std::string bca0 = AT ? "newaxis, :" : ":, newaxis";
# std::string bca1 = AT ? ":, newaxis" : "newaxis, :";
# std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
# std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
# std::string lda0 = AT ? "*lda" : "";
# std::string lda1 = AT ? "" : "*lda";
# std::string ldb0 = BT ? "" : "*ldb";
# std::string ldb1 = BT ? "*ldb" : "" ;
# std::string result =
# R"(
# const tunable int TM = {)" + std::to_string(BS_) + R"(};
# const tunable int TN = {)" + std::to_string(BS_) + R"(};
# const tunable int TK = {32};
# void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
# restrict read_only align(16) )" + ab_ty_ + R"( *B,
# )" + c_ty_ + R"(* C,
# int lda, int ldb, int ldc,
# int N, int* lut,
# int* locks, int nlocks) {
# int ridx = get_range_id(0);
# float acc[TM, TN] = 0;
# int rka[TK] = 0 ... TK;
# int rkb[TK] = 0 ... TK;
# int *header = lut + ridx * 2;
# int offx = *(header + 0);
# int offy = *(header + 1);
# int rxa[TM] = offx*TM + (0 ... TM);
# int ryb[TN] = offy*TN + (0 ... TN);
# bool checka[TK, TM] = (rka < N)[:, newaxis];
# bool checkb[TK, TN] = (rkb < N)[:, newaxis];
# int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
# int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
# )" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa;
# )" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb;
# )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
# )" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0;
# for(int k = N; k > 0; k = k - TK) {
# acc = dot()" + usea + ", " + useb + R"(, acc);
# pa = pa + TK)" + lda1 + R"(;
# pb = pb + TK)" + ldb1 + R"(;
# a = checka ? *pa : 0;
# b = checkb ? *pb : 0;
# }
# int rxc[TM] = (0 ... TM);
# int ryc[TN] = (0 ... TN);
# )" + c_ty_ + R"( c[TM, TN] = acc;
# )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN;
# *pc = c;
# })";

View File

@@ -3,15 +3,16 @@ import triton
import numpy as np
src = """
// Templates for accessing A
#if AT == 1
#define USEA ^a
#define USE_A ^a
#define STRIDE_AK lda
#define STRIDE_AM 1
#define BROADCAST_AK :, newaxis
#define BROADCAST_AM newaxis, :
#define SHAPE_A TK, TM
#else
#define USEA a
#define USE_A a
#define STRIDE_AK 1
#define STRIDE_AM lda
#define BROADCAST_AK newaxis, :
@@ -19,15 +20,16 @@ src = """
#define SHAPE_A TM, TK
#endif
// Templates for accessing B
#if BT == 1
#define USEB ^b
#define USE_B ^b
#define STRIDE_BK 1
#define STRIDE_BN ldb
#define BROADCAST_BK newaxis, :
#define BROADCAST_BN :, newaxis
#define SHAPE_B TN, TK
#else
#define USEB b
#define USE_B b
#define STRIDE_BK ldb
#define STRIDE_BN 1
#define BROADCAST_BK :, newaxis
@@ -56,7 +58,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
TYPE b[SHAPE_B] = *pb;
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USEA @ USEB;
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
@@ -71,57 +73,54 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
}
"""
def cdiv(a, b):
return -(-a // b)
class dot_op:
def __init__(self, trans_a = False, trans_b = False):
def __init__(self, transpose_a = False, transpose_b = False):
self.dot = triton.op(src, ['C'])
self.trans_a = trans_a
self.trans_b = trans_b
self.transpose_a = transpose_a
self.transpose_b = transpose_b
def __call__(self, a, b):
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
M = shape_a[0]
Ka = shape_a[1]
Kb = shape_b[0]
N = shape_b[1]
M, Ka = shape_a[0], shape_a[1]
Kb, N = shape_b[0], shape_b[1]
# transpose shapes
if self.trans_a:
if self.transpose_a:
M, Ka = Ka, M
if self.trans_b:
if self.transpose_b:
Kb, N = N, Kb
K = Ka
# contiguous dimensions
lda = Ka
ldb = N
lda = M if self.transpose_a else Ka
ldb = Kb if self.transpose_b else N
ldc = N
# allocate output
c = triton.empty([M, N])
return self.dot(a, b, c, M, N, K, lda, ldb, ldc,
lambda opt: [cdiv(M, opt.d('TM')), cdiv(N, opt.d('TN'))],
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
TM = [128], TN = [ 128], TK = [32])
# compute
return self.dot(a, b, c, M, N, Ka, lda, ldb, ldc,
lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))],
AT = self.transpose_a, BT = self.transpose_b, TYPE = tf.float16,
TM = [128], TN = [128], TK = [32])
def dot(a, b, trans_a = False, trans_b = False):
if (trans_a, trans_b) not in dot.ops:
dot.ops[trans_a, trans_b] = dot_op(trans_a, trans_b)
return dot.ops[trans_a, trans_b](a, b)
def dot(a, b, transpose_a = False, transpose_b = False):
if (transpose_a, transpose_b) not in dot.ops:
dot.ops[transpose_a, transpose_b] = dot_op(transpose_a, transpose_b)
return dot.ops[transpose_a, transpose_b](a, b)
dot.ops = dict()
# @triton.register_gradient(dot_op)
# def _dot_grad(op, dy):
# a = op.inputs[0]
# b = op.inputs[1]
# return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
@tf.RegisterGradient("Dot")
def _dot_grad(op, dy):
a = op.inputs[0]
b = op.inputs[1]
return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
def run_dot():
M, N, K = 128, 128, 128
a = tf.placeholder(tf.float16, shape=[M, K])
b = tf.placeholder(tf.float16, shape=[N, K])
c = dot(a, b, trans_a = False, trans_b = True)
c = dot(a, b, transpose_a = False, transpose_b = False)
# Reference
ha = np.random.rand(M, K).astype(np.float16)
hb = np.random.rand(K, N).astype(np.float16)
@@ -131,7 +130,8 @@ def run_dot():
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
# Test
hresult = np.dot(ha, hb.T)
print(result)
hresult = np.dot(ha, hb)
dif = np.abs(result - hresult)
np.savetxt('dif.dat', dif, '%2.4f')
print("dif: %f" % np.max(dif))

View File

@@ -44,6 +44,7 @@ class CMakeBuild(build_ext):
import tensorflow as tf
tf_include_dirs = tf.sysconfig.get_include()
tf_lib_dirs = tf.sysconfig.get_lib()
tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0
tf_libs = 'tensorflow_framework'
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
@@ -52,7 +53,8 @@ class CMakeBuild(build_ext):
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
'-DTF_INCLUDE_DIRS=' + tf_include_dirs,
'-DTF_LIB_DIRS=' + tf_lib_dirs,
'-DTF_LIBS=' + tf_libs]
'-DTF_LIBS=' + tf_libs,
'-DTF_ABI=' + str(tf_abi)]
cfg = 'Debug' if self.debug else 'Release'
build_args = ['--config', cfg]

View File

@@ -4,7 +4,7 @@
#include <string>
#include <regex>
#include <algorithm>
#include "triton/codegen/selection/selection.h"
#include "triton/codegen/selection.h"
#include "triton/runtime/function.h"
#include "triton/lang/code_gen.h"
#include "triton/lang/parser.h"

View File

@@ -102,12 +102,15 @@ def _build(src, path, framework):
# libraries
libraries = ['triton']
# add framework
extra_compile_args = []
if framework == tensorflow_id:
_import_tensorflow()
library_dirs += [tensorflow.sysconfig.get_lib()]
include_dirs += [tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/']
libraries += ['tensorflow_framework']
ABI = tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)]
elif framework == torch_id:
_import_torch()
prefix = os.path.dirname(torch.__file__)
@@ -120,7 +123,6 @@ def _build(src, path, framework):
else:
assert False
# extra arguments
extra_compile_args = []
extra_link_args = []
# dependences
depends = [os.path.realpath(libtriton.__file__)]
@@ -254,14 +256,14 @@ class op:
return op(*op_args, id=op_id)
# class register_gradient:
class register_gradient:
# def __init__(self, op):
# self.op = op
def __init__(self, op):
self.op = op
# def __call__(self, f):
# name = 'Dot'
# ops.RegisterGradient(name)(f)
def __call__(self, f):
name = 'Dot'
ops.RegisterGradient(name)(f)
def empty(shapes, framework = None):
@@ -276,6 +278,9 @@ def empty(shapes, framework = None):
_import_torch()
return torch.empty(*shapes)
def cdiv(a, b):
return -(-a // b)
class scalar:
def __init__(self, x):

View File

@@ -22,8 +22,8 @@ std::vector<double> do_bench(drv::stream* stream, int32_t N){
// create options
rt::function::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"TN", {"512"}});
opt.num_warps = {4};
opt.defines.push_back({"TN", {"128"}});
opt.num_warps = {1, 2, 4, 8};
// create function
rt::function function(src::copy1d, opt);
// benchmark available libraries
@@ -42,7 +42,7 @@ int main() {
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<int> config_t;
std::vector<config_t> configs = { 1024*1024*16 };
std::vector<config_t> configs = { 1024*1024*32 };
int N;
for(const auto& c: configs){
std::tie(N) = c;

View File

@@ -29,6 +29,7 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float NumericT;
std::string ty = "float";
cublasDataType_t cuty = CUDA_R_32F;
size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context();
// leading dimensions
@@ -44,10 +45,10 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"AT", {AT?"1":"0"}});
opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"64"}});
opt.defines.push_back({"TM", {"64", "128"}});
opt.defines.push_back({"TN", {"64", "128"}});
opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {4};
opt.num_warps = {2, 4, 8};
// create function
rt::function function(src::dot, opt);
// benchmark available libraries
@@ -57,10 +58,11 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
if(cublas::cublasinit()){
NumericT alpha(static_cast<double>(1));
NumericT beta(static_cast<double>(0));
cublasGemmAlgo_t fastest = CUBLAS_GEMM_ALGO5;
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K,
&alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream);
cublasGemmAlgo_t fastest;
cublasGemm(cuty, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
double cublas_ms = triton::tools::bench([&]() { cublasGemm(cuty, stream, AT, BT, M, N, K,
&alpha, &*da, lda, &*db, ldb, &beta, &*dc,
ldc, nullptr, fastest); }, stream);
result.push_back(tflops(cublas_ms));
}
// triton