diff --git a/python/examples/einsum.py b/python/examples/einsum.py index 4ec2dea36..d86edf847 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -13,7 +13,7 @@ configs = [] MNK = [ (512, 512 ,512), (2048, 2048, 2048), - (8192, 8192, 8192), + #(8192, 8192, 8192), # (64, 64, 64000), # (64, 64, 128000), @@ -68,7 +68,7 @@ NTHSE = [ # (128, 1024, 8, 64, 64), # (128, 1024, 8, 128, 128), # (128, 1024, 8, 256, 256), - (128, 1024, 8, 256, 512) + #(128, 1024, 8, 256, 512) ] for N, T, H, S, E in NTHSE: configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())] @@ -92,10 +92,10 @@ for N, C, H, K, R in NCHKR: # 2D Dense convolution NCHWKRS = [ - (8, 64, 128, 128, 768, 3, 3), + #(8, 64, 128, 128, 768, 3, 3), (8, 128, 64, 64, 256, 3, 3), (8, 256, 32, 32, 512, 3, 3), - (8, 512, 32, 32, 1024, 3, 3) + #(8, 512, 32, 32, 1024, 3, 3) ] for N, C, H, W, K, R, S in NCHWKRS: torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2)) @@ -108,10 +108,10 @@ for N, C, H, W, K, R, S in NCHWKRS: # 3D Dense Convolution NCDHWKTRS = [ - (8, 32, 27, 100, 100, 64, 3, 3, 3), - (8, 64, 23, 48, 48, 256, 3, 3, 3), - (8, 256, 19, 22, 22, 640, 3, 3, 3), - (8, 640, 15, 36, 36, 384, 3, 3, 3) + #(8, 32, 27, 100, 100, 64, 3, 3, 3), + #(8, 64, 23, 48, 48, 256, 3, 3, 3), + #(8, 256, 19, 22, 22, 640, 3, 3, 3), + #(8, 640, 15, 36, 36, 384, 3, 3, 3) ] for N, C, D, H, W, K, T, R, S in NCDHWKTRS: torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3)) diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 936a5fced..005ea8812 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -646,5 +646,4 @@ __global__ void {name}( return None, da, None, None, None - einsum = _einsum.apply \ No newline at end of file