[python] more progress towards tensorflow/pytorch unification

This commit is contained in:
Philippe Tillet
2019-09-04 12:47:59 -04:00
parent cdbc9d4ecd
commit f6e9c24fe8
3 changed files with 132 additions and 104 deletions

View File

@@ -2,41 +2,10 @@ import tensorflow as tf
import triton
import numpy as np
src = """
// Templates for accessing A
#if AT == 1
#define USE_A ^a
#define STRIDE_AK lda
#define STRIDE_AM 1
#define BROADCAST_AK :, newaxis
#define BROADCAST_AM newaxis, :
#define SHAPE_A TK, TM
#else
#define USE_A a
#define STRIDE_AK 1
#define STRIDE_AM lda
#define BROADCAST_AK newaxis, :
#define BROADCAST_AM :, newaxis
#define SHAPE_A TM, TK
#endif
// Templates for accessing B
#if BT == 1
#define USE_B ^b
#define STRIDE_BK 1
#define STRIDE_BN ldb
#define BROADCAST_BK newaxis, :
#define BROADCAST_BN :, newaxis
#define SHAPE_B TN, TK
#else
#define USE_B b
#define STRIDE_BK ldb
#define STRIDE_BN 1
#define BROADCAST_BK :, newaxis
#define BROADCAST_BN newaxis, :
#define SHAPE_B TK, TN
#endif
class dot(triton.function):
src = """
void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8),
@@ -73,71 +42,99 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
}
"""
class dot_op(triton.op2):
op = triton.op(src, ['C'])
def __init__(self, transpose_a = False, transpose_b = False):
self.dot = triton.op(src, ['C'])
self.transpose_a = transpose_a
self.transpose_b = transpose_b
def forward(self, a, b):
dtype = a.dtype
@staticmethod
def _call(a, b, transpose_a, transpose_b):
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
M, Ka = shape_a[0], shape_a[1]
Kb, N = shape_b[0], shape_b[1]
# transpose shapes
if self.transpose_a:
if transpose_a:
M, Ka = Ka, M
if self.transpose_b:
if transpose_b:
Kb, N = N, Kb
# contiguous dimensions
lda = M if self.transpose_a else Ka
ldb = Kb if self.transpose_b else N
lda = M if transpose_a else Ka
ldb = Kb if transpose_b else N
ldc = N
# data-type
dtype = a.dtype
# allocate output
c = triton.empty([M, N], dtype = dtype)
# compute
return self.dot(a, b, c, M, N, Ka, lda, ldb, ldc,
lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))],
AT = self.transpose_a, BT = self.transpose_b, TYPE = dtype,
TM = [128], TN = [128], TK = [8])
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
# macros -- not necessary but makes kernel source-code simpler
macros = {# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',
'STRIDE_AM' : '1' if transpose_a else 'lda',
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
# handle B transposition
'USE_B' : '^b' if transpose_b else 'b',
'STRIDE_BK' : '1' if transpose_b else 'ldb',
'STRIDE_BN' : 'ldb' if transpose_b else '1',
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
return dot.op(a, b, c, M, N, Ka, lda, ldb, ldc, grid,
AT = transpose_a, BT = transpose_b, TYPE = dtype,
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
def backward(self, op, dy):
a = op.inputs[0]
b = op.inputs[1]
da = dot_op(self.transpose_a, self.transpose_b).forward(dy, b)
db = dot_op(self.transpose_a, self.transpose_b).forward(a, dy)
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
ctx.save_for_backward(a, b, transpose_a, transpose_b)
return dot._call(a, b, transpose_a, transpose_b)
@staticmethod
def backward(ctx, dy):
a, b, t_a, t_b = ctx.saved_tensors
if not t_a and not t_b:
da = dot._call(dy, b, False, True)
db = dot._call(a, dy, True, False)
elif not t_a and t_b:
da = dot._call(dy, b, False, False)
db = dot._call(dy, a, True, False)
elif t_a and not t_b:
da = dot._call(b, dy, False, True)
db = dot._call(a, dy, False, False)
elif t_a and t_b:
da = dot._call(b, dy, True, True)
db = dot._call(dy, a, True, True)
else:
assert False
return [da, db, None, None, None, None, None, None, None]
def dot(a, b, transpose_a = False, transpose_b = False):
if (transpose_a, transpose_b) not in dot.ops:
dot.ops[transpose_a, transpose_b] = dot_op(transpose_a, transpose_b)
return dot.ops[transpose_a, transpose_b](a, b)
dot.ops = dict()
def run_dot():
M, N, K = 128, 128, 128
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
c = dot(a, b, transpose_a = False, transpose_b = False)
da = tf.gradients(c, [a])
_dot = dot.apply
tr_c = _dot(a, b, transpose_a = False, transpose_b = True)
tr_d = _dot(tr_c, b, transpose_a = True, transpose_b = False)
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
# Gradient
tr_da = tf.gradients(tr_d, [a])
tf_da = tf.gradients(tf_d, [a])
# Reference
ha = np.random.rand(M, K).astype(np.float32)
hb = np.random.rand(K, N).astype(np.float32)
# Run
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([da], feed_dict = {a: ha,
b: hb})[0]
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
b: hb})
# Test
print(result)
hresult = np.dot(ha, hb)
dif = np.abs(result - hresult)
np.savetxt('dif.dat', dif, '%2.4f')
print(result[0][0])
print(result[1][0])
dif = np.abs(result[0][0] - result[1][0])
print("dif: %f" % np.max(dif))
run_dot()