[python][example] now executing tensorflow and/or pytorch example
automatically
This commit is contained in:
@@ -2,7 +2,6 @@ import numpy as np
|
||||
import triton
|
||||
|
||||
def run_tf():
|
||||
import tensorflow as tf
|
||||
M, N, K = 128, 128, 128
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
@@ -28,7 +27,6 @@ def run_tf():
|
||||
print("dif: %f" % np.max(dif))
|
||||
|
||||
def run_torch():
|
||||
import torch
|
||||
torch.manual_seed(0)
|
||||
M, N, K = 128, 128, 128
|
||||
a = torch.randn(M, K).cuda()
|
||||
@@ -56,5 +54,14 @@ def run_torch():
|
||||
print('Diff DA:', (torch_da - triton_da).max())
|
||||
print('Diff DB:', (torch_db - triton_db).max())
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
run_tf()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
run_torch()
|
||||
try:
|
||||
import torch
|
||||
run_torch()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
Reference in New Issue
Block a user