[python] fixed various issues in pytorch supoport

This commit is contained in:
Philippe Tillet
2019-09-05 00:19:42 -04:00
parent 945b5d0de9
commit ed0f706005
8 changed files with 182 additions and 92 deletions

View File

@@ -1,14 +1,13 @@
import numpy as np
import tensorflow as tf
import triton
def run_dot():
def run_tf():
import tensorflow as tf
M, N, K = 128, 128, 128
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
_dot = triton.ops.dot.apply
tr_c = _dot(a, b, transpose_a = False, transpose_b = True)
tr_d = _dot(tr_c, b, transpose_a = True, transpose_b = False)
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True)
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False)
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
# Gradient
@@ -28,4 +27,13 @@ def run_dot():
dif = np.abs(result[0][0] - result[1][0])
print("dif: %f" % np.max(dif))
run_dot()
def run_torch():
import torch as th
M, N, K = 128, 128, 128
a = th.randn(M, K).cuda()
b = th.randn(K, N).cuda()
th_c = th.matmul(a, b)
tr_c = triton.ops.dot(a, b)
print(c)
run_torch()