[FRONTEND] provide device kwargs && fix fstring error for py<3.8 (#515)
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -40,7 +40,7 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
|
|||||||
# create op
|
# create op
|
||||||
tflops = lambda ms: num_flops / ms * 1e3
|
tflops = lambda ms: num_flops / ms * 1e3
|
||||||
if provider == 'triton':
|
if provider == 'triton':
|
||||||
op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT)
|
op = triton.ops.blocksparse.matmul(layout, block, op_mode, device="cuda", trans_a=AT, trans_b=BT)
|
||||||
# inputs
|
# inputs
|
||||||
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
|
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
|
||||||
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
||||||
@@ -83,7 +83,7 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
|
|||||||
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
|
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
|
||||||
if provider == 'triton':
|
if provider == 'triton':
|
||||||
a = triton.testing.sparsify_tensor(a, layout, block)
|
a = triton.testing.sparsify_tensor(a, layout, block)
|
||||||
op = triton.ops.blocksparse.softmax(layout, block)
|
op = triton.ops.blocksparse.softmax(layout, block, device="cuda")
|
||||||
gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
|
gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
|
||||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
|
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
|
||||||
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
||||||
|
@@ -644,7 +644,7 @@ def test_f16_to_f8_rounding():
|
|||||||
)
|
)
|
||||||
assert torch.all(
|
assert torch.all(
|
||||||
torch.logical_not(mismatch)
|
torch.logical_not(mismatch)
|
||||||
), f"{f16_input[mismatch]=} {f16_output[mismatch]=} {abs_error[mismatch]=} {min_error[mismatch]=}"
|
), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
|
Reference in New Issue
Block a user