[BACKEND][FRONTEND] Fix problems with test_matmul (#973)

1. Handle induction variable when step is negative
2. Restore async_wait that accidentally deleted
3. Add missing induction variable in prefetch
4. Add device property functions

Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
This commit is contained in:
Keren Zhou
2022-12-10 20:34:58 -08:00
committed by GitHub
parent 24fd953f9a
commit be2f70699c
12 changed files with 217 additions and 58 deletions

View File

@@ -4783,10 +4783,15 @@ private:
decomposed = true;
});
// async wait is supported in Ampere and later
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) ||
decomposed) {
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
// async wait is supported in Ampere and later
asyncWaitOp.erase();
} else if (decomposed) {
// Wait for all previous async ops
OpBuilder builder(asyncWaitOp);
auto newAsyncWaitOp =
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
asyncWaitOp.erase();
}
});

View File

@@ -262,10 +262,10 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
// For now, this behaves like generic, but this will evolve when
// we add support for `can_reorder=False`
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType, adaptor.getOperands());
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
adaptor.getOperands());
return success();
}
};
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
@@ -450,13 +450,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::IntToPtrOp>,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>,
TritonCatPattern,
TritonReducePattern,
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
typeConverter, context);
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
TritonAtomicRMWPattern>(typeConverter, context);
}
//

View File

@@ -782,8 +782,8 @@ public:
newRetType.getEncoding()));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
auto newDot = rewriter.create<triton::DotOp>(dotOp.getLoc(), newRetType, a,
b, newAcc, dotOp.allowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());

View File

@@ -225,6 +225,7 @@ scf::ForOp Prefetcher::createNewForOp() {
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);

101
python/tests/test_matmul.py Normal file
View File

