[CODEGEN][TRANSFORM] some bug-fixes for FP32 einsum

This commit is contained in:
Philippe Tillet
2020-01-19 19:58:40 -05:00
parent f278d9741a
commit fbf2a3f56f
4 changed files with 4 additions and 8 deletions

View File

@@ -168,7 +168,7 @@ for N, C, H, W, K, R, S in NCHWKRS:
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.HalfTensor
dtype = torch.cuda.FloatTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()