[GENERAL] Cleaned polymorphic structure of layouts analysis pass

This commit is contained in:
Philippe Tillet
2020-01-20 15:15:32 -05:00
parent 382ca2c745
commit 78b98fb7cf
17 changed files with 500 additions and 480 deletions

View File

@@ -79,7 +79,7 @@ for N, T, H, S, E in NTHSE:
# 1D Dense convolution
NCHKR = [
# (1, 1152, 12602, 512, 3)
(1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
@@ -92,10 +92,10 @@ for N, C, H, K, R in NCHKR:
# 2D Dense convolution
NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
(8, 64, 128, 128, 768, 3, 3),
(8, 128, 64, 64, 256, 3, 3),
(8, 256, 32, 32, 512, 3, 3),
(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
@@ -108,10 +108,10 @@ for N, C, H, W, K, R, S in NCHWKRS:
# 3D Dense Convolution
NCDHWKTRS = [
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
(8, 32, 27, 100, 100, 64, 3, 3, 3),
(8, 64, 23, 48, 48, 256, 3, 3, 3),
(8, 256, 19, 22, 22, 640, 3, 3, 3),
(8, 640, 15, 36, 36, 384, 3, 3, 3)
]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
@@ -168,7 +168,7 @@ for N, C, H, W, K, R, S in NCHWKRS:
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.FloatTensor
dtype = torch.cuda.HalfTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()