Merge branch 'master' of https://github.com/ptillet/triton
This commit is contained in:
@@ -2,7 +2,6 @@ import numpy as np
|
|||||||
import triton
|
import triton
|
||||||
|
|
||||||
def run_tf():
|
def run_tf():
|
||||||
import tensorflow as tf
|
|
||||||
M, N, K = 128, 128, 128
|
M, N, K = 128, 128, 128
|
||||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||||
@@ -28,7 +27,6 @@ def run_tf():
|
|||||||
print("dif: %f" % np.max(dif))
|
print("dif: %f" % np.max(dif))
|
||||||
|
|
||||||
def run_torch():
|
def run_torch():
|
||||||
import torch
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
M, N, K = 128, 128, 128
|
M, N, K = 128, 128, 128
|
||||||
a = torch.randn(M, K).cuda()
|
a = torch.randn(M, K).cuda()
|
||||||
@@ -56,5 +54,14 @@ def run_torch():
|
|||||||
print('Diff DA:', (torch_da - triton_da).max())
|
print('Diff DA:', (torch_da - triton_da).max())
|
||||||
print('Diff DB:', (torch_db - triton_db).max())
|
print('Diff DB:', (torch_db - triton_db).max())
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
run_tf()
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
run_torch()
|
run_torch()
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
Reference in New Issue
Block a user