@@ -0,0 +1,101 @@
import itertools
import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
*[
[
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
# split-k
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
]
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if capability[0] < 8 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
#if DTYPE == "bfloat16" and SPLIT_K != 1:
# pytest.skip("bfloat16 matmuls don't allow split_k for now")
if DTYPE == "bfloat16":
pytest.skip("bfloat16 matmuls doesn't support for now")
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
kernel = triton.ops._matmul.kernel
kernel.configs = configs
# kernel.run = kernel.run.run.run
# get matrix shape
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
# run test
th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
triton.testing.assert_almost_equal(th_c, tt_c)

View File

@@ -68,7 +68,7 @@ def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size):
@num_elements: number of elements
'''
pid = tl.program_id(axis=0)
for i in range(math.ceil(block_size / iter_size)):
for i in range(tl.cdiv(block_size, iter_size)):
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset

View File

@@ -329,10 +329,6 @@ class CodeGenerator(ast.NodeVisitor):
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value
fn = {
ast.Add: '__add__',
ast.Sub: '__sub__',
@@ -591,8 +587,10 @@ class CodeGenerator(ast.NodeVisitor):
ast.NodeVisitor.generic_visit(self, stmt)
return
# handle negative constant step (not supported by scf.for in MLIR)
negative_step = False
if isinstance(step, triton.language.constexpr) and step.value < 0:
step = triton.language.constexpr(-step.value)
negative_step = True
lb, ub = ub, lb
# lb/ub/step might be constexpr, we need to cast them to tensor
lb = triton.language.core._to_tensor(lb, self.builder).handle
@@ -640,6 +638,9 @@ class CodeGenerator(ast.NodeVisitor):
# update induction variable with actual value, and replace all uses
self.builder.set_insertion_point_to_start(for_op.get_body(0))
iv = self.builder.create_index_to_si(for_op.get_induction_var())
if negative_step:
ub_si = self.builder.create_index_to_si(ub)
iv = self.builder.create_sub(ub_si, iv)
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
@@ -890,9 +891,9 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.enable_debug()
# Convert blocked layout to mma layout for dot ops so that pipeline
# can get shared memory swizzled correctly.
pm.add_coalesce_pass()
# The combine pass converts blocked layout to mma layout
# for dot ops so that pipeline can get shared memory swizzled correctly.
pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_tritongpu_pipeline_pass(num_stages)
# Prefetch must be done after pipeline pass because pipeline pass
@@ -1395,20 +1396,20 @@ def compile(fn, **kwargs):
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
capability = torch.cuda.get_device_capability()
capability = capability[0]*10 + capability[1]
capability = capability[0] * 10 + capability[1]
# build compilation stages
stages = {
"ast" : (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, capability))
"ast": (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, capability))
}
# find out the signature of the function
if isinstance(fn, triton.runtime.JITFunction):
@@ -1467,8 +1468,8 @@ def compile(fn, **kwargs):
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
ir in metadata["ctime"] and\
os.path.getctime(path) == metadata["ctime"][ir]:
ir in metadata["ctime"] and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
@@ -1504,8 +1505,7 @@ class CompiledKernel:
self.asm = asm
device = torch.cuda.current_device()
global cuda_utils
if cuda_utils is None:
cuda_utils = CudaUtils()
init_cuda_utils()
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
self.cu_module = mod
self.cu_function = func
@@ -1562,6 +1562,34 @@ class CudaUtils(object):
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; }
static PyObject* getDeviceProperties(PyObject* self, PyObject* args){
int device_id;
if(!PyArg_ParseTuple(args, "i", &device_id))
return NULL;
// Get device handle
CUdevice device;
cuDeviceGet(&device, device_id);
// create a struct to hold device properties
int max_shared_mem;
int multiprocessor_count;
int sm_clock_rate;
int mem_clock_rate;
int mem_bus_width;
CUDA_CHECK(cuDeviceGetAttribute(&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, device));
CUDA_CHECK(cuDeviceGetAttribute(&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", max_shared_mem,
"multiprocessor_count", multiprocessor_count,
"sm_clock_rate", sm_clock_rate,
"mem_clock_rate", mem_clock_rate,
"mem_bus_width", mem_bus_width);
}
static PyObject* loadBinary(PyObject* self, PyObject* args) {
const char* name;
const char* data;
@@ -1601,6 +1629,7 @@ class CudaUtils(object):
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"},
{NULL, NULL, 0, NULL} // sentinel
};
@@ -1640,6 +1669,13 @@ class CudaUtils(object):
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
def init_cuda_utils():
global cuda_utils
if cuda_utils is None:
cuda_utils = CudaUtils()
cuda_utils = None

View File

@@ -9,6 +9,7 @@ from triton._C.libtriton.triton import ir
T = TypeVar('T')
def _to_tensor(x, builder):
if isinstance(x, bool):
return tensor(builder.get_int1(x), int1)
@@ -348,6 +349,9 @@ class constexpr:
def __mul__(self, other):
return constexpr(self.value * other.value)
def __mod__(self, other):
return constexpr(self.value % other.value)
def __rmul__(self, other):
return constexpr(other.value * self.value)
@@ -726,10 +730,12 @@ def broadcast_to(input, shape, _builder=None):
"""
return semantic.broadcast_impl_shape(input, shape, _builder)
@builtin
def trans(input, _builder=None):
return semantic.trans(input, _builder)
@builtin
def cat(input, other, can_reorder=False, _builder=None):
"""
@@ -762,6 +768,7 @@ def view(input, shape, _builder=None):
shape = [x.value for x in shape]
return semantic.view(input, shape, _builder)
@builtin
def reshape(input, shape, _builder=None):
# TODO: should be more than just a view

View File

@@ -481,7 +481,8 @@ def view(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
# TODO: disable when TritonToTritonGPU handles views properly
assert len(input.shape) == len(dst_shape)
# assert len(input.shape) == len(dst_shape)
numel = 1
for s in dst_shape:
numel *= s

View File

@@ -26,9 +26,6 @@ def get_configs_io_bound():
return configs
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
@@ -59,6 +56,9 @@ def get_configs_io_bound():
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,

View File

@@ -10,7 +10,9 @@ from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcor
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
triton.compiler.init_cuda_utils()
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
return tflops
@@ -18,14 +20,14 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
return tflops
def get_tflops(backend, device, num_ctas, num_warps, dtype):
cc = _triton.runtime.cc(backend, device)
if cc < 80 and dtype == torch.float32:
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8 and dtype == torch.float32:
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
@@ -59,7 +61,7 @@ def estimate_matmul_time(
compute_ms = total_ops / tput
# time to load data
num_sm = _triton.runtime.num_sm(backend, device)
num_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
active_cta_ratio = min(1, num_ctas / num_sm)
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
@@ -99,7 +101,7 @@ def estimate_matmul_time(
def early_config_prune(configs, named_args):
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device)
capability = torch.cuda.get_device_capability()
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
dtsize = named_args['A'].element_size()
dtype = named_args['A'].dtype
@@ -110,7 +112,10 @@ def early_config_prune(configs, named_args):
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
# TODO: move to `cuda_utils` submodule
triton.compiler.init_cuda_utils()
max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory <= max_shared_memory:
pruned_configs.append(config)
@@ -136,7 +141,7 @@ def early_config_prune(configs, named_args):
pruned_configs = []
for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if cc >= 80:
if capability[0] >= 8:
# compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
mma_cycles = mmas / min(4, num_warps) * 8

View File

@@ -16,6 +16,9 @@ except ImportError:
_cutlass = None
has_cutlass = False
# TODO: move to separate module
import triton
def catch_oor(kernel, pytest_handle=None):
try:
@@ -330,8 +333,8 @@ def get_dram_gbps(backend=None, device=None):
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
mem_clock_khz = triton.compiler.cuda_utils.get_device_properties(device)["mem_clock_rate"] # in kHz
bus_width = triton.compiler.cuda_utils.get_device_properties(device)["mem_bus_width"]
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
return bw_gbps
@@ -341,11 +344,13 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
triton.compiler.init_cuda_utils()
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
if not clock_rate:
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
cc = _triton.runtime.cc(backend, device)
if cc < 80:
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8:
assert dtype == torch.float16
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else: