[python] more progress towards tensorflow/pytorch unification
This commit is contained in:
@@ -2,41 +2,10 @@ import tensorflow as tf
|
||||
import triton
|
||||
import numpy as np
|
||||
|
||||
src = """
|
||||
// Templates for accessing A
|
||||
#if AT == 1
|
||||
#define USE_A ^a
|
||||
#define STRIDE_AK lda
|
||||
#define STRIDE_AM 1
|
||||
#define BROADCAST_AK :, newaxis
|
||||
#define BROADCAST_AM newaxis, :
|
||||
#define SHAPE_A TK, TM
|
||||
#else
|
||||
#define USE_A a
|
||||
#define STRIDE_AK 1
|
||||
#define STRIDE_AM lda
|
||||
#define BROADCAST_AK newaxis, :
|
||||
#define BROADCAST_AM :, newaxis
|
||||
#define SHAPE_A TM, TK
|
||||
#endif
|
||||
|
||||
// Templates for accessing B
|
||||
#if BT == 1
|
||||
#define USE_B ^b
|
||||
#define STRIDE_BK 1
|
||||
#define STRIDE_BN ldb
|
||||
#define BROADCAST_BK newaxis, :
|
||||
#define BROADCAST_BN :, newaxis
|
||||
#define SHAPE_B TN, TK
|
||||
#else
|
||||
#define USE_B b
|
||||
#define STRIDE_BK ldb
|
||||
#define STRIDE_BN 1
|
||||
#define BROADCAST_BK :, newaxis
|
||||
#define BROADCAST_BN newaxis, :
|
||||
#define SHAPE_B TK, TN
|
||||
#endif
|
||||
class dot(triton.function):
|
||||
|
||||
src = """
|
||||
void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8),
|
||||
@@ -73,71 +42,99 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
}
|
||||
"""
|
||||
|
||||
class dot_op(triton.op2):
|
||||
op = triton.op(src, ['C'])
|
||||
|
||||
def __init__(self, transpose_a = False, transpose_b = False):
|
||||
self.dot = triton.op(src, ['C'])
|
||||
self.transpose_a = transpose_a
|
||||
self.transpose_b = transpose_b
|
||||
|
||||
def forward(self, a, b):
|
||||
dtype = a.dtype
|
||||
@staticmethod
|
||||
def _call(a, b, transpose_a, transpose_b):
|
||||
# extract shapes
|
||||
shape_a = triton.shape(a)
|
||||
shape_b = triton.shape(b)
|
||||
M, Ka = shape_a[0], shape_a[1]
|
||||
Kb, N = shape_b[0], shape_b[1]
|
||||
# transpose shapes
|
||||
if self.transpose_a:
|
||||
if transpose_a:
|
||||
M, Ka = Ka, M
|
||||
if self.transpose_b:
|
||||
if transpose_b:
|
||||
Kb, N = N, Kb
|
||||
# contiguous dimensions
|
||||
lda = M if self.transpose_a else Ka
|
||||
ldb = Kb if self.transpose_b else N
|
||||
lda = M if transpose_a else Ka
|
||||
ldb = Kb if transpose_b else N
|
||||
ldc = N
|
||||
# data-type
|
||||
dtype = a.dtype
|
||||
# allocate output
|
||||
c = triton.empty([M, N], dtype = dtype)
|
||||
# compute
|
||||
return self.dot(a, b, c, M, N, Ka, lda, ldb, ldc,
|
||||
lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))],
|
||||
AT = self.transpose_a, BT = self.transpose_b, TYPE = dtype,
|
||||
TM = [128], TN = [128], TK = [8])
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
# macros -- not necessary but makes kernel source-code simpler
|
||||
macros = {# handle A transposition
|
||||
'USE_A' : '^a' if transpose_a else 'a',
|
||||
'STRIDE_AK' : 'lda' if transpose_a else '1',
|
||||
'STRIDE_AM' : '1' if transpose_a else 'lda',
|
||||
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
|
||||
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
|
||||
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
|
||||
# handle B transposition
|
||||
'USE_B' : '^b' if transpose_b else 'b',
|
||||
'STRIDE_BK' : '1' if transpose_b else 'ldb',
|
||||
'STRIDE_BN' : 'ldb' if transpose_b else '1',
|
||||
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
|
||||
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
|
||||
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
|
||||
return dot.op(a, b, c, M, N, Ka, lda, ldb, ldc, grid,
|
||||
AT = transpose_a, BT = transpose_b, TYPE = dtype,
|
||||
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
|
||||
|
||||
def backward(self, op, dy):
|
||||
a = op.inputs[0]
|
||||
b = op.inputs[1]
|
||||
da = dot_op(self.transpose_a, self.transpose_b).forward(dy, b)
|
||||
db = dot_op(self.transpose_a, self.transpose_b).forward(a, dy)
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
|
||||
ctx.save_for_backward(a, b, transpose_a, transpose_b)
|
||||
return dot._call(a, b, transpose_a, transpose_b)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
a, b, t_a, t_b = ctx.saved_tensors
|
||||
if not t_a and not t_b:
|
||||
da = dot._call(dy, b, False, True)
|
||||
db = dot._call(a, dy, True, False)
|
||||
elif not t_a and t_b:
|
||||
da = dot._call(dy, b, False, False)
|
||||
db = dot._call(dy, a, True, False)
|
||||
elif t_a and not t_b:
|
||||
da = dot._call(b, dy, False, True)
|
||||
db = dot._call(a, dy, False, False)
|
||||
elif t_a and t_b:
|
||||
da = dot._call(b, dy, True, True)
|
||||
db = dot._call(dy, a, True, True)
|
||||
else:
|
||||
assert False
|
||||
return [da, db, None, None, None, None, None, None, None]
|
||||
|
||||
|
||||
def dot(a, b, transpose_a = False, transpose_b = False):
|
||||
if (transpose_a, transpose_b) not in dot.ops:
|
||||
dot.ops[transpose_a, transpose_b] = dot_op(transpose_a, transpose_b)
|
||||
return dot.ops[transpose_a, transpose_b](a, b)
|
||||
dot.ops = dict()
|
||||
|
||||
|
||||
def run_dot():
|
||||
M, N, K = 128, 128, 128
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
c = dot(a, b, transpose_a = False, transpose_b = False)
|
||||
da = tf.gradients(c, [a])
|
||||
_dot = dot.apply
|
||||
tr_c = _dot(a, b, transpose_a = False, transpose_b = True)
|
||||
tr_d = _dot(tr_c, b, transpose_a = True, transpose_b = False)
|
||||
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
||||
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
||||
# Gradient
|
||||
tr_da = tf.gradients(tr_d, [a])
|
||||
tf_da = tf.gradients(tf_d, [a])
|
||||
# Reference
|
||||
ha = np.random.rand(M, K).astype(np.float32)
|
||||
hb = np.random.rand(K, N).astype(np.float32)
|
||||
# Run
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([da], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
|
||||
b: hb})
|
||||
# Test
|
||||
print(result)
|
||||
hresult = np.dot(ha, hb)
|
||||
dif = np.abs(result - hresult)
|
||||
np.savetxt('dif.dat', dif, '%2.4f')
|
||||
print(result[0][0])
|
||||
print(result[1][0])
|
||||
dif = np.abs(result[0][0] - result[1][0])
|
||||
print("dif: %f" % np.max(dif))
|
||||
|
||||
run_dot()
|
Reference in New Issue
Block a user