[BACKEND][FRONTEND] Fix problems with test_matmul (#973)
1. Handle induction variable when step is negative 2. Restore async_wait that accidentally deleted 3. Add missing induction variable in prefetch 4. Add device property functions Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
This commit is contained in:
@@ -4783,10 +4783,15 @@ private:
|
||||
decomposed = true;
|
||||
});
|
||||
|
||||
// async wait is supported in Ampere and later
|
||||
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) ||
|
||||
decomposed) {
|
||||
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
|
||||
// async wait is supported in Ampere and later
|
||||
asyncWaitOp.erase();
|
||||
} else if (decomposed) {
|
||||
// Wait for all previous async ops
|
||||
OpBuilder builder(asyncWaitOp);
|
||||
auto newAsyncWaitOp =
|
||||
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
||||
asyncWaitOp.erase();
|
||||
}
|
||||
});
|
||||
|
@@ -262,10 +262,10 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
||||
// For now, this behaves like generic, but this will evolve when
|
||||
// we add support for `can_reorder=False`
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType, adaptor.getOperands());
|
||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
|
||||
adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
@@ -450,13 +450,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>,
|
||||
TritonCatPattern,
|
||||
TritonReducePattern,
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
|
||||
typeConverter, context);
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
||||
TritonAtomicRMWPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -782,8 +782,8 @@ public:
|
||||
newRetType.getEncoding()));
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
|
||||
auto newDot = rewriter.create<triton::DotOp>(dotOp.getLoc(), newRetType, a,
|
||||
b, newAcc, dotOp.allowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
|
@@ -225,6 +225,7 @@ scf::ForOp Prefetcher::createNewForOp() {
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
Operation *newOp = builder.clone(op, mapping);
|
||||
|
101
python/tests/test_matmul.py
Normal file
101
python/tests/test_matmul.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if capability[0] < 8 and DTYPE == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
#if DTYPE == "bfloat16" and SPLIT_K != 1:
|
||||
# pytest.skip("bfloat16 matmuls don't allow split_k for now")
|
||||
if DTYPE == "bfloat16":
|
||||
pytest.skip("bfloat16 matmuls doesn't support for now")
|
||||
torch.manual_seed(0)
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
||||
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
kernel.configs = configs
|
||||
# kernel.run = kernel.run.run.run
|
||||
|
||||
# get matrix shape
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K * SPLIT_K if K is None else K
|
||||
# allocate/transpose inputs
|
||||
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
|
||||
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
|
||||
triton.testing.assert_almost_equal(th_c, tt_c)
|
@@ -68,7 +68,7 @@ def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size):
|
||||
@num_elements: number of elements
|
||||
'''
|
||||
pid = tl.program_id(axis=0)
|
||||
for i in range(math.ceil(block_size / iter_size)):
|
||||
for i in range(tl.cdiv(block_size, iter_size)):
|
||||
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
|
||||
offset = pid * block_size + tl.arange(0, iter_size)
|
||||
x_ptrs = x_ptr + offset
|
||||
|
@@ -329,10 +329,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_BinOp(self, node):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
fn = {
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
@@ -591,8 +587,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
# handle negative constant step (not supported by scf.for in MLIR)
|
||||
negative_step = False
|
||||
if isinstance(step, triton.language.constexpr) and step.value < 0:
|
||||
step = triton.language.constexpr(-step.value)
|
||||
negative_step = True
|
||||
lb, ub = ub, lb
|
||||
# lb/ub/step might be constexpr, we need to cast them to tensor
|
||||
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||
@@ -640,6 +638,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# update induction variable with actual value, and replace all uses
|
||||
self.builder.set_insertion_point_to_start(for_op.get_body(0))
|
||||
iv = self.builder.create_index_to_si(for_op.get_induction_var())
|
||||
if negative_step:
|
||||
ub_si = self.builder.create_index_to_si(ub)
|
||||
iv = self.builder.create_sub(ub_si, iv)
|
||||
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
||||
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
|
||||
|
||||
@@ -890,9 +891,9 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.enable_debug()
|
||||
# Convert blocked layout to mma layout for dot ops so that pipeline
|
||||
# can get shared memory swizzled correctly.
|
||||
pm.add_coalesce_pass()
|
||||
# The combine pass converts blocked layout to mma layout
|
||||
# for dot ops so that pipeline can get shared memory swizzled correctly.
|
||||
pm.add_triton_gpu_combine_pass(compute_capability)
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
# Prefetch must be done after pipeline pass because pipeline pass
|
||||
@@ -1395,20 +1396,20 @@ def compile(fn, **kwargs):
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
device = kwargs.get("device", torch.cuda.current_device())
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0]*10 + capability[1]
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
# build compilation stages
|
||||
stages = {
|
||||
"ast" : (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, capability))
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, capability))
|
||||
}
|
||||
# find out the signature of the function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
@@ -1467,8 +1468,8 @@ def compile(fn, **kwargs):
|
||||
if ir == ext:
|
||||
next_module = parse(fn)
|
||||
elif os.path.exists(path) and\
|
||||
ir in metadata["ctime"] and\
|
||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||
ir in metadata["ctime"] and\
|
||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||
next_module = parse(path)
|
||||
else:
|
||||
next_module = compile(module)
|
||||
@@ -1504,8 +1505,7 @@ class CompiledKernel:
|
||||
self.asm = asm
|
||||
device = torch.cuda.current_device()
|
||||
global cuda_utils
|
||||
if cuda_utils is None:
|
||||
cuda_utils = CudaUtils()
|
||||
init_cuda_utils()
|
||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
@@ -1562,6 +1562,34 @@ class CudaUtils(object):
|
||||
|
||||
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; }
|
||||
|
||||
static PyObject* getDeviceProperties(PyObject* self, PyObject* args){
|
||||
int device_id;
|
||||
if(!PyArg_ParseTuple(args, "i", &device_id))
|
||||
return NULL;
|
||||
// Get device handle
|
||||
CUdevice device;
|
||||
cuDeviceGet(&device, device_id);
|
||||
|
||||
// create a struct to hold device properties
|
||||
int max_shared_mem;
|
||||
int multiprocessor_count;
|
||||
int sm_clock_rate;
|
||||
int mem_clock_rate;
|
||||
int mem_bus_width;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||
|
||||
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", max_shared_mem,
|
||||
"multiprocessor_count", multiprocessor_count,
|
||||
"sm_clock_rate", sm_clock_rate,
|
||||
"mem_clock_rate", mem_clock_rate,
|
||||
"mem_bus_width", mem_bus_width);
|
||||
}
|
||||
|
||||
static PyObject* loadBinary(PyObject* self, PyObject* args) {
|
||||
const char* name;
|
||||
const char* data;
|
||||
@@ -1601,6 +1629,7 @@ class CudaUtils(object):
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"},
|
||||
{"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
@@ -1640,6 +1669,13 @@ class CudaUtils(object):
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
self.get_device_properties = mod.get_device_properties
|
||||
|
||||
|
||||
def init_cuda_utils():
|
||||
global cuda_utils
|
||||
if cuda_utils is None:
|
||||
cuda_utils = CudaUtils()
|
||||
|
||||
|
||||
cuda_utils = None
|
||||
|
@@ -9,6 +9,7 @@ from triton._C.libtriton.triton import ir
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def _to_tensor(x, builder):
|
||||
if isinstance(x, bool):
|
||||
return tensor(builder.get_int1(x), int1)
|
||||
@@ -348,6 +349,9 @@ class constexpr:
|
||||
def __mul__(self, other):
|
||||
return constexpr(self.value * other.value)
|
||||
|
||||
def __mod__(self, other):
|
||||
return constexpr(self.value % other.value)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return constexpr(other.value * self.value)
|
||||
|
||||
@@ -726,10 +730,12 @@ def broadcast_to(input, shape, _builder=None):
|
||||
"""
|
||||
return semantic.broadcast_impl_shape(input, shape, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def trans(input, _builder=None):
|
||||
return semantic.trans(input, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def cat(input, other, can_reorder=False, _builder=None):
|
||||
"""
|
||||
@@ -762,6 +768,7 @@ def view(input, shape, _builder=None):
|
||||
shape = [x.value for x in shape]
|
||||
return semantic.view(input, shape, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
# TODO: should be more than just a view
|
||||
|
@@ -481,7 +481,8 @@ def view(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
# TODO: disable when TritonToTritonGPU handles views properly
|
||||
assert len(input.shape) == len(dst_shape)
|
||||
|
||||
# assert len(input.shape) == len(dst_shape)
|
||||
numel = 1
|
||||
for s in dst_shape:
|
||||
numel *= s
|
||||
|
@@ -26,9 +26,6 @@ def get_configs_io_bound():
|
||||
return configs
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
@@ -59,6 +56,9 @@ def get_configs_io_bound():
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
|
@@ -10,7 +10,9 @@ from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcor
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
triton.compiler.init_cuda_utils()
|
||||
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
@@ -18,14 +20,14 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80 and dtype == torch.float32:
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
if capability[0] < 8 and dtype == torch.float32:
|
||||
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
|
||||
@@ -59,7 +61,7 @@ def estimate_matmul_time(
|
||||
compute_ms = total_ops / tput
|
||||
|
||||
# time to load data
|
||||
num_sm = _triton.runtime.num_sm(backend, device)
|
||||
num_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
|
||||
active_cta_ratio = min(1, num_ctas / num_sm)
|
||||
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
|
||||
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
|
||||
@@ -99,7 +101,7 @@ def estimate_matmul_time(
|
||||
def early_config_prune(configs, named_args):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||
dtsize = named_args['A'].element_size()
|
||||
dtype = named_args['A'].dtype
|
||||
@@ -110,7 +112,10 @@ def early_config_prune(configs, named_args):
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
|
||||
# TODO: move to `cuda_utils` submodule
|
||||
triton.compiler.init_cuda_utils()
|
||||
max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"]
|
||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||
if required_shared_memory <= max_shared_memory:
|
||||
pruned_configs.append(config)
|
||||
@@ -136,7 +141,7 @@ def early_config_prune(configs, named_args):
|
||||
pruned_configs = []
|
||||
for k, v in configs_map.items():
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
||||
if cc >= 80:
|
||||
if capability[0] >= 8:
|
||||
# compute cycles (only works for ampere GPUs)
|
||||
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
|
||||
mma_cycles = mmas / min(4, num_warps) * 8
|
||||
|
@@ -16,6 +16,9 @@ except ImportError:
|
||||
_cutlass = None
|
||||
has_cutlass = False
|
||||
|
||||
# TODO: move to separate module
|
||||
import triton
|
||||
|
||||
|
||||
def catch_oor(kernel, pytest_handle=None):
|
||||
try:
|
||||
@@ -330,8 +333,8 @@ def get_dram_gbps(backend=None, device=None):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
|
||||
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
|
||||
mem_clock_khz = triton.compiler.cuda_utils.get_device_properties(device)["mem_clock_rate"] # in kHz
|
||||
bus_width = triton.compiler.cuda_utils.get_device_properties(device)["mem_bus_width"]
|
||||
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
|
||||
return bw_gbps
|
||||
|
||||
@@ -341,11 +344,13 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
|
||||
triton.compiler.init_cuda_utils()
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
|
||||
if not clock_rate:
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80:
|
||||
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
if capability[0] < 8:
|
||||
assert dtype == torch.float16
|
||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||
else:
|
||||
|
Reference in New Issue
Block a user