[STYLE] run autopep8 and isort (#421)
Run: ``` isort ./python autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py') ``` with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
committed by
GitHub
parent
120cda015e
commit
8bf551ae7a
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
# -------------------------------
|
||||
@@ -8,18 +9,18 @@ import triton
|
||||
nt = {False: 'n', True: 't'}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N', 'K'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg = 'block',
|
||||
line_vals = [16, 32, 64, 128],
|
||||
line_names = ['Block16', 'Block32', 'Block64', 'Block128'],
|
||||
ylabel = 'TFLOPS',
|
||||
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
for AT in [False] for BT in [False] \
|
||||
for op_mode in ['dsd'] for layout_mode in ['dense']
|
||||
x_names=['M', 'N', 'K'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg='block',
|
||||
line_vals=[16, 32, 64, 128],
|
||||
line_names=['Block16', 'Block32', 'Block64', 'Block128'],
|
||||
ylabel='TFLOPS',
|
||||
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args={'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)
|
||||
for AT in [False] for BT in [False]
|
||||
for op_mode in ['dsd'] for layout_mode in ['dense']
|
||||
]
|
||||
|
||||
|
||||
@@ -27,7 +28,7 @@ square_confs = [
|
||||
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
# create layout
|
||||
@@ -45,10 +46,10 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
|
||||
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||
num_flops = {
|
||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
|
||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
|
||||
'dds': 2 * Z * M * float(layout.sum()) * block * block
|
||||
}[op_mode]*1e-12
|
||||
}[op_mode] * 1e-12
|
||||
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
|
||||
|
||||
|
||||
@@ -58,15 +59,15 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
|
||||
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg = 'block',
|
||||
line_vals = [16, 32, 64],
|
||||
line_names = ['Block16', 'Block32', 'Block64'],
|
||||
ylabel = 'GBPS',
|
||||
plot_name = f'{layout_mode}-square',
|
||||
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
x_names=['M', 'N'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg='block',
|
||||
line_vals=[16, 32, 64],
|
||||
line_names=['Block16', 'Block32', 'Block64'],
|
||||
ylabel='GBPS',
|
||||
plot_name=f'{layout_mode}-square',
|
||||
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)
|
||||
for layout_mode in ['dense', 'tril']
|
||||
]
|
||||
|
||||
@@ -88,4 +89,4 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
|
||||
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
||||
|
||||
|
||||
bench_matmul.run(print_data=True, show_plots=True)
|
||||
bench_matmul.run(print_data=True, show_plots=True)
|
||||
|
Reference in New Issue
Block a user