[PYTHON] Added automated benchmark script (#63)
This adds a bench functionality to the setup.py that can be used to run the benchmark suite and generates a bunch of csv files (and optionally plots) python setup.py bench python setup.py bench --with-plots python setup.py bench --filter=cross_entropy
This commit is contained in:
committed by
Philippe Tillet
parent
66c94f21d7
commit
5e3c7f5a60
@@ -1,160 +0,0 @@
|
||||
import itertools
|
||||
import torch
|
||||
import triton as tt
|
||||
import pytest
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||
ret[:, idx, :, :] = x[:, h, i*block: (i+1)*block, j*block: (j+1)*block]
|
||||
return ret
|
||||
|
||||
def mask_tensor(x, mask, block, value = 0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
## -----------------------------------------------------------------------------
|
||||
## Unit Tests
|
||||
## -----------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
|
||||
[
|
||||
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
|
||||
for at in [False, True]\
|
||||
for bt in [False, True]\
|
||||
for block in [16, 32, 64]
|
||||
]
|
||||
)
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
|
||||
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0]//BLOCK, shape[1]//BLOCK))
|
||||
# triton result
|
||||
op = tt.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
|
||||
rb = sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b
|
||||
rc = op(ra, rb)
|
||||
# torch result
|
||||
ta = mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
|
||||
tb = mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b
|
||||
ta = ta.transpose(2, 3) if TRANS_A else ta
|
||||
tb = tb.transpose(2, 3) if TRANS_B else tb
|
||||
tc = torch.matmul(ta, tb)
|
||||
tc = mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
|
||||
tc = sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
|
||||
# compare
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BLOCK, WIDTH",
|
||||
[
|
||||
(block, width) for block in [32]\
|
||||
for width in [256, 576, 1024, 2048, 4096]
|
||||
]
|
||||
)
|
||||
def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
Z, H, M, N = 2, 4, WIDTH, WIDTH
|
||||
scale = 0.4
|
||||
# create inputs
|
||||
layout = torch.randint(2, (H, M//BLOCK, N//BLOCK))
|
||||
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda')
|
||||
at_mask = torch.randint(low=0, high=2, size=(N, N), \
|
||||
dtype=torch.bool, requires_grad=False, device='cuda')
|
||||
kp_mask = torch.randint(low=0, high=2, size=(Z, N), \
|
||||
dtype=DTYPE, requires_grad=False, device='cuda')
|
||||
kp_mask[kp_mask==1.] = float('-inf')
|
||||
# triton result
|
||||
op = tt.ops.blocksparse.softmax(layout, BLOCK)
|
||||
tx = sparsify_tensor(x, layout, BLOCK)
|
||||
ty = op(tx, scale=scale)
|
||||
# torch result
|
||||
rx = mask_tensor(x, layout, BLOCK, value=float('-inf'))
|
||||
# if at_mask is not None:
|
||||
# # broadcast at_mask to the same shape as rx
|
||||
# M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||
# rx[M == 0] = float('-inf')
|
||||
# if kp_mask is not None:
|
||||
# rx += kp_mask[:, None, None, :]
|
||||
ry = torch.softmax(rx*scale, -1)
|
||||
ry = sparsify_tensor(ry, layout, BLOCK)
|
||||
# compare
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
## -----------------------------------------------------------------------------
|
||||
## Performance Tests
|
||||
## -----------------------------------------------------------------------------
|
||||
|
||||
def do_bench(fn, warmup = 10, rep = 50):
|
||||
import torch as th
|
||||
start_event = th.cuda.Event(enable_timing=True)
|
||||
end_event = th.cuda.Event(enable_timing=True)
|
||||
ret = fn()
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
th.cuda.synchronize()
|
||||
start_event.record()
|
||||
for i in range(rep):
|
||||
fn()
|
||||
end_event.record()
|
||||
th.cuda.synchronize()
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms
|
||||
|
||||
def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50):
|
||||
Z, H = 1, 1
|
||||
K = 512
|
||||
make_layout = {
|
||||
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[LAYOUT_MODE]
|
||||
for N in [128, 256, 512, 1024, 2048, 4096]:
|
||||
# create layout
|
||||
M, N, K = N, N, N
|
||||
shape = {'sdd': (M, N),
|
||||
'dsd': (K, M) if TRANS_A else (M, K),
|
||||
'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE]
|
||||
layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK)
|
||||
# create op
|
||||
op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
# inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
|
||||
a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a
|
||||
b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b
|
||||
ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||
num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
|
||||
'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE]
|
||||
triton_tflops = num_flops / ms * 1e3
|
||||
|
||||
def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50):
|
||||
Z, H = 1, 1
|
||||
K = 512
|
||||
make_layout = {
|
||||
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[LAYOUT_MODE]
|
||||
for N in [128, 256, 512, 1024, 2048, 4096]:
|
||||
layout = make_layout(H, N//BLOCK, N//BLOCK)
|
||||
a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda')
|
||||
a = sparsify_tensor(a, layout, BLOCK)
|
||||
op = tt.ops.blocksparse.softmax(layout, BLOCK)
|
||||
ms = do_bench(lambda: op(a), warmup=warmup, rep=rep)
|
||||
nbytes = 2 * a.numel() * a.element_size()
|
||||
triton_gbyps = (nbytes*1e-9) / (ms*1e-3)
|
||||
print(triton_gbyps)
|
@@ -1,17 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
|
||||
def test_op():
|
||||
torch.manual_seed(0)
|
||||
DTYPE = torch.float16
|
||||
N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
|
||||
pad, stride, = (1, 1), (1, 1)
|
||||
dilation = (1, 1)
|
||||
a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5
|
||||
b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5
|
||||
th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation)
|
||||
tt_c = triton.ops.conv(a, b, pad, stride)
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol)
|
@@ -1,33 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']\
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
def test_op(M, N, dtype, mode):
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
# forward pass
|
||||
tt_y = triton.ops.cross_entropy(x, idx)
|
||||
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
|
||||
if mode == 'forward':
|
||||
assert torch.allclose(th_y, tt_y, atol=1e-3, rtol=1e-2)
|
||||
# backward pass
|
||||
elif mode == 'backward':
|
||||
dy = torch.randn_like(tt_y)
|
||||
# triton backward
|
||||
tt_y.backward(dy)
|
||||
tt_dx = x.grad.clone()
|
||||
# torch backward
|
||||
x.grad.zero_()
|
||||
th_y.backward(dy)
|
||||
th_dx = x.grad.clone()
|
||||
assert torch.allclose(th_dx, tt_dx, atol=1e-3, rtol=1e-2)
|
@@ -1,151 +0,0 @@
|
||||
import pytest
|
||||
import itertools
|
||||
import triton as tt
|
||||
import torch as th
|
||||
|
||||
@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE)
|
||||
]
|
||||
for DTYPE in ['float16']
|
||||
for AT in [False, True]
|
||||
for BT in [False, True]
|
||||
]))
|
||||
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
|
||||
th.manual_seed(0)
|
||||
tt.ops._matmul._kernels = dict()
|
||||
tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
|
||||
if M is None: M = TM
|
||||
if N is None: N = TN
|
||||
if K is None: K = TK*TZ
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
th_c = th.matmul(a, b)
|
||||
tt_c = tt.ops.matmul(a, b)
|
||||
rtol, atol = {th.float32: (1e-4, 1e-5),
|
||||
th.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
||||
start_event = th.cuda.Event(enable_timing=True)
|
||||
end_event = th.cuda.Event(enable_timing=True)
|
||||
ret = fn()
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
th.cuda.synchronize()
|
||||
start_event.record()
|
||||
for i in range(rep):
|
||||
fn()
|
||||
end_event.record()
|
||||
th.cuda.synchronize()
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms
|
||||
|
||||
def time_all(fn, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog=True, plot_name='', **kwargs):
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(columns = [x_names[0]] + y_lines)
|
||||
for x in x_vals:
|
||||
x_args = {x_name: x for x_name in x_names}
|
||||
row = [fn(**x_args, **{y_name: y}, **kwargs) for y in y_vals]
|
||||
df.loc[len(df)] = [x] + row
|
||||
print(df)
|
||||
if plot_name:
|
||||
df.plot(x=x_names[0], y=y_lines, ylabel=ylabel, xlabel=' = '.join(x_names), title=f'{plot_name}', loglog=loglog)
|
||||
plt.savefig(f'{plot_name}.pdf')
|
||||
|
||||
def perf_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
||||
import os
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
num_flops = 2*M*N*K
|
||||
if provider == 'torch':
|
||||
torch_ms = do_bench(lambda: th.matmul(a, b), warmup = warmup, rep = rep)
|
||||
torch_tflops = num_flops / torch_ms * 1e-9
|
||||
return torch_tflops
|
||||
if provider == 'triton':
|
||||
triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep)
|
||||
triton_tflops = num_flops / triton_ms * 1e-9
|
||||
return triton_tflops
|
||||
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
# run program specified by CUTLASS_PROFILER env variable
|
||||
layout_a = 'column' if AT else 'row'
|
||||
layout_b = 'column' if BT else 'row'
|
||||
# create temporary file name
|
||||
fd, fname = tempfile.mkstemp()
|
||||
# run program and gets its output
|
||||
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
|
||||
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', '--warmup-iterations=10', \
|
||||
'--profiling-iterations=50', f'--output={fname}', '--verbose=false']
|
||||
# run cmd
|
||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||
# read CSV output
|
||||
df_c = pd.read_csv(f'{fname}.gemm.csv')
|
||||
cutlass_tflops = max(df_c['GFLOPs'])/1e3
|
||||
return cutlass_tflops
|
||||
return None
|
||||
|
||||
if __name__ == '__main__':
|
||||
# # square
|
||||
x_square = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
|
||||
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, dtype = th.float16, loglog=False, plot_name = 'matmul-square-nn')
|
||||
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = True, dtype = th.float16, loglog=False, plot_name = 'matmul-square-nt')
|
||||
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = True, BT = False, dtype = th.float16, loglog=False, plot_name = 'matmul-square-tn')
|
||||
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = True, BT = True, dtype = th.float16, loglog=False, plot_name = 'matmul-square-tt')
|
||||
# tall-skinny
|
||||
x_tall_skinny = [64, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536]
|
||||
time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=2048, K=2048, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-2k-2k')
|
||||
time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=4096, K=4096, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-4k-4k')
|
||||
time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=6144, K=6144, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-6k-6k')
|
Reference in New Issue
Block a user