[STYLE] run autopep8 and isort (#421)

Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
Madeleine Thompson
2022-01-06 14:34:17 -08:00
committed by GitHub
parent 120cda015e
commit 8bf551ae7a
30 changed files with 742 additions and 623 deletions

View File

@@ -1,4 +1,5 @@
import torch
import triton
# -------------------------------
@@ -8,18 +9,18 @@ import triton
nt = {False: 'n', True: 't'}
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block',
line_vals = [16, 32, 64, 128],
line_names = ['Block16', 'Block32', 'Block64', 'Block128'],
ylabel = 'TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\
for AT in [False] for BT in [False] \
for op_mode in ['dsd'] for layout_mode in ['dense']
x_names=['M', 'N', 'K'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg='block',
line_vals=[16, 32, 64, 128],
line_names=['Block16', 'Block32', 'Block64', 'Block128'],
ylabel='TFLOPS',
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args={'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)
for AT in [False] for BT in [False]
for op_mode in ['dsd'] for layout_mode in ['dense']
]
@@ -27,7 +28,7 @@ square_confs = [
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode]
# create layout
@@ -45,10 +46,10 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode]*1e-12
}[op_mode] * 1e-12
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
@@ -58,15 +59,15 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block',
line_vals = [16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'],
ylabel = 'GBPS',
plot_name = f'{layout_mode}-square',
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)\
x_names=['M', 'N'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg='block',
line_vals=[16, 32, 64],
line_names=['Block16', 'Block32', 'Block64'],
ylabel='GBPS',
plot_name=f'{layout_mode}-square',
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)
for layout_mode in ['dense', 'tril']
]
@@ -88,4 +89,4 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
bench_matmul.run(print_data=True, show_plots=True)
bench_matmul.run(print_data=True, show_plots=True)

View File

@@ -1,17 +1,18 @@
import torch
import triton
confs = [
triton.testing.Benchmark(
x_names = ['N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg = 'provider',
line_vals = ['triton', 'torch'],
line_names = ['Triton', 'Torch'],
ylabel = 'GBPS',
plot_name = f'{mode}-2048',
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
)\
x_names=['N'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
ylabel='GBPS',
plot_name=f'{mode}-2048',
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
)
for mode in ['forward', 'backward']
]
@@ -24,8 +25,8 @@ def bench_op(M, N, dtype, mode, provider):
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3
# forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
'triton': triton.ops.cross_entropy}[provider]
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward':
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
if mode == 'backward':
@@ -37,4 +38,4 @@ def bench_op(M, N, dtype, mode, provider):
if __name__ == '__main__':
bench_op.run(print_data=True)
bench_op.run(print_data=True)

View File

@@ -1,6 +1,6 @@
import triton
import torch
import os
import triton
def rounded_linspace(low, high, steps, div):
@@ -29,16 +29,16 @@ square_confs = [
transformer_confs = [
triton.testing.Benchmark(
x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 32, 128),
x_vals=rounded_linspace(NK // 16, NK, 32, 128),
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]\
for i, x in enumerate(["N", "K"])\
for M in [2048]
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]
for i, x in enumerate(["N", "K"])
for M in [2048]
]
@@ -46,8 +46,10 @@ transformer_confs = [
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT: a = a.t()
if BT: b = b.t()
if AT:
a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "cublas":
@@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
except:
except Exception:
return None
return None

View File

@@ -1,7 +1,8 @@
import argparse
import sys
import os
import inspect
import os
import sys
import triton