[python] fixed various issues in pytorch supoport
This commit is contained in:
@@ -1,14 +1,13 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import triton
|
||||
|
||||
def run_dot():
|
||||
def run_tf():
|
||||
import tensorflow as tf
|
||||
M, N, K = 128, 128, 128
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
_dot = triton.ops.dot.apply
|
||||
tr_c = _dot(a, b, transpose_a = False, transpose_b = True)
|
||||
tr_d = _dot(tr_c, b, transpose_a = True, transpose_b = False)
|
||||
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True)
|
||||
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False)
|
||||
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
||||
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
||||
# Gradient
|
||||
@@ -28,4 +27,13 @@ def run_dot():
|
||||
dif = np.abs(result[0][0] - result[1][0])
|
||||
print("dif: %f" % np.max(dif))
|
||||
|
||||
run_dot()
|
||||
def run_torch():
|
||||
import torch as th
|
||||
M, N, K = 128, 128, 128
|
||||
a = th.randn(M, K).cuda()
|
||||
b = th.randn(K, N).cuda()
|
||||
th_c = th.matmul(a, b)
|
||||
tr_c = triton.ops.dot(a, b)
|
||||
print(c)
|
||||
|
||||
run_torch()
|
Reference in New Issue
Block a user