[BACKEND][FRONTEND] Fix problems with test_matmul (#973)
1. Handle induction variable when step is negative 2. Restore async_wait that accidentally deleted 3. Add missing induction variable in prefetch 4. Add device property functions Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
This commit is contained in:
@@ -4783,10 +4783,15 @@ private:
|
|||||||
decomposed = true;
|
decomposed = true;
|
||||||
});
|
});
|
||||||
|
|
||||||
// async wait is supported in Ampere and later
|
|
||||||
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||||
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) ||
|
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
|
||||||
decomposed) {
|
// async wait is supported in Ampere and later
|
||||||
|
asyncWaitOp.erase();
|
||||||
|
} else if (decomposed) {
|
||||||
|
// Wait for all previous async ops
|
||||||
|
OpBuilder builder(asyncWaitOp);
|
||||||
|
auto newAsyncWaitOp =
|
||||||
|
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
||||||
asyncWaitOp.erase();
|
asyncWaitOp.erase();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@@ -262,10 +262,10 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
|||||||
// For now, this behaves like generic, but this will evolve when
|
// For now, this behaves like generic, but this will evolve when
|
||||||
// we add support for `can_reorder=False`
|
// we add support for `can_reorder=False`
|
||||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType, adaptor.getOperands());
|
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
|
||||||
|
adaptor.getOperands());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||||
@@ -450,13 +450,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|||||||
TritonGenericPattern<triton::IntToPtrOp>,
|
TritonGenericPattern<triton::IntToPtrOp>,
|
||||||
TritonGenericPattern<triton::PtrToIntOp>,
|
TritonGenericPattern<triton::PtrToIntOp>,
|
||||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||||
TritonGenericPattern<triton::AddPtrOp>,
|
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||||
TritonCatPattern,
|
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||||
TritonReducePattern,
|
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
||||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
TritonAtomicRMWPattern>(typeConverter, context);
|
||||||
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
|
|
||||||
typeConverter, context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@@ -782,8 +782,8 @@ public:
|
|||||||
newRetType.getEncoding()));
|
newRetType.getEncoding()));
|
||||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||||
auto newDot = rewriter.create<triton::DotOp>(
|
auto newDot = rewriter.create<triton::DotOp>(dotOp.getLoc(), newRetType, a,
|
||||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
|
b, newAcc, dotOp.allowTF32());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||||
op, oldRetType, newDot.getResult());
|
op, oldRetType, newDot.getResult());
|
||||||
|
@@ -225,6 +225,7 @@ scf::ForOp Prefetcher::createNewForOp() {
|
|||||||
BlockAndValueMapping mapping;
|
BlockAndValueMapping mapping;
|
||||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||||
|
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||||
|
|
||||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||||
Operation *newOp = builder.clone(op, mapping);
|
Operation *newOp = builder.clone(op, mapping);
|
||||||
|
101
python/tests/test_matmul.py
Normal file
101
python/tests/test_matmul.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton._C.libtriton.triton as _triton
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||||
|
itertools.chain(
|
||||||
|
*[
|
||||||
|
[
|
||||||
|
# 1 warp
|
||||||
|
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
# 2 warp
|
||||||
|
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
# 4 warp
|
||||||
|
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
# 8 warp
|
||||||
|
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
# split-k
|
||||||
|
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
|
# variable input
|
||||||
|
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
|
||||||
|
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
|
||||||
|
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
|
||||||
|
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||||
|
],
|
||||||
|
# n-stage
|
||||||
|
*[
|
||||||
|
[
|
||||||
|
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
|
||||||
|
# split-k
|
||||||
|
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
|
||||||
|
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
if capability[0] < 7:
|
||||||
|
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||||
|
if capability[0] < 8 and DTYPE == "bfloat16":
|
||||||
|
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||||
|
#if DTYPE == "bfloat16" and SPLIT_K != 1:
|
||||||
|
# pytest.skip("bfloat16 matmuls don't allow split_k for now")
|
||||||
|
if DTYPE == "bfloat16":
|
||||||
|
pytest.skip("bfloat16 matmuls doesn't support for now")
|
||||||
|
torch.manual_seed(0)
|
||||||
|
# nuke kernel decorators -- will set meta-parameters manually
|
||||||
|
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||||
|
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
||||||
|
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
||||||
|
kernel = triton.ops._matmul.kernel
|
||||||
|
kernel.configs = configs
|
||||||
|
# kernel.run = kernel.run.run.run
|
||||||
|
|
||||||
|
# get matrix shape
|
||||||
|
M = BLOCK_M if M is None else M
|
||||||
|
N = BLOCK_N if N is None else N
|
||||||
|
K = BLOCK_K * SPLIT_K if K is None else K
|
||||||
|
# allocate/transpose inputs
|
||||||
|
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
|
||||||
|
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||||
|
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||||
|
a = a.t() if AT else a
|
||||||
|
b = b.t() if BT else b
|
||||||
|
# run test
|
||||||
|
th_c = torch.matmul(a, b)
|
||||||
|
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
|
||||||
|
triton.testing.assert_almost_equal(th_c, tt_c)
|
@@ -68,7 +68,7 @@ def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size):
|
|||||||
@num_elements: number of elements
|
@num_elements: number of elements
|
||||||
'''
|
'''
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
for i in range(math.ceil(block_size / iter_size)):
|
for i in range(tl.cdiv(block_size, iter_size)):
|
||||||
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
|
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
|
||||||
offset = pid * block_size + tl.arange(0, iter_size)
|
offset = pid * block_size + tl.arange(0, iter_size)
|
||||||
x_ptrs = x_ptr + offset
|
x_ptrs = x_ptr + offset
|
||||||
|
@@ -329,10 +329,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
def visit_BinOp(self, node):
|
def visit_BinOp(self, node):
|
||||||
lhs = self.visit(node.left)
|
lhs = self.visit(node.left)
|
||||||
rhs = self.visit(node.right)
|
rhs = self.visit(node.right)
|
||||||
if isinstance(lhs, triton.language.constexpr):
|
|
||||||
lhs = lhs.value
|
|
||||||
if isinstance(rhs, triton.language.constexpr):
|
|
||||||
rhs = rhs.value
|
|
||||||
fn = {
|
fn = {
|
||||||
ast.Add: '__add__',
|
ast.Add: '__add__',
|
||||||
ast.Sub: '__sub__',
|
ast.Sub: '__sub__',
|
||||||
@@ -591,8 +587,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ast.NodeVisitor.generic_visit(self, stmt)
|
ast.NodeVisitor.generic_visit(self, stmt)
|
||||||
return
|
return
|
||||||
# handle negative constant step (not supported by scf.for in MLIR)
|
# handle negative constant step (not supported by scf.for in MLIR)
|
||||||
|
negative_step = False
|
||||||
if isinstance(step, triton.language.constexpr) and step.value < 0:
|
if isinstance(step, triton.language.constexpr) and step.value < 0:
|
||||||
step = triton.language.constexpr(-step.value)
|
step = triton.language.constexpr(-step.value)
|
||||||
|
negative_step = True
|
||||||
lb, ub = ub, lb
|
lb, ub = ub, lb
|
||||||
# lb/ub/step might be constexpr, we need to cast them to tensor
|
# lb/ub/step might be constexpr, we need to cast them to tensor
|
||||||
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||||
@@ -640,6 +638,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
# update induction variable with actual value, and replace all uses
|
# update induction variable with actual value, and replace all uses
|
||||||
self.builder.set_insertion_point_to_start(for_op.get_body(0))
|
self.builder.set_insertion_point_to_start(for_op.get_body(0))
|
||||||
iv = self.builder.create_index_to_si(for_op.get_induction_var())
|
iv = self.builder.create_index_to_si(for_op.get_induction_var())
|
||||||
|
if negative_step:
|
||||||
|
ub_si = self.builder.create_index_to_si(ub)
|
||||||
|
iv = self.builder.create_sub(ub_si, iv)
|
||||||
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
||||||
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
|
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
|
||||||
|
|
||||||
@@ -890,9 +891,9 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
|||||||
pm = _triton.ir.pass_manager(mod.context)
|
pm = _triton.ir.pass_manager(mod.context)
|
||||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||||
pm.enable_debug()
|
pm.enable_debug()
|
||||||
# Convert blocked layout to mma layout for dot ops so that pipeline
|
|
||||||
# can get shared memory swizzled correctly.
|
|
||||||
pm.add_coalesce_pass()
|
pm.add_coalesce_pass()
|
||||||
|
# The combine pass converts blocked layout to mma layout
|
||||||
|
# for dot ops so that pipeline can get shared memory swizzled correctly.
|
||||||
pm.add_triton_gpu_combine_pass(compute_capability)
|
pm.add_triton_gpu_combine_pass(compute_capability)
|
||||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||||
# Prefetch must be done after pipeline pass because pipeline pass
|
# Prefetch must be done after pipeline pass because pipeline pass
|
||||||
@@ -1395,20 +1396,20 @@ def compile(fn, **kwargs):
|
|||||||
extern_libs = kwargs.get("extern_libs", dict())
|
extern_libs = kwargs.get("extern_libs", dict())
|
||||||
device = kwargs.get("device", torch.cuda.current_device())
|
device = kwargs.get("device", torch.cuda.current_device())
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
capability = capability[0]*10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
# build compilation stages
|
# build compilation stages
|
||||||
stages = {
|
stages = {
|
||||||
"ast" : (lambda path: fn, None),
|
"ast": (lambda path: fn, None),
|
||||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||||
"llir": (lambda path: Path(path).read_bytes(),
|
"llir": (lambda path: Path(path).read_bytes(),
|
||||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||||
"ptx": (lambda path: Path(path).read_text(),
|
"ptx": (lambda path: Path(path).read_text(),
|
||||||
lambda src: llir_to_ptx(src, capability)),
|
lambda src: llir_to_ptx(src, capability)),
|
||||||
"cubin": (lambda path: Path(path).read_bytes(),
|
"cubin": (lambda path: Path(path).read_bytes(),
|
||||||
lambda src: ptx_to_cubin(src, capability))
|
lambda src: ptx_to_cubin(src, capability))
|
||||||
}
|
}
|
||||||
# find out the signature of the function
|
# find out the signature of the function
|
||||||
if isinstance(fn, triton.runtime.JITFunction):
|
if isinstance(fn, triton.runtime.JITFunction):
|
||||||
@@ -1467,8 +1468,8 @@ def compile(fn, **kwargs):
|
|||||||
if ir == ext:
|
if ir == ext:
|
||||||
next_module = parse(fn)
|
next_module = parse(fn)
|
||||||
elif os.path.exists(path) and\
|
elif os.path.exists(path) and\
|
||||||
ir in metadata["ctime"] and\
|
ir in metadata["ctime"] and\
|
||||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||||
next_module = parse(path)
|
next_module = parse(path)
|
||||||
else:
|
else:
|
||||||
next_module = compile(module)
|
next_module = compile(module)
|
||||||
@@ -1504,8 +1505,7 @@ class CompiledKernel:
|
|||||||
self.asm = asm
|
self.asm = asm
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
global cuda_utils
|
global cuda_utils
|
||||||
if cuda_utils is None:
|
init_cuda_utils()
|
||||||
cuda_utils = CudaUtils()
|
|
||||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||||
self.cu_module = mod
|
self.cu_module = mod
|
||||||
self.cu_function = func
|
self.cu_function = func
|
||||||
@@ -1562,6 +1562,34 @@ class CudaUtils(object):
|
|||||||
|
|
||||||
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; }
|
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; }
|
||||||
|
|
||||||
|
static PyObject* getDeviceProperties(PyObject* self, PyObject* args){
|
||||||
|
int device_id;
|
||||||
|
if(!PyArg_ParseTuple(args, "i", &device_id))
|
||||||
|
return NULL;
|
||||||
|
// Get device handle
|
||||||
|
CUdevice device;
|
||||||
|
cuDeviceGet(&device, device_id);
|
||||||
|
|
||||||
|
// create a struct to hold device properties
|
||||||
|
int max_shared_mem;
|
||||||
|
int multiprocessor_count;
|
||||||
|
int sm_clock_rate;
|
||||||
|
int mem_clock_rate;
|
||||||
|
int mem_bus_width;
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, device));
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||||
|
|
||||||
|
|
||||||
|
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", max_shared_mem,
|
||||||
|
"multiprocessor_count", multiprocessor_count,
|
||||||
|
"sm_clock_rate", sm_clock_rate,
|
||||||
|
"mem_clock_rate", mem_clock_rate,
|
||||||
|
"mem_bus_width", mem_bus_width);
|
||||||
|
}
|
||||||
|
|
||||||
static PyObject* loadBinary(PyObject* self, PyObject* args) {
|
static PyObject* loadBinary(PyObject* self, PyObject* args) {
|
||||||
const char* name;
|
const char* name;
|
||||||
const char* data;
|
const char* data;
|
||||||
@@ -1601,6 +1629,7 @@ class CudaUtils(object):
|
|||||||
|
|
||||||
static PyMethodDef ModuleMethods[] = {
|
static PyMethodDef ModuleMethods[] = {
|
||||||
{"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"},
|
{"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"},
|
||||||
|
{"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"},
|
||||||
{NULL, NULL, 0, NULL} // sentinel
|
{NULL, NULL, 0, NULL} // sentinel
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1640,6 +1669,13 @@ class CudaUtils(object):
|
|||||||
mod = importlib.util.module_from_spec(spec)
|
mod = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(mod)
|
spec.loader.exec_module(mod)
|
||||||
self.load_binary = mod.load_binary
|
self.load_binary = mod.load_binary
|
||||||
|
self.get_device_properties = mod.get_device_properties
|
||||||
|
|
||||||
|
|
||||||
|
def init_cuda_utils():
|
||||||
|
global cuda_utils
|
||||||
|
if cuda_utils is None:
|
||||||
|
cuda_utils = CudaUtils()
|
||||||
|
|
||||||
|
|
||||||
cuda_utils = None
|
cuda_utils = None
|
||||||
|
@@ -9,6 +9,7 @@ from triton._C.libtriton.triton import ir
|
|||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
def _to_tensor(x, builder):
|
def _to_tensor(x, builder):
|
||||||
if isinstance(x, bool):
|
if isinstance(x, bool):
|
||||||
return tensor(builder.get_int1(x), int1)
|
return tensor(builder.get_int1(x), int1)
|
||||||
@@ -348,6 +349,9 @@ class constexpr:
|
|||||||
def __mul__(self, other):
|
def __mul__(self, other):
|
||||||
return constexpr(self.value * other.value)
|
return constexpr(self.value * other.value)
|
||||||
|
|
||||||
|
def __mod__(self, other):
|
||||||
|
return constexpr(self.value % other.value)
|
||||||
|
|
||||||
def __rmul__(self, other):
|
def __rmul__(self, other):
|
||||||
return constexpr(other.value * self.value)
|
return constexpr(other.value * self.value)
|
||||||
|
|
||||||
@@ -726,10 +730,12 @@ def broadcast_to(input, shape, _builder=None):
|
|||||||
"""
|
"""
|
||||||
return semantic.broadcast_impl_shape(input, shape, _builder)
|
return semantic.broadcast_impl_shape(input, shape, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def trans(input, _builder=None):
|
def trans(input, _builder=None):
|
||||||
return semantic.trans(input, _builder)
|
return semantic.trans(input, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def cat(input, other, can_reorder=False, _builder=None):
|
def cat(input, other, can_reorder=False, _builder=None):
|
||||||
"""
|
"""
|
||||||
@@ -762,6 +768,7 @@ def view(input, shape, _builder=None):
|
|||||||
shape = [x.value for x in shape]
|
shape = [x.value for x in shape]
|
||||||
return semantic.view(input, shape, _builder)
|
return semantic.view(input, shape, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def reshape(input, shape, _builder=None):
|
def reshape(input, shape, _builder=None):
|
||||||
# TODO: should be more than just a view
|
# TODO: should be more than just a view
|
||||||
|
@@ -481,7 +481,8 @@ def view(input: tl.tensor,
|
|||||||
dst_shape: List[int],
|
dst_shape: List[int],
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
# TODO: disable when TritonToTritonGPU handles views properly
|
# TODO: disable when TritonToTritonGPU handles views properly
|
||||||
assert len(input.shape) == len(dst_shape)
|
|
||||||
|
# assert len(input.shape) == len(dst_shape)
|
||||||
numel = 1
|
numel = 1
|
||||||
for s in dst_shape:
|
for s in dst_shape:
|
||||||
numel *= s
|
numel *= s
|
||||||
|
@@ -26,9 +26,6 @@ def get_configs_io_bound():
|
|||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({
|
|
||||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
|
||||||
})
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
# basic configs for compute-bound matmuls
|
# basic configs for compute-bound matmuls
|
||||||
@@ -59,6 +56,9 @@ def get_configs_io_bound():
|
|||||||
'top_k': 10
|
'top_k': 10
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@triton.heuristics({
|
||||||
|
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||||
|
})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _kernel(A, B, C, M, N, K,
|
def _kernel(A, B, C, M, N, K,
|
||||||
stride_am, stride_ak,
|
stride_am, stride_ak,
|
||||||
|
@@ -10,7 +10,9 @@ from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcor
|
|||||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||||
''' return compute throughput in TOPS '''
|
''' return compute throughput in TOPS '''
|
||||||
total_warps = num_ctas * min(num_warps, 4)
|
total_warps = num_ctas * min(num_warps, 4)
|
||||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
triton.compiler.init_cuda_utils()
|
||||||
|
|
||||||
|
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
|
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
|
||||||
return tflops
|
return tflops
|
||||||
|
|
||||||
@@ -18,14 +20,14 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
|||||||
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||||
''' return compute throughput in TOPS '''
|
''' return compute throughput in TOPS '''
|
||||||
total_warps = num_ctas * min(num_warps, 4)
|
total_warps = num_ctas * min(num_warps, 4)
|
||||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
|
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
|
||||||
return tflops
|
return tflops
|
||||||
|
|
||||||
|
|
||||||
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||||
cc = _triton.runtime.cc(backend, device)
|
capability = torch.cuda.get_device_capability(device)
|
||||||
if cc < 80 and dtype == torch.float32:
|
if capability[0] < 8 and dtype == torch.float32:
|
||||||
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
|
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||||
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||||
|
|
||||||
@@ -59,7 +61,7 @@ def estimate_matmul_time(
|
|||||||
compute_ms = total_ops / tput
|
compute_ms = total_ops / tput
|
||||||
|
|
||||||
# time to load data
|
# time to load data
|
||||||
num_sm = _triton.runtime.num_sm(backend, device)
|
num_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
|
||||||
active_cta_ratio = min(1, num_ctas / num_sm)
|
active_cta_ratio = min(1, num_ctas / num_sm)
|
||||||
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
|
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
|
||||||
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
|
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
|
||||||
@@ -99,7 +101,7 @@ def estimate_matmul_time(
|
|||||||
def early_config_prune(configs, named_args):
|
def early_config_prune(configs, named_args):
|
||||||
backend = _triton.runtime.backend.CUDA
|
backend = _triton.runtime.backend.CUDA
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
cc = _triton.runtime.cc(backend, device)
|
capability = torch.cuda.get_device_capability()
|
||||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||||
dtsize = named_args['A'].element_size()
|
dtsize = named_args['A'].element_size()
|
||||||
dtype = named_args['A'].dtype
|
dtype = named_args['A'].dtype
|
||||||
@@ -110,7 +112,10 @@ def early_config_prune(configs, named_args):
|
|||||||
kw = config.kwargs
|
kw = config.kwargs
|
||||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
|
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
|
||||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
|
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
|
||||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
|
||||||
|
# TODO: move to `cuda_utils` submodule
|
||||||
|
triton.compiler.init_cuda_utils()
|
||||||
|
max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"]
|
||||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||||
if required_shared_memory <= max_shared_memory:
|
if required_shared_memory <= max_shared_memory:
|
||||||
pruned_configs.append(config)
|
pruned_configs.append(config)
|
||||||
@@ -136,7 +141,7 @@ def early_config_prune(configs, named_args):
|
|||||||
pruned_configs = []
|
pruned_configs = []
|
||||||
for k, v in configs_map.items():
|
for k, v in configs_map.items():
|
||||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
||||||
if cc >= 80:
|
if capability[0] >= 8:
|
||||||
# compute cycles (only works for ampere GPUs)
|
# compute cycles (only works for ampere GPUs)
|
||||||
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
|
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
|
||||||
mma_cycles = mmas / min(4, num_warps) * 8
|
mma_cycles = mmas / min(4, num_warps) * 8
|
||||||
|
@@ -16,6 +16,9 @@ except ImportError:
|
|||||||
_cutlass = None
|
_cutlass = None
|
||||||
has_cutlass = False
|
has_cutlass = False
|
||||||
|
|
||||||
|
# TODO: move to separate module
|
||||||
|
import triton
|
||||||
|
|
||||||
|
|
||||||
def catch_oor(kernel, pytest_handle=None):
|
def catch_oor(kernel, pytest_handle=None):
|
||||||
try:
|
try:
|
||||||
@@ -330,8 +333,8 @@ def get_dram_gbps(backend=None, device=None):
|
|||||||
backend = _triton.runtime.backend.CUDA
|
backend = _triton.runtime.backend.CUDA
|
||||||
if not device:
|
if not device:
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
|
mem_clock_khz = triton.compiler.cuda_utils.get_device_properties(device)["mem_clock_rate"] # in kHz
|
||||||
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
|
bus_width = triton.compiler.cuda_utils.get_device_properties(device)["mem_bus_width"]
|
||||||
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
|
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
|
||||||
return bw_gbps
|
return bw_gbps
|
||||||
|
|
||||||
@@ -341,11 +344,13 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo
|
|||||||
backend = _triton.runtime.backend.CUDA
|
backend = _triton.runtime.backend.CUDA
|
||||||
if not device:
|
if not device:
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
|
||||||
|
triton.compiler.init_cuda_utils()
|
||||||
|
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
|
||||||
if not clock_rate:
|
if not clock_rate:
|
||||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
|
||||||
cc = _triton.runtime.cc(backend, device)
|
capability = torch.cuda.get_device_capability(device)
|
||||||
if cc < 80:
|
if capability[0] < 8:
|
||||||
assert dtype == torch.float16
|
assert dtype == torch.float16
|
||||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user