[PYTHON][KERNEL] Enforcing shapes to be known at compile-time for

TensorFlow Graph Execution
This commit is contained in:
Philippe Tillet
2019-10-28 17:12:37 -04:00
parent e9c787ef05
commit 448f4433d9
9 changed files with 82 additions and 52 deletions

View File

@@ -53,11 +53,11 @@ class ProdKeyTest(tf.test.TestCase):
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
a = tf.placeholder(tf.float16, a_shape, name="a")
b = tf.placeholder(tf.float16, b_shape, name="b")
e = tf.placeholder(tf.float16, c_shape, name="e")
feed_dict = { a: A.astype(np.float16),
b: B.astype(np.float16),
a = tf.placeholder(tf.float32, a_shape, name="a")
b = tf.placeholder(tf.float32, b_shape, name="b")
e = tf.placeholder(tf.float32, c_shape, name="e")
feed_dict = { a: A.astype(np.float32),
b: B.astype(np.float32),
e: E }
c = triton.ops.einsum(einsum, a, b, bench=bench)