[python][examples] added template for blocksparse
This commit is contained in:
@@ -42,6 +42,8 @@ if(BUILD_PYTHON_MODULE)
|
|||||||
file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cc)
|
file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cc)
|
||||||
add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC})
|
add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC})
|
||||||
target_link_libraries(extra_tf_ops triton ${TF_LIBS})
|
target_link_libraries(extra_tf_ops triton ${TF_LIBS})
|
||||||
|
target_compile_definitions(extra_tf_ops PRIVATE "-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI}")
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
@@ -250,10 +250,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
|
|||||||
try{
|
try{
|
||||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
||||||
}catch(exception::cuda::base const &){
|
}catch(exception::cuda::base const &){
|
||||||
//#ifdef TRITON_LOG_PTX_ERROR
|
#ifdef TRITON_LOG_PTX_ERROR
|
||||||
std::cerr << "Compilation Failed! Log: " << std::endl;
|
std::cerr << "Compilation Failed! Log: " << std::endl;
|
||||||
std::cerr << errbuf << std::endl;
|
std::cerr << errbuf << std::endl;
|
||||||
//#endif
|
#endif
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
158
python/examples/blocksparse.py
Normal file
158
python/examples/blocksparse.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
import triton
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
src = '''
|
||||||
|
#if AT == 1
|
||||||
|
#define USE_A ^a
|
||||||
|
#define STRIDE_AK lda
|
||||||
|
#define STRIDE_AM 1
|
||||||
|
#define BROADCAST_AK :, newaxis
|
||||||
|
#define BROADCAST_AM newaxis, :
|
||||||
|
#define SHAPE_A TK, TM
|
||||||
|
#else
|
||||||
|
#define USE_A a
|
||||||
|
#define STRIDE_AK 1
|
||||||
|
#define STRIDE_AM lda
|
||||||
|
#define BROADCAST_AK newaxis, :
|
||||||
|
#define BROADCAST_AM :, newaxis
|
||||||
|
#define SHAPE_A TM, TK
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if BT == 1
|
||||||
|
#define USE_B ^b
|
||||||
|
#define STRIDE_BK 1
|
||||||
|
#define STRIDE_BM ldb
|
||||||
|
#define BROADCAST_BN newaxis, :
|
||||||
|
#define BROADCAST_BK :, newaxis
|
||||||
|
#define SHAPE_B TN, TK
|
||||||
|
#else
|
||||||
|
#define USE_B b
|
||||||
|
#define STRIDE_BK ldb
|
||||||
|
#define STRIDE_BM 1
|
||||||
|
#define BROADCAST_BN :, newaxis
|
||||||
|
#define BROADCAST_BK newaxis, :
|
||||||
|
#define SHAPE_B TK, TN
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void dot (TYPE* A __readonly __noalias __align(16),
|
||||||
|
TYPE* B __readonly __noalias __align(16),
|
||||||
|
TYPE* C __writeonly __noalias __align(16),
|
||||||
|
int lda, int ldb, int ldc,
|
||||||
|
int N, int* lut,
|
||||||
|
int* locks, int nlocks) {
|
||||||
|
int ridx = get_program_id(0);
|
||||||
|
float c[TM, TN] = 0;
|
||||||
|
int rka[TK] = 0 ... TK;
|
||||||
|
int rkb[TK] = 0 ... TK;
|
||||||
|
// load LUT header
|
||||||
|
int *header = lut + get_program_id(1) * 4;
|
||||||
|
int offset = *(header + 0);
|
||||||
|
int K = *(header + 1);
|
||||||
|
int column = *(header + 2);
|
||||||
|
int lockid = *(header + 3);
|
||||||
|
int *plut = lut + offset * 2;
|
||||||
|
int offx = ridx;
|
||||||
|
int offy = 0;
|
||||||
|
// compute x, y offsets
|
||||||
|
int rxa[TM] = offx * TM + (0 ... TM);
|
||||||
|
int ryb[TN] = offy * TN + (0 ... TN);
|
||||||
|
// bounds checking
|
||||||
|
bool checka[SHAPE_A] = (rxa < N)[:, newaxis];
|
||||||
|
bool checkb[SHAPE_B] = 1;
|
||||||
|
// base offset
|
||||||
|
int offa[SHAPE_A] = rxa[BROADCAST_AM] * STRIDE_AM + rka[BROADCAST_AK] * STRIDE_AK;
|
||||||
|
int offb[SHAPE_B] = ryb[BROADCAST_BN] * STRIDE_BN + rkb[BROADCAST_BK] * STRIDE_BK;
|
||||||
|
for(int k = K; k > 0; k -= 1) {
|
||||||
|
// fetch block indices
|
||||||
|
int ak = *(plut + 0);
|
||||||
|
int bk = *(plut + 1);
|
||||||
|
lut += 2;
|
||||||
|
// compute pointers to blocks
|
||||||
|
TYPE* pa[SHAPE_A] = A + offa + ak * TK * lda;
|
||||||
|
TYPE* pb[SHAPE_B] = B + offb + bk * TK * TN;
|
||||||
|
// load blocks
|
||||||
|
TYPE a[SHAPE_A] = checka ? *pa : 0;
|
||||||
|
TYPE b[SHAPE_B] = *pb;
|
||||||
|
// multiply blocks
|
||||||
|
c += USE_A @ USE_B;
|
||||||
|
}
|
||||||
|
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||||
|
int ryc[TN] = column * TN + (0 ... TN);
|
||||||
|
TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||||
|
bool checkc[TM, TN] = (rxc < N)[:, newaxis];
|
||||||
|
if(lockid == 0) {
|
||||||
|
*?(checkc) pc = c;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
int *plock = locks + ridx*nlocks + lockid - 1;
|
||||||
|
int *pcount = plock + get_num_program(0)*nlocks;
|
||||||
|
while(__atomic_cas(plock, 0, 1));
|
||||||
|
int count = *pcount;
|
||||||
|
if(count == 0)
|
||||||
|
*?(checkc) pc = c;
|
||||||
|
else
|
||||||
|
*?(checkc) pc = c + *pc;
|
||||||
|
__atomic_exch(pcount, 1);
|
||||||
|
__atomic_exch(plock, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
# std::string dot::triton_c_src_dw() const {
|
||||||
|
# bool AT = (op_ == WGRAD);
|
||||||
|
# bool BT = (op_ == FPROP);
|
||||||
|
# std::string usea = AT ? "trans(a)" : "a";
|
||||||
|
# std::string useb = BT ? "trans(b)" : "b";
|
||||||
|
# std::string sizea = AT ? "TK, TM" : "TM, TK";
|
||||||
|
# std::string sizeb = BT ? "TN, TK" : "TK, TN";
|
||||||
|
# std::string bca0 = AT ? "newaxis, :" : ":, newaxis";
|
||||||
|
# std::string bca1 = AT ? ":, newaxis" : "newaxis, :";
|
||||||
|
# std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
|
||||||
|
# std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
|
||||||
|
# std::string lda0 = AT ? "*lda" : "";
|
||||||
|
# std::string lda1 = AT ? "" : "*lda";
|
||||||
|
# std::string ldb0 = BT ? "" : "*ldb";
|
||||||
|
# std::string ldb1 = BT ? "*ldb" : "" ;
|
||||||
|
# std::string result =
|
||||||
|
# R"(
|
||||||
|
# const tunable int TM = {)" + std::to_string(BS_) + R"(};
|
||||||
|
# const tunable int TN = {)" + std::to_string(BS_) + R"(};
|
||||||
|
# const tunable int TK = {32};
|
||||||
|
# void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||||
|
# restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||||
|
# )" + c_ty_ + R"(* C,
|
||||||
|
# int lda, int ldb, int ldc,
|
||||||
|
# int N, int* lut,
|
||||||
|
# int* locks, int nlocks) {
|
||||||
|
# int ridx = get_range_id(0);
|
||||||
|
# float acc[TM, TN] = 0;
|
||||||
|
# int rka[TK] = 0 ... TK;
|
||||||
|
# int rkb[TK] = 0 ... TK;
|
||||||
|
# int *header = lut + ridx * 2;
|
||||||
|
# int offx = *(header + 0);
|
||||||
|
# int offy = *(header + 1);
|
||||||
|
# int rxa[TM] = offx*TM + (0 ... TM);
|
||||||
|
# int ryb[TN] = offy*TN + (0 ... TN);
|
||||||
|
# bool checka[TK, TM] = (rka < N)[:, newaxis];
|
||||||
|
# bool checkb[TK, TN] = (rkb < N)[:, newaxis];
|
||||||
|
# int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
|
||||||
|
# int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||||
|
# )" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa;
|
||||||
|
# )" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb;
|
||||||
|
# )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
|
||||||
|
# )" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0;
|
||||||
|
# for(int k = N; k > 0; k = k - TK) {
|
||||||
|
# acc = dot()" + usea + ", " + useb + R"(, acc);
|
||||||
|
# pa = pa + TK)" + lda1 + R"(;
|
||||||
|
# pb = pb + TK)" + ldb1 + R"(;
|
||||||
|
# a = checka ? *pa : 0;
|
||||||
|
# b = checkb ? *pb : 0;
|
||||||
|
# }
|
||||||
|
# int rxc[TM] = (0 ... TM);
|
||||||
|
# int ryc[TN] = (0 ... TN);
|
||||||
|
# )" + c_ty_ + R"( c[TM, TN] = acc;
|
||||||
|
# )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN;
|
||||||
|
# *pc = c;
|
||||||
|
# })";
|
@@ -3,15 +3,16 @@ import triton
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
src = """
|
src = """
|
||||||
|
// Templates for accessing A
|
||||||
#if AT == 1
|
#if AT == 1
|
||||||
#define USEA ^a
|
#define USE_A ^a
|
||||||
#define STRIDE_AK lda
|
#define STRIDE_AK lda
|
||||||
#define STRIDE_AM 1
|
#define STRIDE_AM 1
|
||||||
#define BROADCAST_AK :, newaxis
|
#define BROADCAST_AK :, newaxis
|
||||||
#define BROADCAST_AM newaxis, :
|
#define BROADCAST_AM newaxis, :
|
||||||
#define SHAPE_A TK, TM
|
#define SHAPE_A TK, TM
|
||||||
#else
|
#else
|
||||||
#define USEA a
|
#define USE_A a
|
||||||
#define STRIDE_AK 1
|
#define STRIDE_AK 1
|
||||||
#define STRIDE_AM lda
|
#define STRIDE_AM lda
|
||||||
#define BROADCAST_AK newaxis, :
|
#define BROADCAST_AK newaxis, :
|
||||||
@@ -19,15 +20,16 @@ src = """
|
|||||||
#define SHAPE_A TM, TK
|
#define SHAPE_A TM, TK
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Templates for accessing B
|
||||||
#if BT == 1
|
#if BT == 1
|
||||||
#define USEB ^b
|
#define USE_B ^b
|
||||||
#define STRIDE_BK 1
|
#define STRIDE_BK 1
|
||||||
#define STRIDE_BN ldb
|
#define STRIDE_BN ldb
|
||||||
#define BROADCAST_BK newaxis, :
|
#define BROADCAST_BK newaxis, :
|
||||||
#define BROADCAST_BN :, newaxis
|
#define BROADCAST_BN :, newaxis
|
||||||
#define SHAPE_B TN, TK
|
#define SHAPE_B TN, TK
|
||||||
#else
|
#else
|
||||||
#define USEB b
|
#define USE_B b
|
||||||
#define STRIDE_BK ldb
|
#define STRIDE_BK ldb
|
||||||
#define STRIDE_BN 1
|
#define STRIDE_BN 1
|
||||||
#define BROADCAST_BK :, newaxis
|
#define BROADCAST_BK :, newaxis
|
||||||
@@ -56,7 +58,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
TYPE b[SHAPE_B] = *pb;
|
TYPE b[SHAPE_B] = *pb;
|
||||||
// reduction loop
|
// reduction loop
|
||||||
for(int k = K; k > 0; k-= TK){
|
for(int k = K; k > 0; k-= TK){
|
||||||
c += USEA @ USEB;
|
c += USE_A @ USE_B;
|
||||||
pa = pa + TK * STRIDE_AK;
|
pa = pa + TK * STRIDE_AK;
|
||||||
pb = pb + TK * STRIDE_BK;
|
pb = pb + TK * STRIDE_BK;
|
||||||
a = *pa;
|
a = *pa;
|
||||||
@@ -71,57 +73,54 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def cdiv(a, b):
|
|
||||||
return -(-a // b)
|
|
||||||
|
|
||||||
class dot_op:
|
class dot_op:
|
||||||
|
|
||||||
def __init__(self, trans_a = False, trans_b = False):
|
def __init__(self, transpose_a = False, transpose_b = False):
|
||||||
self.dot = triton.op(src, ['C'])
|
self.dot = triton.op(src, ['C'])
|
||||||
self.trans_a = trans_a
|
self.transpose_a = transpose_a
|
||||||
self.trans_b = trans_b
|
self.transpose_b = transpose_b
|
||||||
|
|
||||||
def __call__(self, a, b):
|
def __call__(self, a, b):
|
||||||
|
# extract shapes
|
||||||
shape_a = triton.shape(a)
|
shape_a = triton.shape(a)
|
||||||
shape_b = triton.shape(b)
|
shape_b = triton.shape(b)
|
||||||
M = shape_a[0]
|
M, Ka = shape_a[0], shape_a[1]
|
||||||
Ka = shape_a[1]
|
Kb, N = shape_b[0], shape_b[1]
|
||||||
Kb = shape_b[0]
|
|
||||||
N = shape_b[1]
|
|
||||||
# transpose shapes
|
# transpose shapes
|
||||||
if self.trans_a:
|
if self.transpose_a:
|
||||||
M, Ka = Ka, M
|
M, Ka = Ka, M
|
||||||
if self.trans_b:
|
if self.transpose_b:
|
||||||
Kb, N = N, Kb
|
Kb, N = N, Kb
|
||||||
K = Ka
|
|
||||||
# contiguous dimensions
|
# contiguous dimensions
|
||||||
lda = Ka
|
lda = M if self.transpose_a else Ka
|
||||||
ldb = N
|
ldb = Kb if self.transpose_b else N
|
||||||
ldc = N
|
ldc = N
|
||||||
|
# allocate output
|
||||||
c = triton.empty([M, N])
|
c = triton.empty([M, N])
|
||||||
return self.dot(a, b, c, M, N, K, lda, ldb, ldc,
|
# compute
|
||||||
lambda opt: [cdiv(M, opt.d('TM')), cdiv(N, opt.d('TN'))],
|
return self.dot(a, b, c, M, N, Ka, lda, ldb, ldc,
|
||||||
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
|
lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))],
|
||||||
TM = [128], TN = [ 128], TK = [32])
|
AT = self.transpose_a, BT = self.transpose_b, TYPE = tf.float16,
|
||||||
|
TM = [128], TN = [128], TK = [32])
|
||||||
|
|
||||||
|
|
||||||
def dot(a, b, trans_a = False, trans_b = False):
|
def dot(a, b, transpose_a = False, transpose_b = False):
|
||||||
if (trans_a, trans_b) not in dot.ops:
|
if (transpose_a, transpose_b) not in dot.ops:
|
||||||
dot.ops[trans_a, trans_b] = dot_op(trans_a, trans_b)
|
dot.ops[transpose_a, transpose_b] = dot_op(transpose_a, transpose_b)
|
||||||
return dot.ops[trans_a, trans_b](a, b)
|
return dot.ops[transpose_a, transpose_b](a, b)
|
||||||
dot.ops = dict()
|
dot.ops = dict()
|
||||||
|
|
||||||
# @triton.register_gradient(dot_op)
|
@tf.RegisterGradient("Dot")
|
||||||
# def _dot_grad(op, dy):
|
def _dot_grad(op, dy):
|
||||||
# a = op.inputs[0]
|
a = op.inputs[0]
|
||||||
# b = op.inputs[1]
|
b = op.inputs[1]
|
||||||
# return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
|
return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
|
||||||
|
|
||||||
def run_dot():
|
def run_dot():
|
||||||
M, N, K = 128, 128, 128
|
M, N, K = 128, 128, 128
|
||||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||||
c = dot(a, b, trans_a = False, trans_b = True)
|
c = dot(a, b, transpose_a = False, transpose_b = False)
|
||||||
# Reference
|
# Reference
|
||||||
ha = np.random.rand(M, K).astype(np.float16)
|
ha = np.random.rand(M, K).astype(np.float16)
|
||||||
hb = np.random.rand(K, N).astype(np.float16)
|
hb = np.random.rand(K, N).astype(np.float16)
|
||||||
@@ -131,7 +130,8 @@ def run_dot():
|
|||||||
result = sess.run([c], feed_dict = {a: ha,
|
result = sess.run([c], feed_dict = {a: ha,
|
||||||
b: hb})[0]
|
b: hb})[0]
|
||||||
# Test
|
# Test
|
||||||
hresult = np.dot(ha, hb.T)
|
print(result)
|
||||||
|
hresult = np.dot(ha, hb)
|
||||||
dif = np.abs(result - hresult)
|
dif = np.abs(result - hresult)
|
||||||
np.savetxt('dif.dat', dif, '%2.4f')
|
np.savetxt('dif.dat', dif, '%2.4f')
|
||||||
print("dif: %f" % np.max(dif))
|
print("dif: %f" % np.max(dif))
|
||||||
|
@@ -44,6 +44,7 @@ class CMakeBuild(build_ext):
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
tf_include_dirs = tf.sysconfig.get_include()
|
tf_include_dirs = tf.sysconfig.get_include()
|
||||||
tf_lib_dirs = tf.sysconfig.get_lib()
|
tf_lib_dirs = tf.sysconfig.get_lib()
|
||||||
|
tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0
|
||||||
tf_libs = 'tensorflow_framework'
|
tf_libs = 'tensorflow_framework'
|
||||||
|
|
||||||
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
||||||
@@ -52,7 +53,8 @@ class CMakeBuild(build_ext):
|
|||||||
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
|
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
|
||||||
'-DTF_INCLUDE_DIRS=' + tf_include_dirs,
|
'-DTF_INCLUDE_DIRS=' + tf_include_dirs,
|
||||||
'-DTF_LIB_DIRS=' + tf_lib_dirs,
|
'-DTF_LIB_DIRS=' + tf_lib_dirs,
|
||||||
'-DTF_LIBS=' + tf_libs]
|
'-DTF_LIBS=' + tf_libs,
|
||||||
|
'-DTF_ABI=' + str(tf_abi)]
|
||||||
|
|
||||||
cfg = 'Debug' if self.debug else 'Release'
|
cfg = 'Debug' if self.debug else 'Release'
|
||||||
build_args = ['--config', cfg]
|
build_args = ['--config', cfg]
|
||||||
|
@@ -4,7 +4,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "triton/codegen/selection/selection.h"
|
#include "triton/codegen/selection.h"
|
||||||
#include "triton/runtime/function.h"
|
#include "triton/runtime/function.h"
|
||||||
#include "triton/lang/code_gen.h"
|
#include "triton/lang/code_gen.h"
|
||||||
#include "triton/lang/parser.h"
|
#include "triton/lang/parser.h"
|
||||||
|
@@ -102,12 +102,15 @@ def _build(src, path, framework):
|
|||||||
# libraries
|
# libraries
|
||||||
libraries = ['triton']
|
libraries = ['triton']
|
||||||
# add framework
|
# add framework
|
||||||
|
extra_compile_args = []
|
||||||
if framework == tensorflow_id:
|
if framework == tensorflow_id:
|
||||||
_import_tensorflow()
|
_import_tensorflow()
|
||||||
library_dirs += [tensorflow.sysconfig.get_lib()]
|
library_dirs += [tensorflow.sysconfig.get_lib()]
|
||||||
include_dirs += [tensorflow.sysconfig.get_include()]
|
include_dirs += [tensorflow.sysconfig.get_include()]
|
||||||
include_dirs += ['/usr/local/cuda/include/']
|
include_dirs += ['/usr/local/cuda/include/']
|
||||||
libraries += ['tensorflow_framework']
|
libraries += ['tensorflow_framework']
|
||||||
|
ABI = tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tensorflow.__dict__ else 0
|
||||||
|
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)]
|
||||||
elif framework == torch_id:
|
elif framework == torch_id:
|
||||||
_import_torch()
|
_import_torch()
|
||||||
prefix = os.path.dirname(torch.__file__)
|
prefix = os.path.dirname(torch.__file__)
|
||||||
@@ -120,7 +123,6 @@ def _build(src, path, framework):
|
|||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
# extra arguments
|
# extra arguments
|
||||||
extra_compile_args = []
|
|
||||||
extra_link_args = []
|
extra_link_args = []
|
||||||
# dependences
|
# dependences
|
||||||
depends = [os.path.realpath(libtriton.__file__)]
|
depends = [os.path.realpath(libtriton.__file__)]
|
||||||
@@ -254,14 +256,14 @@ class op:
|
|||||||
return op(*op_args, id=op_id)
|
return op(*op_args, id=op_id)
|
||||||
|
|
||||||
|
|
||||||
# class register_gradient:
|
class register_gradient:
|
||||||
|
|
||||||
# def __init__(self, op):
|
def __init__(self, op):
|
||||||
# self.op = op
|
self.op = op
|
||||||
|
|
||||||
# def __call__(self, f):
|
def __call__(self, f):
|
||||||
# name = 'Dot'
|
name = 'Dot'
|
||||||
# ops.RegisterGradient(name)(f)
|
ops.RegisterGradient(name)(f)
|
||||||
|
|
||||||
|
|
||||||
def empty(shapes, framework = None):
|
def empty(shapes, framework = None):
|
||||||
@@ -276,6 +278,9 @@ def empty(shapes, framework = None):
|
|||||||
_import_torch()
|
_import_torch()
|
||||||
return torch.empty(*shapes)
|
return torch.empty(*shapes)
|
||||||
|
|
||||||
|
def cdiv(a, b):
|
||||||
|
return -(-a // b)
|
||||||
|
|
||||||
class scalar:
|
class scalar:
|
||||||
|
|
||||||
def __init__(self, x):
|
def __init__(self, x):
|
||||||
|
@@ -22,8 +22,8 @@ std::vector<double> do_bench(drv::stream* stream, int32_t N){
|
|||||||
// create options
|
// create options
|
||||||
rt::function::options_space_t opt;
|
rt::function::options_space_t opt;
|
||||||
opt.defines.push_back({"TYPE", {ty}});
|
opt.defines.push_back({"TYPE", {ty}});
|
||||||
opt.defines.push_back({"TN", {"512"}});
|
opt.defines.push_back({"TN", {"128"}});
|
||||||
opt.num_warps = {4};
|
opt.num_warps = {1, 2, 4, 8};
|
||||||
// create function
|
// create function
|
||||||
rt::function function(src::copy1d, opt);
|
rt::function function(src::copy1d, opt);
|
||||||
// benchmark available libraries
|
// benchmark available libraries
|
||||||
@@ -42,7 +42,7 @@ int main() {
|
|||||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<int> config_t;
|
typedef std::tuple<int> config_t;
|
||||||
std::vector<config_t> configs = { 1024*1024*16 };
|
std::vector<config_t> configs = { 1024*1024*32 };
|
||||||
int N;
|
int N;
|
||||||
for(const auto& c: configs){
|
for(const auto& c: configs){
|
||||||
std::tie(N) = c;
|
std::tie(N) = c;
|
||||||
|
@@ -29,6 +29,7 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
|||||||
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||||
typedef float NumericT;
|
typedef float NumericT;
|
||||||
std::string ty = "float";
|
std::string ty = "float";
|
||||||
|
cublasDataType_t cuty = CUDA_R_32F;
|
||||||
size_t dt_nbytes = sizeof(NumericT);
|
size_t dt_nbytes = sizeof(NumericT);
|
||||||
drv::context* context = stream->context();
|
drv::context* context = stream->context();
|
||||||
// leading dimensions
|
// leading dimensions
|
||||||
@@ -44,10 +45,10 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
|||||||
opt.defines.push_back({"TYPE", {ty}});
|
opt.defines.push_back({"TYPE", {ty}});
|
||||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||||
opt.defines.push_back({"TM", {"128"}});
|
opt.defines.push_back({"TM", {"64", "128"}});
|
||||||
opt.defines.push_back({"TN", {"64"}});
|
opt.defines.push_back({"TN", {"64", "128"}});
|
||||||
opt.defines.push_back({"TK", {"8"}});
|
opt.defines.push_back({"TK", {"8"}});
|
||||||
opt.num_warps = {4};
|
opt.num_warps = {2, 4, 8};
|
||||||
// create function
|
// create function
|
||||||
rt::function function(src::dot, opt);
|
rt::function function(src::dot, opt);
|
||||||
// benchmark available libraries
|
// benchmark available libraries
|
||||||
@@ -57,10 +58,11 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
|||||||
if(cublas::cublasinit()){
|
if(cublas::cublasinit()){
|
||||||
NumericT alpha(static_cast<double>(1));
|
NumericT alpha(static_cast<double>(1));
|
||||||
NumericT beta(static_cast<double>(0));
|
NumericT beta(static_cast<double>(0));
|
||||||
cublasGemmAlgo_t fastest = CUBLAS_GEMM_ALGO5;
|
cublasGemmAlgo_t fastest;
|
||||||
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
cublasGemm(cuty, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||||
double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K,
|
double cublas_ms = triton::tools::bench([&]() { cublasGemm(cuty, stream, AT, BT, M, N, K,
|
||||||
&alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream);
|
&alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
||||||
|
ldc, nullptr, fastest); }, stream);
|
||||||
result.push_back(tflops(cublas_ms));
|
result.push_back(tflops(cublas_ms));
|
||||||
}
|
}
|
||||||
// triton
|
// triton
|
||||||
|
Reference in New Issue
Block a user