Deprecation of Triton-C and Replacement by decorated Python functions (#86)

This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes.

See documentations for more information on the new API
This commit is contained in:
Philippe Tillet
2021-04-20 22:29:40 -04:00
committed by Philippe Tillet
parent 1fdb465b71
commit 39f4730305
91 changed files with 4500 additions and 13008 deletions

View File

@@ -49,11 +49,11 @@ class CMakeBuild(build_ext):
self.build_extension(ext)
def build_extension(self, ext):
# self.debug = True
self.debug = False
#self.debug = True
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm")
build_suffix = 'debug' if self.debug else 'release'
llvm_build_dir = os.path.join(tempfile.gettempdir(), f"llvm-{build_suffix}")
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
if not os.path.exists(llvm_build_dir):

676
python/src/functions.h Normal file
View File

@@ -0,0 +1,676 @@
#include "triton/ir/builder.h"
#include <functional>
#include <iostream>
#include <pybind11/pybind11.h>
namespace ir = triton::ir;
namespace py = pybind11;
static const std::string _builder_doc = R"pbdoc(
:param builder: IR builder to generate code into, optional, set automatically when called inside a @triton.jit function
:type builder: triton.ir.builder
)pbdoc";
#define VA_ARGS(...) , ##__VA_ARGS__
#define DEF_FUNC(MOD, PY_NAME, C_FUNC, ...) \
MOD.def(PY_NAME, C_FUNC, (C_FUNC##_docstr + _builder_doc).c_str(), \
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
void throw_not_implemented(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. This is likely a bug on our side.");
}
void throw_not_int_or_float(std::string key) {
throw std::runtime_error("`" + key + "` only supported for integer and floating point types.");
}
enum type_code {
_bool,
int8,
int16,
int32,
int64,
float16,
float32,
float64
};
ir::type *make_ir(type_code ty, ir::builder *builder) {
switch (ty) {
case float16:
return builder->get_half_ty();
case float32:
return builder->get_float_ty();
default:
throw_not_implemented("make_ir");
}
}
type_code from_ir(ir::type *ty) {
if (ty->is_half_ty())
return float16;
if (ty->is_float_ty())
return float32;
throw_not_implemented("from_ir");
}
/*----------------------------------------------
definition of triton.cast / triton.ir.value.to
----------------------------------------------*/
std::string cast_docstr = R"pbdoc(
Tries to cast a block to a new data type.
:param input: The input block.
:type input: triton.ir.value
)pbdoc";
ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
ir::type *src_ty = input->get_type();
ir::type *dst_ty = make_ir(_dtype, builder);
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
// FP Truncation
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp)
return builder->create_fp_trunc(input, dst_ty);
// FP Extension
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
if (ext_fp)
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
return builder->create_int_cast(input, dst_ty, true);
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty())
return builder->create_fp_to_si(input, dst_ty);
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty())
return builder->create_si_to_fp(input, dst_ty);
// Ptr -> Ptr
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::BitCast, input, dst_ty);
// * -> Bool
if (dst_sca_ty->is_bool_ty()) {
if (src_sca_ty->is_pointer_ty())
input = cast(input, int64, builder);
ir::value *other = builder->get_int64(0);
if (src_ty->is_bool_ty())
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
throw_not_implemented("cast");
}
/*----------------------------------------------
definition of triton.broadcast_check
----------------------------------------------*/
std::string try_broadcast_docstr = R"pbdoc(
Tries to broadcast two blocks to a common compatible shape.
:param input: The first input block.
:type input: triton.ir.value
:param other: The second input block.
:type other: triton.ir.value
)pbdoc";
std::tuple<ir::value *, ir::value *> try_broadcast(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
// make_shape_compatible(block, scalar)
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
// make_shape_compatible(scalar, block)
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
// make_shape_compatible(block, block)
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
auto lhs_shape = lhs_ty->get_block_shapes();
auto rhs_shape = rhs_ty->get_block_shapes();
if (lhs_shape.size() != rhs_shape.size())
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
ir::type::block_shapes_t ret_shape;
for (size_t i = 0; i < lhs_shape.size(); ++i) {
unsigned left = lhs_shape[i];
unsigned right = rhs_shape[i];
if (left == 1)
ret_shape.push_back(right);
else if (right == 1)
ret_shape.push_back(left);
else if (left == right)
ret_shape.push_back(left);
else
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
": " + std::to_string(left) + " and " + std::to_string(right));
}
if (lhs_shape != ret_shape)
lhs = builder->create_broadcast(lhs, ret_shape);
if (rhs_shape != ret_shape)
rhs = builder->create_broadcast(rhs, ret_shape);
}
return std::make_tuple(lhs, rhs);
}
/*----------------------------------------------
definition of triton.broadcast_to
----------------------------------------------*/
std::string broadcast_to_docstr = R"pbdoc(
Tries to broadcast a block to a new shape.
:param input: The input block.
:type input: triton.value
:param shape: The new shape.
:type shape: tuple of int
)pbdoc";
ir::value *broadcast_to(ir::value *input, const ir::type::block_shapes_t &shape, ir::builder *builder) {
if (!input->get_type()->is_block_ty())
return builder->create_splat(input, shape);
auto src_shape = input->get_type()->get_block_shapes();
if (src_shape.size() != shape.size())
throw std::runtime_error("Cannot broadcast");
return builder->create_broadcast(input, shape);
}
/*----------------------------------------------
definition of triton.load
----------------------------------------------*/
std::string load_docstr = R"pbdoc(
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
:param pointer: Pointer to the data to be loaded.
:type pointer: Block of triton.pointer
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
:type mask: Block of triton.bool, optional
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
:type other: Block of triton.value, optional
)pbdoc";
ir::value *load(ir::value *pointer, std::optional<ir::value *> _mask, std::optional<ir::value *> _other, ir::builder *builder) {
if (!_mask.has_value() && !_other.has_value())
return builder->create_load(pointer);
if (!_mask.has_value())
throw std::runtime_error("`other` cannot be provided without `mask`");
ir::value *mask = _mask.value();
ir::type *elt_ty = pointer->get_type()->get_scalar_ty()->get_pointer_element_ty();
auto shape = pointer->get_type()->get_block_shapes();
ir::value *other = _other.has_value() ? _other.value() : ir::undef_value::get(elt_ty);
other = cast(other, from_ir(elt_ty), builder);
other = broadcast_to(other, shape, builder);
mask = broadcast_to(mask, shape, builder);
return builder->create_masked_load(pointer, mask, other);
}
/*----------------------------------------------
definition of triton.store
----------------------------------------------*/
std::string store_docstr = R"pbdoc(
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
:param pointer: The memory locations where the elements of `value` are stored.
:type pointer: Block of triton.pointer
:param value: The block of elements to be stored.
:type value: Block of triton.value
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
:type mask: Block of triton.bool, optional
)pbdoc";
ir::value *store(ir::value *ptr, ir::value *val, std::optional<ir::value *> _mask, ir::builder *builder) {
if (!_mask.has_value())
return builder->create_store(ptr, val);
ir::value *mask = _mask.value();
return builder->create_masked_store(ptr, val, mask);
}
/*----------------------------------------------
definition of triton.dot
----------------------------------------------*/
std::string dot_docstr = R"pbdoc(
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {`float16`, `float32`}
:param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {`float16`, `float32`}
)pbdoc";
ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
ir::value *_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});
return builder->create_dot(lhs, rhs, _0);
}
/*----------------------------------------------
definition of triton.where
----------------------------------------------*/
std::string where_docstr = R"pbdoc(
Returns a block of elements from either `x` or `y`, depending on `condition`.
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
:param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool
:param x: values selected at indices where condition is True.
:param y: values selected at indices where condition is False.
)pbdoc";
ir::value *where(ir::value *condition, ir::value *x, ir::value *y, ir::builder *builder) {
return builder->create_select(condition, x, y);
};
/*----------------------------------------------
definition of triton.arange
----------------------------------------------*/
std::string arange_docstr = R"pbdoc(
Returns contiguous values within the open interval [start, end).
:param start: Start of the interval.
:type start: int
:param stop: End of the interval.
:type stop: int
)pbdoc";
ir::value *arange(int start, int end, ir::builder *builder) {
return builder->get_range(start, end);
};
/*----------------------------------------------
definition of triton.program_id
----------------------------------------------*/
std::string program_id_docstr = R"pbdoc(
Returns the id of the current program instance along the given `axis`.
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
)pbdoc";
ir::value *program_id(int axis, ir::builder *builder) {
return builder->create_get_program_id(axis);
};
/*----------------------------------------------
definition of triton.num_programs
----------------------------------------------*/
std::string num_programs_docstr = R"pbdoc(
Returns the number of program instances launched along the given `axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
)pbdoc";
ir::value *num_programs(int axis, ir::builder *builder) {
return builder->create_get_num_programs(axis);
};
/*----------------------------------------------
definition of triton.zeros
----------------------------------------------*/
std::string zeros_docstr = R"pbdoc(
Returns a block filled with the scalar value 0 and the given shape.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., triton.float16
:type dtype: triton.ir.dtype
)pbdoc";
ir::value *zeros(ir::type::block_shapes_t shape, type_code _dtype, ir::builder *builder) {
ir::type *dtype = make_ir(_dtype, builder);
ir::value *_0 = ir::constant::get_null_value(dtype);
return builder->create_splat(_0, shape);
};
/*----------------------------------------------
definition of triton.exp
----------------------------------------------*/
std::string _exp_docstr = R"pbdoc(
Returns the element-wise exponential of `input`.
)pbdoc";
ir::value *_exp(ir::value *input, ir::builder *builder) {
return builder->create_exp(input);
};
/*----------------------------------------------
definition of triton.log
----------------------------------------------*/
std::string _log_docstr = R"pbdoc(
Returns the element-wise natural logarithm of `input`.
)pbdoc";
ir::value *_log(ir::value *input, ir::builder *builder) {
return builder->create_log(input);
};
/*----------------------------------------------
definition of triton.sqrt
----------------------------------------------*/
std::string sqrt_docstr = R"pbdoc(
Returns the element-wise square root of `input`.
)pbdoc";
ir::value *sqrt(ir::value *input, ir::builder *builder) {
return builder->create_sqrt(input);
};
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
if (scalar_ty->is_floating_point_ty())
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
else
throw_not_int_or_float(name);
}
std::string min_docstr = R"pbdoc(
Returns the minimum value of `input`.
)pbdoc";
ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
};
/*----------------------------------------------
definition of triton.max
----------------------------------------------*/
std::string max_docstr = R"pbdoc(
Returns the maximum value of `input`.
)pbdoc";
ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
};
/*----------------------------------------------
definition of triton.sum
----------------------------------------------*/
std::string sum_docstr = R"pbdoc(
Returns the sum of `input`.
)pbdoc";
ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
};
/*----------------------------------------------
definition of triton.atomic_cas
----------------------------------------------*/
std::string atomic_cas_docstr = R"pbdoc(
Atomic compare-and-swap.
)pbdoc";
ir::value *atomic_cas(ir::value *ptr, ir::value *cmp, ir::value *val, ir::builder *builder) {
return builder->create_atomic_cas(ptr, cmp, val);
};
/*----------------------------------------------
definition of triton.atomic_xchg
----------------------------------------------*/
std::string atomic_xchg_docstr = R"pbdoc(
Atomic exchange.
)pbdoc";
ir::value *atomic_xchg(ir::value *ptr, ir::value *val, ir::builder *builder) {
return builder->create_atomic_exch(ptr, val);
};
/*----------------------------------------------
debug barrier
----------------------------------------------*/
std::string debug_barrier_docstr = R"pbdoc(
Temporary hacky fixup for when the compiler forgets to insert sync barriers
)pbdoc";
ir::value *debug_barrier(ir::builder *builder) {
return builder->create_barrier();
}
#define DEF_BINARY_OP(MOD, PY_NAME, C_FUNC, ...) \
MOD.def(PY_NAME, binary_op(C_FUNC), (C_FUNC##_docstr + _builder_doc).c_str(), \
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
template <class FN>
std::function<ir::value *(ir::value *, ir::value *, ir::builder *builder)>
binary_op(const FN &fn) {
auto ret = [&fn](ir::value *self, ir::value *other, ir::builder *builder) {
//std::tie(self, other) = try_broadcast(self, other, builder);
return fn(self, other, builder);
};
return ret;
}
/*----------------------------------------------
definition of self + other
----------------------------------------------*/
std::string add_docstr = R"pbdoc(
Returns self + other, element-wise.
)pbdoc";
ir::value *add(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// ptr + offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(self, {other});
// float + float
else if (scalar_ty->is_floating_point_ty())
return builder->create_fadd(self, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_add(self, other);
throw_not_implemented("add");
}
/*----------------------------------------------
definition of self - other
----------------------------------------------*/
std::string sub_docstr = R"pbdoc(
Returns self - other, element-wise.
)pbdoc";
ir::value *sub(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// ptr + offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(self, {other});
// float + float
if (scalar_ty->is_floating_point_ty())
return builder->create_fsub(self, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(self, other);
throw_not_implemented("sub");
}
/*----------------------------------------------
definition of self * other
----------------------------------------------*/
std::string mul_docstr = R"pbdoc(
Returns self * other, element-wise.
)pbdoc";
ir::value *mul(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float * float
if (scalar_ty->is_floating_point_ty())
return builder->create_fmul(self, other);
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(self, other);
throw_not_implemented("mul");
}
/*----------------------------------------------
definition of self > other
----------------------------------------------*/
std::string greater_than_docstr = R"pbdoc(
Returns self > other, element-wise.
)pbdoc";
ir::value *greater_than(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float > float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(self, other);
// int > int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGT(self, other);
throw_not_implemented("greater_than");
}
/*----------------------------------------------
definition of self >= other
----------------------------------------------*/
std::string greater_equal_docstr = R"pbdoc(
Returns self >= other, element-wise.
)pbdoc";
ir::value *greater_equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float >= float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(self, other);
// int >= int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGE(self, other);
throw_not_implemented("greater_equal");
}
/*----------------------------------------------
definition of self < other
----------------------------------------------*/
std::string less_than_docstr = R"pbdoc(
Returns self < other, element-wise.
)pbdoc";
ir::value *less_than(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(self, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLT(self, other);
throw_not_implemented("less_than");
}
/*----------------------------------------------
definition of self <= other
----------------------------------------------*/
std::string less_equal_docstr = R"pbdoc(
Returns self <= other, element-wise.
)pbdoc";
ir::value *less_equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(self, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLE(self, other);
throw_not_implemented("less_equal");
}
/*----------------------------------------------
definition of self == other
----------------------------------------------*/
std::string equal_docstr = R"pbdoc(
Returns self == other, element-wise.
)pbdoc";
ir::value *equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOEQ(self, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(self, other);
throw_not_implemented("equal");
}
/*----------------------------------------------
definition of self / other
----------------------------------------------*/
std::string _div_docstr = R"pbdoc(
Returns self / other, element-wise.
)pbdoc";
ir::value *_div(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float / float
if (scalar_ty->is_floating_point_ty())
return builder->create_fdiv(self, other);
// int / int
else if (scalar_ty->is_integer_ty())
return builder->create_sdiv(self, other);
throw_not_implemented("div");
}
/*----------------------------------------------
definition of self % other
----------------------------------------------*/
std::string mod_docstr = R"pbdoc(
Returns self % other, element-wise.
)pbdoc";
ir::value *mod(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(self, other);
// int % int
else if (scalar_ty->is_integer_ty())
return builder->create_srem(self, other);
throw_not_implemented("mod");
}
/*----------------------------------------------
definition of self & other
----------------------------------------------*/
std::string _and_docstr = R"pbdoc(
Returns self & other, element-wise.
)pbdoc";
ir::value *_and(ir::value *self, ir::value *other, ir::builder *builder) {
return builder->create_and(self, other);
}
/*----------------------------------------------
definition of minimum(self, other)
----------------------------------------------*/
std::string minimum_docstr = R"pbdoc(
Returns element-wise minimum of self and other
)pbdoc";
ir::value *minimum(ir::value *self, ir::value *other, ir::builder *builder) {
return where(less_than(self, other, builder), self, other, builder);
}
/*----------------------------------------------
definition of self[slices]
----------------------------------------------*/
enum slice_mode_t {
NEWAXIS,
ALL
};
std::string subscript_docstr = R"pbdoc(
returns self[slices].
:param slices: The slices to subscript with.
:type slices: List of `None` or `:` slices.
)pbdoc";
ir::value *subscript(ir::value *self, std::vector<py::object> slices, ir::builder *builder) {
std::vector<slice_mode_t> modes;
for (py::object slice : slices) {
py::object none = py::none();
py::object all = py::make_tuple(none, none, none);
if (slice.is(none))
modes.push_back(NEWAXIS);
else if (all.attr("__eq__")(slice))
modes.push_back(ALL);
else
throw std::runtime_error("slice must be None or (None, None, None)");
}
ir::type::block_shapes_t shape;
size_t curr = 0;
for (slice_mode_t mode : modes) {
if (mode == NEWAXIS)
shape.push_back(1);
else {
assert(mode == ALL);
shape.push_back(self->get_type()->get_block_shapes()[curr++]);
}
}
return builder->create_reshape(self, shape);
}

View File

@@ -1,5 +1,13 @@
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/codegen/pass.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/driver/stream.h"
#include "triton/ir/builder.h"
#include "triton/ir/dispatch.h"
#include "triton/ir/enums.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include <optional>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
@@ -8,78 +16,9 @@
#include <string>
namespace py = pybind11;
using namespace triton;
namespace rt = triton::runtime;
namespace ir = triton::ir;
namespace drv = triton::driver;
/*****************************************************************************/
/* Python bindings for triton::tools */
/*****************************************************************************/
/*!
@brief Function for extracting kernels out of a given source-string
This can be important to enable pre-processor macros (or tunable parameters) that should only
be defined within the scope of a single kernel function
*/
std::string extract_kernels(const std::string &str, const std::vector<std::string> &names) {
if (names.empty())
return str;
// search for all regex matches of kernel_regex in str
std::smatch matches;
std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})");
std::sregex_iterator it(str.begin(), str.end(), regex);
std::sregex_iterator end;
std::vector<std::tuple<std::string, int, int>> kernels;
for (; it != end; ++it) {
int pos = it->position();
int len = it->length();
std::string name = it->str(1);
kernels.push_back(std::make_tuple(name, pos, len));
}
// check that all the kernels provided actually exist
for (const std::string &name : names) {
auto pred = [&name](const std::tuple<std::string, int, int> &t) { return std::get<0>(t) == name; };
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
if (!found)
throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
}
// simple parsing logic to extract the declaration and body of each specified kernel
std::string ret;
for (const auto &k : kernels) {
std::string name;
int pos, len;
std::tie(name, pos, len) = k;
if (std::find(names.begin(), names.end(), name) == names.end())
continue;
std::string def = str.substr(pos, str.size() - pos);
// skip over declaration
// by finding matching ')' for first '('
int count = 1;
pos = def.find('(');
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
count += def[pos] == '(';
count -= def[pos] == ')';
}
// skip over definition
// by finding matching '{' for first '}'
count = 1;
pos = def.find('{', pos);
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
count += def[pos] == '{';
count -= def[pos] == '}';
}
ret += def.substr(0, pos);
ret += '\n';
}
return ret;
}
void init_triton_tools(py::module &&m) {
m.def("extract_kernels", &extract_kernels);
}
/*****************************************************************************/
/* Python bindings for triton::driver */
/*****************************************************************************/
@@ -88,14 +27,14 @@ void init_triton_driver(py::module &&m) {
// base device
py::class_<drv::device>(m, "device");
// cuda device
py::class_<drv::cu_device, driver::device>(m, "cu_device")
py::class_<drv::cu_device, drv::device>(m, "cu_device")
.def(py::init([](int dev_id, bool take_ownership) {
CUdevice handle;
drv::dispatch::cuDeviceGet(&handle, dev_id);
return new drv::cu_device(handle, take_ownership);
}));
// host device
py::class_<drv::host_device, driver::device>(m, "host_device")
py::class_<drv::host_device, drv::device>(m, "host_device")
.def(py::init<>());
// base stream
@@ -108,54 +47,236 @@ void init_triton_driver(py::module &&m) {
// py doesn't support opaque pointer (e.g., CUstream) so
// we assume it has been converted to uint64_t
.def(py::init([](uint64_t handle, bool take_ownership) {
return std::unique_ptr<driver::cu_stream>(new driver::cu_stream((CUstream)handle, take_ownership));
}));
return std::unique_ptr<drv::cu_stream>(new drv::cu_stream((CUstream)handle, take_ownership));
}))
.def("enqueue", [](drv::cu_stream *self, drv::kernel *kernel,
size_t grid_0, size_t grid_1, size_t grid_2,
size_t block_0, size_t block_1, size_t block_2,
const std::string &args,
size_t shared_mem) {
return self->enqueue(kernel, {grid_0, grid_1, grid_2}, {block_0, block_1, block_2},
(void *)args.data(), args.size(), shared_mem);
});
py::class_<drv::module>(m, "module");
//py::class_<drv::cu_module, drv::module>(m, "cu_module");
py::class_<drv::kernel>(m, "kernel");
}
/*****************************************************************************/
/* Python bindings for triton::runtime */
/* Python bindings for triton::codegen */
/*****************************************************************************/
void init_triton_runtime(py::module &&m) {
// argument type
py::enum_<rt::arg_type>(m, "arg_type")
.value("int1", rt::INT1_T)
.value("int8", rt::INT8_T)
.value("int16", rt::INT16_T)
.value("int32", rt::INT32_T)
.value("int64", rt::INT64_T)
.value("half", rt::HALF_T)
.value("float", rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
// compilation options
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
.def(py::init<>())
.def_readwrite("defines", &rt::options_t::defines)
.def_readwrite("num_warps", &rt::options_t::num_warps)
.def("__getattr__", [](rt::options_t *opt, const std::string &name) {
return opt->D<int>(name);
});
// kernel
py::class_<rt::kernel>(m, "kernel")
.def("__call__", &rt::kernel::operator())
.def_readonly("opt", &rt::kernel::opt)
.def("asm", &rt::kernel::get_asm);
// tune conf
py::class_<rt::config>(m, "config")
.def(py::init<std::map<std::string, std::string>, int>(),
py::arg("defines") = std::map<std::string, std::string>(),
py::arg("num_warps"));
// function
py::class_<rt::function>(m, "function")
.def(py::init<const std::string &, const rt::options_t &, driver::device *, const std::vector<rt::config> &, const std::vector<std::string> &>())
.def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal)
.def("signature", &rt::function::get_signature);
void init_triton_codegen(py::module &&m) {
m.def(
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps) {
drv::module *mod;
drv::kernel *ker;
size_t shared_mem;
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem);
return std::make_tuple(mod, ker, shared_mem);
},
py::return_value_policy::take_ownership);
}
/*****************************************************************************/
/* User-facing language features */
/*****************************************************************************/
void init_triton_frontend(py::module &&m) {
using ret = py::return_value_policy;
// programming model
m.def("program_id", &ir::dispatch::program_id, ret::reference);
m.def("num_programs", &ir::dispatch::num_programs, ret::reference);
// binary
m.def("add", &ir::dispatch::add, ret::reference);
m.def("sub", &ir::dispatch::sub, ret::reference);
m.def("mul", &ir::dispatch::mul, ret::reference);
m.def("truediv", &ir::dispatch::truediv, ret::reference);
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
m.def("mod", &ir::dispatch::mod, ret::reference);
m.def("and_", &ir::dispatch::and_, ret::reference);
m.def("or_", &ir::dispatch::or_, ret::reference);
m.def("xor_", &ir::dispatch::xor_, ret::reference);
m.def("lshr", &ir::dispatch::lshr, ret::reference);
m.def("shl", &ir::dispatch::shl, ret::reference);
// unary
m.def("plus", &ir::dispatch::plus, ret::reference);
m.def("minus", &ir::dispatch::minus, ret::reference);
m.def("invert", &ir::dispatch::invert, ret::reference);
// comparison
m.def("greater_than", &ir::dispatch::greater_than, ret::reference);
m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference);
m.def("less_than", &ir::dispatch::less_than, ret::reference);
m.def("less_equal", &ir::dispatch::less_equal, ret::reference);
m.def("equal", &ir::dispatch::equal, ret::reference);
m.def("not_equal", &ir::dispatch::not_equal, ret::reference);
// block creation
m.def("arange", &ir::dispatch::arange, ret::reference);
m.def("zeros", &ir::dispatch::zeros, ret::reference);
// type manipuatation
m.def("reshape", &ir::dispatch::reshape, ret::reference);
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("cast", &ir::dispatch::cast, ret::reference);
// memory
m.def("load", &ir::dispatch::load, ret::reference);
m.def("store", &ir::dispatch::store, ret::reference);
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
// linear algebra
m.def("dot", &ir::dispatch::dot, ret::reference);
// indexing
m.def("where", &ir::dispatch::where, ret::reference);
// reduction
m.def("min", &ir::dispatch::min, ret::reference);
m.def("max", &ir::dispatch::max, ret::reference);
m.def("sum", &ir::dispatch::sum, ret::reference);
// math
m.def("exp", &ir::dispatch::exp, ret::reference);
m.def("log", &ir::dispatch::log, ret::reference);
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
// internal (debugging only)
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference);
}
/*****************************************************************************/
/* Python bindings for triton::ir */
/*****************************************************************************/
void init_triton_ir(py::module &&m) {
using ret = py::return_value_policy;
using namespace pybind11::literals;
py::class_<ir::context>(m, "context")
.def(py::init<>());
auto value = py::class_<ir::value>(m, "value");
value.def_property("name", &ir::value::get_name, &ir::value::set_name);
value.def_property_readonly("type", &ir::value::get_type);
py::class_<ir::user, ir::value>(m, "user");
py::class_<ir::constant, ir::user>(m, "constant");
py::class_<ir::undef_value, ir::constant>(m, "undef")
.def("get", &ir::undef_value::get, ret::reference);
py::class_<ir::constant_int, ir::constant>(m, "constant_int")
.def_property_readonly("value", &ir::constant_int::get_value)
.def("__int__", [](ir::constant_int *self) { return self->get_value(); });
py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
.def_property_readonly("value", &ir::constant_fp::get_value);
py::class_<ir::type>(m, "type")
.def("is_ptr", &ir::type::is_pointer_ty)
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
.def("is_floating", &ir::type::is_floating_point_ty)
.def("is_block", &ir::type::is_block_ty)
.def("make_ptr", &ir::pointer_type::get, ret::reference)
.def("make_function", &ir::function_type::get, ret::reference)
.def("make_block", &ir::block_type::get, ret::reference)
.def("get_void", &ir::type::get_void_ty, ret::reference)
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
.def("get_int1", &ir::type::get_int1_ty, ret::reference)
.def("get_int8", &ir::type::get_int8_ty, ret::reference)
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
.def("is_void", &ir::type::is_void_ty)
.def("is_fp16", &ir::type::is_half_ty)
.def("is_fp32", &ir::type::is_float_ty)
.def("is_fp64", &ir::type::is_double_ty)
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
.def_property_readonly("context", &ir::type::get_context, ret::reference);
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
py::class_<ir::function_type, ir::type>(m, "function_type");
py::class_<ir::integer_type, ir::type>(m, "integer_type");
py::class_<ir::block_type, ir::type>(m, "block_type")
.def_property_readonly("shape", &ir::block_type::get_shapes)
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
py::class_<ir::scope>(m, "scope")
.def(py::init<>())
.def_property_readonly("values", &ir::scope::get_values)
.def("set_type", &ir::scope::set_type);
py::class_<ir::module>(m, "module")
.def(py::init<std::string, ir::builder &>())
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
.def("add_new_scope", &ir::module::add_new_scope, ret::reference)
.def("seal_block", &ir::module::seal_block)
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
.def("pop_scope", &ir::module::pop_scope)
.def_property_readonly("scope", &ir::module::get_scope, ret::reference)
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
using eattr = ir::attribute_kind_t;
py::enum_<eattr>(m, "attribute_kind")
.value("readonly", eattr::readonly)
.value("writeonly", eattr::writeonly)
.value("noalias", eattr::noalias)
.value("aligned", eattr::aligned)
.value("multiple_of", eattr::multiple_of)
.value("retune", eattr::retune)
.value("not_implemented", eattr::not_implemented);
py::class_<ir::attribute>(m, "attribute")
.def(py::init<eattr, int>());
py::class_<ir::function>(m, "function")
.def_property_readonly("args", &ir::function::args)
.def_property_readonly("attrs", &ir::function::attrs)
.def("add_attr", &ir::function::add_attr);
py::class_<ir::argument, ir::value>(m, "argument");
py::class_<ir::basic_block, ir::value>(m, "basic_block")
.def("create", &ir::basic_block::create, ret::reference)
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
.def(py::init<ir::context &>())
// getters
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
// control flow
.def("br", &ir::builder::create_br, ret::reference)
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
// constants
.def("get_int1", &ir::builder::get_int1, ret::reference)
.def("get_int32", &ir::builder::get_int32, ret::reference)
.def("get_float16", &ir::builder::get_float16, ret::reference)
.def("get_float32", &ir::builder::get_float32, ret::reference)
.def("get_range", &ir::builder::get_range, ret::reference);
}
void init_triton(py::module &m) {
py::module subm = m.def_submodule("triton");
init_triton_codegen(std::move(subm.def_submodule("code_gen")));
init_triton_driver(std::move(subm.def_submodule("driver")));
init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_tools(std::move(subm.def_submodule("tools")));
init_triton_ir(std::move(subm.def_submodule("ir")));
init_triton_frontend(std::move(subm.def_submodule("frontend")));
}

View File

@@ -0,0 +1,209 @@
import torch
import triton
import copy
import pytest
import ast
torch.manual_seed(0)
# convert from string to torch.dtype
# Necessary because doesn't print torch.dtype properly
cvt = {
'bool': torch.bool,
'int8': torch.int8,
'int16': torch.int16,
'int32': torch.int32,
'int64': torch.int64,
'float16': torch.float16,
'float32': torch.float32,
'float64': torch.float64,
}
int_dtypes = ['int8', 'int16', 'int32', 'int64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + float_dtypes
def patch_kernel(template, to_replace):
kernel = copy.deepcopy(template)
for key, value in to_replace.items():
kernel.src = kernel.src.replace(key, value)
return kernel
# generic test functions
def _test_unary(dtype_x, expr, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, **meta):
off = triton.arange(0, meta['SIZE'])
x = triton.load(X + off)
z = GENERATE_TEST_HERE
triton.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
# reference result
z_ref = eval(expr)
# triton result
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_allclose(z_ref, z_tri)
def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, **meta):
off = triton.arange(0, meta['SIZE'])
x = triton.load(X + off)
y = triton.load(Y + off)
z = GENERATE_TEST_HERE
triton.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
# reference result
z_ref = eval(expr)
# triton result
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_allclose(z_ref, z_tri)
# ---------------
# test binary ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['+', '-', '*', '/', '%'] \
for dtype_x in dtypes \
for dtype_y in dtypes
])
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, device=device)
# ---------------
# test bitwise ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['&', '|', '^'] \
for dtype_x in dtypes \
for dtype_y in dtypes
])
def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
if 'float' in dtype_x + dtype_y:
with pytest.raises(RuntimeError):
_test_binary(dtype_x, dtype_y, expr, device=device)
else:
_test_binary(dtype_x, dtype_y, expr, device=device)
# ---------------
# test compare ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['==', '!=', '>', '<', '>=', '<='] \
for dtype_x in dtypes \
for dtype_y in dtypes
])
def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, device=device)
# ---------------
# test unary ops
# ---------------
@pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, f' -x') for dtype_x in float_dtypes
] + [\
(dtype_x, f' ~x') for dtype_x in int_dtypes
])
def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device)
# ----------------
# test indexing
# ----------------
def make_ptr_str(name, shape):
rank = len(shape)
offsets = []
stride = 1
for i in reversed(range(rank)):
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}']
stride *= shape[i]
return f"{name} + {' + '.join(offsets)}"
@pytest.mark.parametrize("expr", [f'x[{s}]' for s in
['None, :', ':, None',\
'None, :, :', ':, :, None']\
])
def test_index1d(expr, device='cuda'):
dtype = torch.int32
rank_x = expr.count(':')
rank_y = expr.count(',') + 1
shape_x = [32 for _ in range(rank_x)]
shape_z = [32 for _ in range(rank_y)]
# Triton kernel
@triton.jit
def kernel(Z, X, **meta):
SIZE = meta['SIZE']
m = triton.arange(0, SIZE)
n = triton.arange(0, SIZE)
x = triton.load(X_PTR_EXPR)
z = GENERATE_TEST_HERE
triton.store(Z_PTR_EXPR, z)
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
'GENERATE_TEST_HERE': expr,
}
kernel = patch_kernel(kernel, to_replace)
# torch result
x = triton.testing.random(shape_x, dtype=dtype, device=device)
y = torch.zeros(shape_z, dtype=dtype, device=device)
z_ref = eval(expr) + y
# triton result
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0])
# compare
triton.testing.assert_allclose(z_ref, z_tri)
# ---------------
# test load
# ---------------
# ---------------
# test store
# ---------------
# ---------------
# test if
# ---------------
# ---------------
# test for
# ---------------
# ---------------
# test while
# ---------------

View File

@@ -1,17 +0,0 @@
import torch
import triton
def test_op():
torch.manual_seed(0)
DTYPE = torch.float16
N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
pad, stride, = (1, 1), (1, 1)
dilation = (1, 1)
a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5
b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5
th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation)
tt_c = triton.ops.conv(a, b, pad, stride)
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol)

View File

@@ -3,66 +3,74 @@ import itertools
import triton
import torch
@pytest.mark.parametrize(
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(*[
[
# 1 warp
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# # 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# # 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
# (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
# (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
# (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
]),
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# # split-k
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# # variable input
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
]
),
)
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE):
torch.manual_seed(0)
defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}
triton.ops._matmul._kernels = dict()
triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)]
if M is None:
M = TM
if N is None:
N = TN
if K is None:
K = TK * SPLITK
# nuke kernel decorators -- will set meta-parameters manually
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8}
configs = [triton.Config(meta=META, num_warps=NWARP)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []
triton.autotune(configs, [])(kernel)
kernel.kernel_decorators += decorators[1:]
# get matrix shape
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
# run test
th_c = torch.matmul(a, b)
tt_c = triton.ops.matmul(a, b)
assert triton.testing.allclose(th_c, tt_c)

View File

@@ -2,9 +2,10 @@
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from . import testing
from .kernel import *
from . import ops
from .code_gen import jit, autotune, heuristics, Config, Autotuner
from .core import *
from . import testing
from . import ops
# version
__version__ = '1.0.0'

648
python/triton/code_gen.py Normal file
View File

@@ -0,0 +1,648 @@
import inspect
import struct
import enum
import types
import torch
import ast
import builtins
import triton._C.libtriton.triton as _triton
import triton
import sys
import textwrap
from abc import ABC, abstractmethod
class CodeGenerator(ast.NodeVisitor):
def get_value(self, name):
# search node.id in local scope
ret = None
if name in self.lscope:
ret = self.lscope[name]
# search node.id in global scope
elif name in self.gscope:
ret = self.gscope[name]
# search node.id in builtins
elif name in self.builtins:
ret = self.builtins[name]
else:
raise ValueError(f'{name} is not defined')
if isinstance(ret, triton.block):
handle = self.module.get_value(name)
return triton.block(handle)
return ret
def set_value(self, name, value):
if isinstance(value, _triton.ir.value):
value = triton.block(value)
if isinstance(value, triton.block):
self.module.set_value(name, value.handle)
self.module.scope.set_type(name, value.handle.type)
self.lscope[name] = value
def is_triton_object(self, value):
return isinstance(value, triton.block)
def visit_compound_statement(self, stmts, add_scope=False):
if add_scope:
self.module.add_new_scope()
for stmt in stmts:
self.last_ret = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
if add_scope:
self.module.pop_scope()
return self.last_ret
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
self.builder = _triton.ir.builder(context)
self.module = _triton.ir.module('', self.builder)
self.prototype = prototype
self.gscope = gscope
self.lscope = dict()
self.attributes = attributes
self.constants = constants
self.kwargs = kwargs
self.last_node = None
self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr}
def visit_Module(self, node):
self.module.add_new_scope()
ast.NodeVisitor.generic_visit(self, node)
self.module.pop_scope()
def visit_List(self, node):
ctx = self.visit(node.ctx)
assert ctx is None
elts = [self.visit(elt) for elt in node.elts]
return elts
# By design, only non-kernel functions can return
def visit_Return(self, node):
return self.visit(node.value)
def visit_FunctionDef(self, node, inline=False, arg_values=None):
arg_names, kwarg_names = self.visit(node.args)
# store keyword arguments in local scope
self.lscope[kwarg_names] = self.kwargs
# initialize function
if inline:
pass
else:
fn = self.module.get_or_insert_function(node.name, self.prototype)
arg_values = []
for i, arg_name in enumerate(arg_names):
if i in self.constants:
arg_values.append(self.constants[i])
else:
if i in self.attributes:
is_ptr = fn.args[i].type.is_ptr()
attr = 'aligned' if is_ptr else 'multiple_of'
attr = getattr(_triton.ir.attribute_kind, attr)
attr = _triton.ir.attribute(attr, self.attributes[i])
fn.add_attr(i + 1, attr)
fn.args[i].name = arg_name
arg_values.append(fn.args[i])
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
return self.visit_compound_statement(node.body, add_scope=True)
else:
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
self.module.seal_block(entry)
self.builder.set_insert_block(entry)
# visit function body
self.visit_compound_statement(node.body, add_scope=True)
# finalize function
self.builder.ret_void()
def visit_arguments(self, node):
arg_names = []
for arg in node.args:
arg_names += [self.visit(arg)]
kwarg_names = self.visit(node.kwarg)
return arg_names, kwarg_names
def visit_arg(self, node):
ast.NodeVisitor.generic_visit(self, node)
return node.arg
def visit_Assign(self, node):
names = []
for target in node.targets:
names += [self.visit(target)]
assert len(names) == 1
name = names[0]
value = self.visit(node.value)
self.set_value(names[0], value)
def visit_AugAssign(self, node):
name = node.target.id
lhs = ast.Name(id=name, ctx=ast.Load())
rhs = ast.BinOp(lhs, node.op, node.value)
assign = ast.Assign(targets=[node.target], value=rhs)
self.visit(assign)
return self.get_value(name)
def visit_Name(self, node):
if type(node.ctx) == ast.Store:
return node.id
return self.get_value(node.id)
def visit_Store(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Load(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Tuple(self, node):
args = [self.visit(x) for x in node.elts]
return tuple(args)
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
fn = {
ast.Add: '__add__',
ast.Sub: '__sub__',
ast.Mult: '__mul__',
ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__',
ast.Mod: '__mod__',
ast.Pow: '__pow__',
ast.LShift: '__lshift__',
ast.RShift: '__rshift__',
ast.BitAnd: '__and__',
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}[type(node.op)]
kws = dict()
if self.is_triton_object(lhs):
kws['builder'] = self.builder
ret = getattr(lhs, fn)(rhs, **kws)
if ret is NotImplemented:
if self.is_triton_object(rhs):
kws['builder'] = self.builder
fn = fn[:2] + 'r' + fn[2:]
ret = getattr(rhs, fn)(lhs, **kws)
return ret
def visit_If(self, node):
cond = self.visit(node.test)
if self.is_triton_object(cond):
current_bb = self.builder.get_insert_block()
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
self.module.seal_block(then_bb)
if else_bb:
self.module.seal_block(else_bb)
self.builder.cond_br(cond.handle, then_bb, else_bb)
else:
self.builder.cond_br(cond.handle, then_bb, endif_bb)
self.builder.set_insert_block(then_bb)
self.visit_compound_statement(node.body, add_scope=True)
# TODO: last statement is a terminator?
self.builder.br(endif_bb)
if else_bb:
self.builder.set_insert_block(else_bb)
self.visit_compound_statement(node.orelse, add_scope=True)
#TODO: last statement is a terminator?
self.builder.br(endif_bb)
self.module.seal_block(endif_bb)
self.builder.set_insert_block(endif_bb)
else:
if cond:
self.visit_compound_statement(node.body)
else:
self.visit_compound_statement(node.orelse)
def visit_IfExp(self, node):
cond = self.visit(node.test)
if cond:
return self.visit(node.body)
else:
return self.visit(node.orelse)
def visit_Pass(self, node):
pass
def visit_Compare(self, node):
assert len(node.comparators) == 1
assert len(node.ops) == 1
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
fn = {
ast.Eq: '__eq__',
ast.NotEq: '__ne__',
ast.Lt: '__lt__',
ast.LtE: '__le__',
ast.Gt: '__gt__',
ast.GtE: '__ge__',
ast.Is: '__eq__',
ast.IsNot: '__ne__',
}[type(node.ops[0])]
if self.is_triton_object(lhs) or self.is_triton_object(rhs):
return getattr(lhs, fn)(rhs, builder=self.builder)
return getattr(lhs, fn)(rhs)
def visit_UnaryOp(self, node):
op = self.visit(node.operand)
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
ast.Invert: '__invert__',
}[type(node.op)]
if self.is_triton_object(op):
return getattr(op, fn)(builder=self.builder)
return getattr(op, fn)()
def visit_While(self, node):
current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
def continue_fn():
cond = self.visit(node.test)
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
continue_fn()
self.builder.set_insert_block(loop_bb)
self.visit_compound_statement(node.body, add_scope=True)
continue_fn()
stop_bb = self.builder.get_insert_block()
self.module.seal_block(stop_bb)
self.module.seal_block(loop_bb)
self.module.seal_block(next_bb)
self.builder.set_insert_block(next_bb)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Str(self, node):
return ast.literal_eval(node)
def visit_Subscript(self, node):
assert node.ctx.__class__.__name__ == "Load"
lhs = self.visit(node.value)
slices = self.visit(node.slice)
if self.is_triton_object(lhs):
return lhs.__getitem__(slices, builder=self.builder)
return lhs[slices]
def visit_ExtSlice(self, node):
return [self.visit(dim) for dim in node.dims]
def visit_For(self, node):
iterator = self.visit(node.iter.func)
assert iterator == self.builtins['range']
# create nodes
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
init_node = ast.Assign(targets=[st_target], value=node.iter.args[0])
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\
self.visit(neg_cond_node),\
builder=self.builder)
#cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=node.iter.args[2])
# code generation
current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
def continue_fn():
self.visit(step_node)
cond = build_cond()
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
self.visit(init_node)
cond = build_cond()
self.builder.cond_br(cond.handle, loop_bb, next_bb)
self.builder.set_insert_block(loop_bb)
self.visit_compound_statement(node.body, add_scope=True)
# TODO: handle case where body breaks control flow
continue_fn()
stop_bb = self.builder.get_insert_block()
self.module.seal_block(stop_bb)
self.module.seal_block(loop_bb)
self.module.seal_block(next_bb)
self.builder.set_insert_block(next_bb)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Slice(self, node):
lower = self.visit(node.lower)
upper = self.visit(node.upper)
step = self.visit(node.step)
return slice(lower, upper, step)
def visit_Index(self, node):
return self.visit(node.value)
def visit_NameConstant(self, node):
return node.value
def visit_keyword(self, node):
return {node.arg: self.visit(node.value)}
def visit_Call(self, node):
fn = self.visit(node.func)
kws = dict()
for keyword in node.keywords:
kws.update(self.visit(keyword))
args = [self.visit(arg) for arg in node.args]
if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.core:
return fn(*args, builder=self.builder, **kws)
return fn(*args, **kws)
def visit_Num(self, node):
return node.n
def visit_Attribute(self, node):
lhs = self.visit(node.value)
return getattr(lhs, node.attr)
def visit_Expr(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_NoneType(self, node):
return None
def visit(self, node):
if node is not None:
self.last_node = node
return super().visit(node)
def generic_visit(self, node):
typename = type(node).__name__
raise NotImplementedError("Unsupported node: {}".format(typename))
class Binary:
def __init__(self, module, kernel, num_warps, shared_mem):
self.module = module
self.kernel = kernel
self.shared_mem = shared_mem
self.num_warps = num_warps
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem)
class CompilationError(Exception):
def __init__(self, src, node, err):
self.message = '\n'.join(src.split('\n')[:node.lineno])
self.message += '\n' + ' ' * node.col_offset + '^'
self.message += '\n Error: ' + str(err)
super().__init__(self.message)
class Kernel:
type_names = {
int: 'I',
float: 'f',
bool: 'B',
torch.float16: 'f16',
torch.float32: 'f32',
torch.float64: 'f64',
torch.bool: 'i1',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
}
@staticmethod
def _to_triton_ir(context, obj):
type_map = {
'I': _triton.ir.type.get_int32,
'f': _triton.ir.type.get_fp32,
'B': _triton.ir.type.get_int1,
'f16': _triton.ir.type.get_fp16,
'f32': _triton.ir.type.get_fp32,
'f64': _triton.ir.type.get_fp64,
'i1': _triton.ir.type.get_int1,
'i8': _triton.ir.type.get_int8,
'i16': _triton.ir.type.get_int16,
'i32': _triton.ir.type.get_int32,
'i64': _triton.ir.type.get_int64,
}
# convert torch.Tensor to Triton IR pointers
if isinstance(obj, torch.Tensor):
name = Kernel.type_names[obj.dtype]
elt_ty = type_map[name](context)
return _triton.ir.type.make_ptr(elt_ty, 1)
# default path returns triton.ir.type directly
name = Kernel.type_names[obj.__class__]
return type_map[name](context)
@staticmethod
def _types_key(*wargs, tensor_idxs):
# type inference
types_key = [None] * len(wargs)
for i, arg in enumerate(wargs):
prefix = 'P' if i in tensor_idxs else ''
suffix = Kernel.type_names[arg.dtype] if i in tensor_idxs else Kernel.type_names[arg.__class__]
types_key[i] = prefix + suffix
return tuple(types_key)
@staticmethod
def pow2_divisor(N):
if N % 16 == 0: return 16
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
return 1
def __init__(self, fn):
self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, **meta):
# explicitly set device
torch.cuda.set_device(device.index)
# create IR module
context = _triton.ir.context()
# get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs]
ret_type = _triton.ir.type.get_void(context)
prototype = _triton.ir.type.make_function(ret_type, arg_types)
# generate Triton-IR
# export symbols visible from self.fn into code-generator object
gscope = sys.modules[self.fn.module].__dict__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta)
try:
generator.visit(self.fn.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.fn.src, node, e)
tt_device = _triton.driver.cu_device(device.index, False)
# Compile to machine code
mod, ker, shared_mem = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
return Binary(mod, ker, num_warps, shared_mem)
def __call__(self, *wargs, grid, num_warps=4, **meta):
# device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if isinstance(arg, torch.Tensor)]
if len(tensor_idxs) == 0:
raise ValueError("No Tensor argument found.")
device = wargs[tensor_idxs[0]].device
# attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
# determine if we need to re-compile
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
attr_key = frozenset(attributes.items())
meta_key = frozenset(meta.items())
const_key = frozenset(constants.items())
key = (device.type, device.index, types_key, attr_key, num_warps, meta_key, const_key)
cache = self.fn.cache
if key not in cache:
# compile and cache configuration if necessary
cache[key] = self._compile(
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
)
# pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel.type_names[arg.__class__] for i, arg in enumerate(wargs)])
params = struct.pack(fmt, *args)
# enqueue cached function into stream
binary = cache[key]
cu_stream = torch.cuda.current_stream(device.index).cuda_stream
stream = _triton.driver.cu_stream(cu_stream, False)
grid = grid(meta) if hasattr(grid, '__call__') else grid
binary(stream, params, *grid)
class Launcher:
def __init__(self, kernel, grid):
self.kernel = kernel
self.grid = grid
def __call__(self, *wargs, **kwargs):
self.kernel(*wargs, **kwargs, grid=self.grid)
class Autotuner:
def __init__(self, kernel, arg_names, configs, key):
if not configs:
self.configs = [Config(dict(), num_warps=4)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = dict()
self.kernel = kernel
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.meta.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.meta)
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, **current)
return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **meta):
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
timings = {config: self._bench(*args, config=config, **meta) \
for config in self.configs}
self.cache[key] = builtins.min(timings, key=timings.get)
config = self.cache[key]
else:
config = self.configs[0]
self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
class JITFunction:
def __init__(self, fn):
self.module = fn.__module__
self.arg_names = inspect.getfullargspec(fn).args
self.cache = dict()
self.kernel_decorators = []
self.src = textwrap.dedent(inspect.getsource(fn))
self.kernel = None
# we do not parse in the constructor because
# the user might want to monkey-patch self.src dynamically.
# Some unit tests do this, for example.
def parse(self):
tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
def __call__(self, *args, generator: CodeGenerator, **meta):
try:
return generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node, e)
def _init_kernel(self):
if self.kernel is None:
self.kernel = Kernel(self)
for decorator in reversed(self.kernel_decorators):
self.kernel = decorator(self.kernel)
return self.kernel
def __getitem__(self, grid):
return Launcher(self._init_kernel(), grid)
class Config:
def __init__(self, meta, num_warps=4):
self.meta = meta
self.num_warps = num_warps
def autotune(configs, key):
def decorator(fn):
def wrapper(kernel):
return Autotuner(kernel, fn.arg_names, configs, key)
fn.kernel_decorators.append(wrapper)
return fn
return decorator
def heuristics(values):
def decorator(fn):
def wrapper(kernel):
def fun(*args, **meta):
for v, heur in values.items():
assert v not in meta
meta[v] = heur(*args, **meta)
return kernel(*args, **meta)
return fun
fn.kernel_decorators.append(wrapper)
return fn
return decorator
def jit(fn):
return JITFunction(fn)

499
python/triton/core.py Normal file
View File

@@ -0,0 +1,499 @@
from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend
import triton
from functools import wraps
def _patch(fn):
# convert block/dtype to ir values
def _to_ir(x, builder):
if isinstance(x, bool):
return builder.get_int1(x)
elif isinstance(x, int):
return builder.get_int32(x)
elif isinstance(x, float):
return builder.get_float32(x)
if isinstance(x, block):
return x.handle
if isinstance(x, dtype):
return x.handle(builder)
return x
def _from_ir(x):
if isinstance(x, ir.value):
if x.type.is_void():
return None
return block(x)
return x
def wrapper(*args, **kwargs):
builder = args[-1]
assert isinstance(builder, ir.builder)
args = [_to_ir(x, builder) for x in args]
kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()}
ret = fn(*args, **kwargs)
if isinstance(ret, tuple):
return map(_from_ir, ret)
return _from_ir(ret)
return wrapper
for name in dir(frontend):
fn = getattr(frontend, name)
if callable(fn):
setattr(frontend, name, _patch(fn))
def builtin(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if 'builder' not in kwargs or \
kwargs['builder'] is None:
raise ValueError("Builder argument must be provided outside of JIT functions")
return fn(*args, **kwargs)
if wrapper.__doc__:
wrapper.__doc__ += """\
:param builder: IR builder to generate code into, optional from within @triton.jit functions
:type builder: triton.ir.builder
"""
return wrapper
class dtype:
def __init__(self, init):
self.init = init
def handle(self, builder):
ctx = builder.context
return self.init(ctx)
class pointer_dtype:
def __init__(self, element_ty):
self.element_ty = element_ty
def handle(self, builder):
return ir.type.make_ptr(self.element_ty, 1)
int1 = dtype(ir.type.get_int1)
int8 = dtype(ir.type.get_int8)
int16 = dtype(ir.type.get_int16)
int32 = dtype(ir.type.get_int32)
int64 = dtype(ir.type.get_int64)
float16 = dtype(ir.type.get_fp16)
float32 = dtype(ir.type.get_fp32)
float64 = dtype(ir.type.get_fp64)
class block:
@staticmethod
def _init_dtype(ir_type):
# primitive type
if ir_type.is_int1(): return int1
if ir_type.is_int8(): return int8
if ir_type.is_int16(): return int16
if ir_type.is_int32(): return int32
if ir_type.is_int64(): return int64
if ir_type.is_fp16(): return float16
if ir_type.is_fp32(): return float32
if ir_type.is_fp64(): return float64
# pointer type
if ir_type.is_ptr():
element_ty = block._init_dtype(ir_type.element)
return pointer_dtype(element_ty)
raise ValueError(f"Unsupported type {ir_type}")
def __init__(self, handle):
# IR handle
self.handle = handle
# Block shape
self.shape = (1, )
if self.handle.type.is_block():
self.shape = self.handle.type.shape
# Data-type wrapper
self.dtype = block._init_dtype(self.handle.type.scalar)
@builtin
def __add__(self, other, builder=None):
return frontend.add(self, other, builder)
def __radd__(self, other, builder=None):
return self.__add__(other, builder=builder)
@builtin
def __sub__(self, other, builder=None):
return frontend.sub(self, other, builder)
@builtin
def __mul__(self, other, builder=None):
return frontend.mul(self, other, builder)
def __rmul__(self, other, builder=None):
return self.__mul__(other, builder=builder)
@builtin
def __truediv__(self, other, builder=None):
return frontend.truediv(self, other, builder)
def __rtruediv__(self, other, builder=None):
return frontend.truediv(other, self, builder)
@builtin
def __floordiv__(self, other, builder=None):
return frontend.floordiv(self, other, builder)
@builtin
def __mod__(self, other, builder=None):
return frontend.mod(self, other, builder)
# unary operators
@builtin
def __neg__(self, builder=None):
return frontend.minus(self, builder)
@builtin
def __invert__(self, builder=None):
return frontend.invert(self, builder)
# bitwise operators
@builtin
def __and__(self, other, builder=None):
return frontend.and_(self, other, builder)
@builtin
def __or__(self, other, builder=None):
return frontend.or_(self, other, builder)
@builtin
def __xor__(self, other, builder=None):
return frontend.xor_(self, other, builder)
@builtin
def __lshift__(self, other, builder=None):
return frontend.shl(self, other, builder)
@builtin
def __rshift__(self, other, builder=None):
return frontend.lshr(self, other, builder)
# comparison operators
@builtin
def __gt__(self, other, builder=None):
return frontend.greater_than(self, other, builder)
@builtin
def __ge__(self, other, builder=None):
return frontend.greater_equal(self, other, builder)
@builtin
def __lt__(self, other, builder=None):
return frontend.less_than(self, other, builder)
@builtin
def __le__(self, other, builder=None):
return frontend.less_equal(self, other, builder)
@builtin
def __eq__(self, other, builder=None):
return frontend.equal(self, other, builder)
@builtin
def __ne__(self, other, builder=None):
return frontend.not_equal(self, other, builder)
@builtin
def __getitem__(self, slices, builder=None):
if isinstance(slices, slice):
slices = [slices]
src_shape = self.shape
dst_shape = []
curr = 0
for sl in slices:
if sl == None:
dst_shape.append(1)
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr])
curr += 1
ret = frontend.reshape(self, dst_shape, builder)
return ret
@builtin
def to(self, dtype, builder=None):
return frontend.cast(self, dtype.handle(builder), builder)
# -----------------------
# SPMD Programming Model
# -----------------------
@builtin
def program_id(axis, builder=None):
"""
Returns the id of the current program instance along the given `axis`.
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
"""
return frontend.program_id(axis, builder)
@builtin
def num_programs(axis, builder=None):
"""
Returns the number of program instances launched along the given `axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
"""
return frontend.num_programs(axis, builder)
# -----------------------
# Block Initialization
# -----------------------
@builtin
def arange(start, end, builder=None):
"""
Returns contiguous values within the open interval [start, end).
:param start: Start of the interval.
:type start: int
:param stop: End of the interval.
:type stop: int
"""
return frontend.arange(start, end, builder)
@builtin
def zeros(shape, dtype, builder=None):
"""
Returns a block filled with the scalar value 0 and the given shape.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., triton.float16
:type dtype: triton.ir.dtype
"""
return frontend.zeros(shape, dtype, builder)
# -----------------------
# Shape Manipulation
# -----------------------
@builtin
def broadcast(input, other, builder=None):
"""
Tries to broadcast two blocks to a common compatible shape.
:param input: The first input block.
:type input: triton.ir.value
:param other: The second input block.
:type other: triton.ir.value
"""
return frontend.broadcast(input, other, builder)
@builtin
def broadcast_to(input, shape, builder=None):
"""
Tries to broadcast a block to a new shape.
:param input: The input block.
:type input: triton.value
:param shape: The new shape.
:type shape: tuple of int
"""
return frontend.broadcast_to(input, shape, builder)
@builtin
def reshape(input, shape, builder=None):
"""
Reshapes a block to a new shape.
"""
return frontend.reshape(input, shape, builder)
# -----------------------
# Linear Algebra
# -----------------------
@builtin
def dot(input, other, builder=None):
"""
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {`float16`, `float32`}
:param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {`float16`, `float32`}
"""
return frontend.dot(input, other, builder)
# -----------------------
# Memory Operations
# -----------------------
@builtin
def load(pointer, mask=None, other=None, builder=None):
"""
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
:param pointer: Pointer to the data to be loaded.
:type pointer: Block of triton.pointer
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
:type mask: Block of triton.bool, optional
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
:type other: Block of triton.value, optional
"""
return frontend.load(pointer, mask, other, builder)
@builtin
def store(pointer, value, mask=None, builder=None):
"""
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
:param pointer: The memory locations where the elements of `value` are stored.
:type pointer: Block of triton.pointer
:param value: The block of elements to be stored.
:type value: Block of triton.value
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
:type mask: Block of triton.bool, optional
"""
return frontend.store(pointer, value, mask, builder)
@builtin
def atomic_cas(ptr, cmp, val, builder=None):
return frontend.atomic_cas(ptr, cmp, val, builder)
@builtin
def atomic_xchg(ptr, val, builder=None):
return frontend.atomic_xchg(ptr, val, builder)
# -----------------------
# Conditioning
# -----------------------
@builtin
def where(condition, x, y, builder=None):
"""
Returns a block of elements from either `x` or `y`, depending on `condition`.
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
The shape of `x` and `y` are both broadcast to the shape of `condition`.
`x` and `y` must have the data type.
:param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool
:param x: values selected at indices where condition is True.
:param y: values selected at indices where condition is False.
"""
return frontend.where(condition, x, y, builder)
# -----------------------
# Math
# -----------------------
@builtin
def exp(x, builder=None):
return frontend.exp(x, builder)
@builtin
def log(x, builder=None):
return frontend.log(x, builder)
# -----------------------
# Reductions
# -----------------------
@builtin
def max(input, axis, builder=None):
return frontend.max(input, axis, builder)
@builtin
def min(input, axis, builder=None):
return frontend.min(input, axis, builder)
@builtin
def sum(input, axis, builder=None):
return frontend.sum(input, axis, builder)
# -----------------------
# Internal for debugging
# -----------------------
@builtin
def debug_barrier(builder=None):
return frontend.debug_barrier(builder)
@builtin
def multiple_of(x, value, builder=None):
return frontend.multiple_of(x, value, builder)
# -----------------------
# Standard library
# -----------------------
@triton.jit
def minimum(x, y):
return triton.where(x < y, x, y)
@triton.jit
def maximum(x, y):
return triton.where(x > y, x, y)
@triton.jit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
@triton.jit
def ravel(x):
return triton.reshape(x, [x.type.numel])
@triton.jit
def softmax(x):
z = x - triton.max(x, 0)
num = triton.exp(z)
den = triton.sum(num, 0)
return num / den
def cdiv(x, y):
return (x + y - 1) // y

View File

@@ -1,119 +0,0 @@
import os
import struct
from typing import Optional, Dict, List, Callable
import torch
import triton._C.libtriton.triton as _triton
codes = {
_triton.runtime.arg_type.int1: 'B',
_triton.runtime.arg_type.int8: 'B',
_triton.runtime.arg_type.int32: 'I',
_triton.runtime.arg_type.int64: 'Q',
_triton.runtime.arg_type.half: 'H',
_triton.runtime.arg_type.float: 'f',
_triton.runtime.arg_type.double: 'd',
_triton.runtime.arg_type.buffer: 'P',
}
def th_to_triton(obj):
""" Convert a `torch.dtype` to a Triton-C type string. """
tys = {
torch.int8: 'char',
torch.int16: 'short',
torch.int32: 'int',
torch.int64: 'long',
torch.float16: 'half',
torch.float32: 'float',
torch.float64: 'double',
}
if isinstance(obj, torch.dtype):
return tys[obj]
return str(obj)
def cdiv(a: int, b: int) -> int:
""" Ceil division (a + b - 1) // b"""
return (a + b - 1) // b
def read(path: str, kernel_names: Optional[List] = None) -> str:
""" Extracts the source code for `kernel_names` from the given `path` file."""
if kernel_names is None:
kernel_names = []
with open(path, 'r') as f:
source = f.read()
source = _triton.tools.extract_kernels(source, kernel_names)
return source
config = _triton.runtime.config
class kernel:
"""
A class used to represent a Triton kernel.
"""
def __init__(
self,
src: str,
device: torch.device,
defines: Optional[Dict] = None,
num_warps: int = 4,
autotune_configs: Optional[List] = None,
autotune_key: Optional[List] = None
):
"""
:param src: The source code of the kernel.
:param device: The device to compile the kernel for.
:param defines: A dictionary of preprocessor #define for the compiler.
:param num_warps: Optimization flag for the compiler's internal auto-parallelization engine.
:param autotune_configs: A list of triton.config objects for the autotuner to try.
:param autotune_key: A list of kernel argument names whose change in value should trigger the autotuner to re-run.
"""
if defines is None:
defines = {}
if autotune_configs is None:
autotune_configs = []
if autotune_key is None:
autotune_key = []
# check if src is empty
if src == '':
raise ValueError('Kernel source code is empty')
self.src = src
# device
assert device.type in ['cuda', 'cpu']
if device.type == 'cuda':
self.device_id = torch.cuda.current_device() if device.index is None else device.index
self.device = _triton.driver.cu_device(self.device_id, False)
cu_stream = torch.cuda.current_stream(self.device_id).cuda_stream
self.stream = _triton.driver.cu_stream(cu_stream, False)
if device.type == 'cpu':
self.device_id = -1
self.device = _triton.driver.host_device()
self.device = _triton.driver.host_stream()
torch.cuda.set_device(self.device_id)
# function
self.opt = _triton.runtime.options()
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
self.opt.num_warps = num_warps
# autotune_configs = [({}, 4)]
self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_configs, autotune_key)
self.tys = ''.join([codes[x] for x in self.fn.signature()])
def __call__(self, *args, grid: Callable[[_triton.runtime.options], tuple]):
"""
Runs the kernel on the given arguments and launch grid.
:param args: The arguments to the kernel in the orders that they appear in the Triton-C source.
:param grid: The launch grid for the kernel, i.e., callable that transform compilation options into a tuple of at most 3 integers.
:return: None
"""
# make sure that the executing thread is on the right device
torch.cuda.set_device(self.device_id)
# pack parameters into a byte buffer
params = struct.pack(self.tys, *args)
kernel = self.fn.autotune(params, grid, self.stream)
# run kernel
grid = grid(kernel.opt)
kernel(params, self.stream, grid)

View File

@@ -1,4 +1,4 @@
from .conv import _conv, conv
#from .conv import _conv, conv
from .matmul import _matmul, matmul
from .cross_entropy import _cross_entropy, cross_entropy
from . import blocksparse

View File

@@ -1,199 +0,0 @@
__global__ void NAME(TYPE *A __readonly __noalias,
TYPE *B __readonly __noalias,
TYPE *C __noalias,
int lda,
int ldb,
int ldc,
long stride_za,
long stride_zb,
long stride_zc,
long stride_ha,
long stride_hb,
long stride_hc,
int DS0, int DS1,
int SDD_K,
int SDD_off_width,
int *lut, int *locks, int nlocks) {
/* ---------------- */
/* Prologue */
/* ---------------- */
// program ids
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pidz = get_program_id(2);
#ifdef SDD
// load LUT header
pid1 = pid1 + SDD_off_width;
int blockidm[TM] = (0 ... TM) / BLOCK;
int blockidn[TN] = (0 ... TN) / BLOCK;
int offlutm[TM] = blockidm * (TN / BLOCK) * 4;
int offlutn[TN] = blockidn * 4;
int *header = lut + pid1 * (TM / BLOCK) * (TN / BLOCK) * 4;
int z = *(header + 0);
int i[TM] = *(header + 1 + offlutm);
int j[TN] = *(header + 2 + offlutn);
int AS1 = SDD_K / TZ;
int lockid = select(TZ > 1, 1, 0);
int offka = pid0 * AS1;
int offkb = pid0 * AS1;
int offmc = 0;
int offnc = 0;
int offpa = 0;
int offpb = 0;
int maxid = TZ;
int offhc = 0;
int offha = z;
int offhb = z;
int ram[TM] = i * BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j * BLOCK + ((0 ... TN) % BLOCK);
#else
// load LUT header
int *header = lut + pid0 * 6;
int offset = *(header + 0);
int AS1 = *(header + 1);
int column = *(header + 2);
int depth = *(header + 3);
int lockid = *(header + 4);
int maxid = *(header + 5);
int *pinc = lut + offset;
int offhc = depth;
#ifdef DSD
// output offset
int offnc = pid1 * TN;
int offmc = column * TM;
int offpc = 0;
// dense input offset
int offnb = pid1 * TN;
int offkb __multipleof(8) = *pinc;
int offpb = 0;
// sparse input offset
int offma = 0;
int offka = 0;
long offpa __multipleof(8) = *(pinc + 1);
offpa = offpa * BLOCK * BLOCK;
int offha = 0;
int offhb = depth;
#endif
#ifdef DDS
// output offset
int offmc = pid1 * TM;
int offnc = column * TN;
int offpc = 0;
// dense input offset
int offma = pid1 * TM;
int offka __multipleof(8) = *pinc;
int offpa = 0;
// sparse input offset
int offnb = 0;
int offkb = 0;
long offpb __multipleof(8) = *(pinc + 1);
offpb = offpb * BLOCK * BLOCK;
int offha = depth;
int offhb = 0;
#endif
int ram[TM] = offma + 0 ... TM;
int rbn[TN] = offnb + 0 ... TN;
#endif
// initialize a, b pointers
int rka[TK] = offka + 0 ... TK;
int rkb[TK] = offkb + 0 ... TK;
TYPE *pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka [newaxis, :] * STRIDE_AK;
TYPE *pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn [newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
// pre-fetch
#ifdef DDS
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
#else
bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
bool checkbn[TK, TN] = rbn [newaxis, :] < DS0;
#else
bool checkbn[TK, TN] = AS1 > 0;
#endif
TYPE a[TM, TK] = checkam ? *pa : 0;
TYPE b[TK, TN] = checkbn ? *pb : 0;
/* ---------------- */
/* Inner Loop */
/* ---------------- */
// create result tile
float acc[TM, TN] = 0;
int step = TK;
for (int k = AS1; k > 0; k -= step) {
acc += a @b;
// update pointers
#ifdef SDD
int inc_a = TK * STRIDE_AK;
int inc_b = TK * STRIDE_BK;
#else
pinc += 2;
#ifdef DSD
int inc_b __multipleof(8) = *pinc;
int inc_a __multipleof(8) = *(pinc + 1);
inc_b = inc_b * STRIDE_BK;
#endif
#ifdef DDS
int inc_a __multipleof(8) = *pinc;
int inc_b __multipleof(8) = *(pinc + 1);
inc_a = inc_a * STRIDE_AK;
#endif
#endif
pa += inc_a;
pb += inc_b;
// pre-fetch
bool checkak[TM, TK] = k > TK;
bool checkbk[TK, TN] = k > TK;
bool checka[TM, TK] = checkam && checkak;
bool checkb[TK, TN] = checkbk && checkbn;
a = *? (checka)pa;
b = *? (checkb)pb;
}
TYPE c[TM, TN] = acc;
/* ---------------- */
/* Epilogue */
/* ---------------- */
// initialize c pointers
#ifdef SDD
bool checkc[TM, TN] = 1;
// rematerialize
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm * (TN / BLOCK) * 4;
int rr_offlutn[TN] = rr_blockidn * 4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn [newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
int rcm[TM] = (0 ... TM) % BLOCK;
int rcn[TN] = (0 ... TN) % BLOCK;
#else
int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD
bool checkc[TM, TN] = rcn [newaxis, :] < DS0;
#endif
#ifdef DDS
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
TYPE *pc[TM, TN] = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, newaxis] * STRIDE_CM + rcn [newaxis, :] * STRIDE_CN;
// write-back directly
if (lockid == 0) {
*? (checkc)pc = c;
}
// accumulate partial result using spin-locks
else {
int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks;
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;
else
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % maxid);
atomic_xchg(plock, 0);
}
}

View File

@@ -4,7 +4,183 @@ import torch
import os
import math
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
@triton.jit
def _kernel(
A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc,
stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta
):
TM = meta['TM']
TN = meta['TN']
TK = meta['TK']
TZ = meta['TZ']
BLOCK = meta['BLOCK']
#------------#
#- Prologue -#
#------------#
pid0 = triton.program_id(0)
pid1 = triton.program_id(1)
pidz = triton.program_id(2)
if meta['SDD']:
pid1 = pid1 + SDD_off_width
blockidm = triton.arange(0, TM) // BLOCK
blockidn = triton.arange(0, TN) // BLOCK
offlutm = blockidm * (TN // BLOCK) * 4
offlutn = blockidn * 4
header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
z = triton.load(header + 0)
i = triton.load(header + 1 + offlutm)
j = triton.load(header + 2 + offlutn)
AS1 = SDD_K // TZ
lockid = triton.where(TZ > 1, 1, 0)
offka = pid0 * AS1
offkb = pid0 * AS1
offmc = 0
offnc = 0
offpa = 0
offpb = 0
maxid = TZ
offhc = 0
offha = z
offhb = z
ram = i * BLOCK + (triton.arange(0, TM) % BLOCK)
rbn = j * BLOCK + (triton.arange(0, TN) % BLOCK)
else:
header = lut + pid0 * 6
offset = triton.load(header + 0)
AS1 = triton.load(header + 1)
column = triton.load(header + 2)
depth = triton.load(header + 3)
lockid = triton.load(header + 4)
maxid = triton.load(header + 5)
pinc = lut + offset
offhc = depth
if meta['DSD']:
# output offset
offnc = pid1 * TN
offmc = column * TM
offpc = 0
# dense input offset
offnb = pid1 * TN
offkb = triton.load(pinc)
offkb = triton.multiple_of(offkb, 8) # compiler hint
offpb = 0
# sparse input offset
offma = 0
offka = 0
offpa = triton.load(pinc + 1)
offpa = triton.multiple_of(offpa, 8) # compiler hint
offpa = offpa * BLOCK * BLOCK
offha = 0
offhb = depth
else:
# output offset
offmc = pid1 * TM
offnc = column * TN
offpc = 0
# dense input offset
offma = pid1 * TM
offka = triton.load(pinc)
offka = triton.multiple_of(offka, 8) # compiler hint
offpa = 0
# sparse input offset
offnb = 0
offkb = 0
offpb = triton.load(pinc + 1)
offpb = triton.multiple_of(offpb, 8) # compiler hint
offpb = offpb * BLOCK * BLOCK
offha = depth
offhb = 0
ram = offma + triton.arange(0, TM)
rbn = offnb + triton.arange(0, TN)
# initialize a, b pointers
rka = offka + triton.arange(0, TK)
rkb = offkb + triton.arange(0, TK)
pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
if meta['DDS']:
checkam = ram[:, None] < DS0
else:
checkam = AS1 > 0
if meta['DSD']:
checkbn = rbn[None, :] < DS0
else:
checkbn = AS1 > 0
a = triton.load(pa, mask=checkam, other=0.)
b = triton.load(pb, mask=checkbn, other=0.)
## ---------------- ##
## Inner Loop ##
## ---------------- ##
acc = triton.zeros((TM, TN), dtype=triton.float32)
for k in range(AS1, 0, -TK):
acc += triton.dot(a, b)
if meta['SDD']:
inc_a = TK * stride_ka
inc_b = TK * stride_kb
else:
pinc += 2
if meta['DSD']:
inc_b = triton.load(pinc)
inc_a = triton.load(pinc + 1)
inc_b = triton.multiple_of(inc_b, 8)
inc_a = triton.multiple_of(inc_a, 8)
inc_b = inc_b * stride_kb
if meta['DDS']:
inc_a = triton.load(pinc)
inc_b = triton.load(pinc + 1)
inc_a = triton.multiple_of(inc_a, 8)
inc_b = triton.multiple_of(inc_b, 8)
inc_a = inc_a * stride_ka
pa += inc_a
pb += inc_b
# pre-fetch
checkak = k > TK
checkbk = k > TK
checka = checkam & checkak
checkb = checkbn & checkbk
a = triton.load(pa, mask=checka)
b = triton.load(pb, mask=checkb)
c = acc.to(C.dtype.element_ty)
if meta['SDD']:
checkc = True
rr_blockidm = triton.arange(0, TM) // BLOCK
rr_blockidn = triton.arange(0, TN) // BLOCK
rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
rr_offlutn = rr_blockidn * 4
off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
bkid = triton.load(header + off_bkid)
offpc = bkid * BLOCK * BLOCK
rcm = triton.arange(0, TM) % BLOCK
rcn = triton.arange(0, TN) % BLOCK
else:
rcm = offmc + triton.arange(0, TM)
rcn = offnc + triton.arange(0, TN)
if meta['DSD']:
checkc = rcn[None, :] < DS0
if meta['DDS']:
checkc = rcm[:, None] < DS0
pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
# write-back directly
if lockid == 0:
triton.store(pc, c, mask=checkc)
# accumulate partial results using spin-locks
else:
plock = locks + triton.program_id(2) * nlocks * triton.num_programs(1) + triton.program_id(1) * nlocks + lockid - 1
pcount = plock + triton.num_programs(2) * triton.num_programs(1) * nlocks
while triton.atomic_cas(plock, 0, 1) == 1:
pass
count = triton.load(pcount)
if count == 0:
triton.store(pc, c, mask=checkc)
else:
d = triton.load(pc, mask=checkc)
triton.store(pc, d + c, mask=checkc)
triton.atomic_xchg(pcount, (count + 1) % maxid)
triton.atomic_xchg(plock, 0)
##############
@@ -118,31 +294,11 @@ class _matmul(torch.autograd.Function):
raise ValueError('Reduction size for SDD must be a multiple of 16')
# create kernel
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
c = torch.zeros((AS0, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
defines = {
'TM': block * pack,
'TN': block * pack,
'TMN': block * block * pack * pack,
'BLOCK': block,
'TK': 32,
'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc',
'STRIDE_CN': '1',
'SDD': True,
'TZ': 1,
'NAME': 'sdd_kernel'
}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
kernel = _matmul.sdd_cache[key]
meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1, \
'SDD': True, 'DSD': False, 'DDS': False}
# create output
locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
# maximum grid size is 65535
@@ -150,27 +306,32 @@ class _matmul(torch.autograd.Function):
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
kernel(
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.stride(2),
b.stride(2),
block,
grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), AS0]
_kernel[grid](
a,
b,
c,
a.stride(0),
b.stride(0),
c.stride(0),
a.stride(1),
a.stride(3 if trans_a else 2),
a.stride(2 if trans_a else 3),
b.stride(0),
b.stride(1),
b.stride(3 if trans_b else 2),
b.stride(2 if trans_b else 3),
c.stride(0),
c.stride(0),
c.stride(2),
c.stride(3),
AS2,
AS2,
AS3,
off_width,
lut.data_ptr(),
locks.data_ptr(),
lut,
locks,
num_lock,
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]
num_warps=4,
**meta
)
# save for backward pass
return c
@@ -282,25 +443,8 @@ class _matmul(torch.autograd.Function):
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
defines = {
'TM': 128,
'TN': block,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1,
'STRIDE_BN': block if trans_b else 1,
'STRIDE_BK': 1 if trans_b else block,
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True
}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dds_cache[key]
meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
'SDD': False, 'DSD': False, 'DDS': True}
# output
CS0 = AS0
CS1 = AS1
@@ -308,27 +452,32 @@ class _matmul(torch.autograd.Function):
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.stride(2),
block,
c.stride(2),
grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0]
_kernel[grid](
a,
b,
c,
a.stride(0),
b.stride(0),
c.stride(0),
a.stride(1),
a.stride(3 if trans_a else 2),
a.stride(2 if trans_a else 3),
b.stride(0),
b.stride(1),
b.stride(3 if trans_b else 2),
b.stride(2 if trans_b else 3),
c.stride(0),
c.stride(1),
c.stride(3 if trans_c else 2),
c.stride(2 if trans_c else 3),
AS2,
BS2,
0,
0,
lut.data_ptr(),
locks.data_ptr(),
lut,
locks,
num_locks,
grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0]
num_warps=4,
**meta
)
return c
@@ -344,25 +493,8 @@ class _matmul(torch.autograd.Function):
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
defines = {
'TM': block,
'TN': 128,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1,
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True
}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key]
meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
'SDD': False, 'DSD': True, 'DDS': False}
# output
CS0 = BS0
CS1 = BS1
@@ -370,27 +502,32 @@ class _matmul(torch.autograd.Function):
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
block,
b.stride(2),
c.stride(2),
grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0]
_kernel[grid](
a,
b,
c,
a.stride(0),
b.stride(0),
c.stride(0),
a.stride(1),
a.stride(3 if trans_a else 2),
a.stride(2 if trans_a else 3),
b.stride(0),
b.stride(1),
b.stride(3 if trans_b else 2),
b.stride(2 if trans_b else 3),
c.stride(0),
c.stride(1),
c.stride(2),
c.stride(3),
BS3,
AS1,
0,
0,
lut.data_ptr(),
locks.data_ptr(),
lut,
locks,
num_locks,
grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0]
num_warps=4,
**meta
)
return c

View File

@@ -1,135 +0,0 @@
__global__ void forward(TYPE *X __readonly __noalias,
float scale,
int *LUT __readonly __noalias,
TYPE *RPE __readonly __noalias,
TYPE *KP_M __readonly __noalias,
TYPE *ATTN_M __readonly __noalias,
int sizemax,
long stride_zx,
long stride_zrpe,
int stride_hrpe,
int stride_srpe,
int stride_zkpm,
int stride_zattnm) {
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table
int *header = LUT + rbm * 2;
int size = *(header + 0);
int offset = *(header + 1);
bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1;
// block id and column id
long blockid[TN] = *(LUT + offset + rbmn * 4 + 0);
long columnid[TN] = *(LUT + offset + rbmn * 4 + 1);
long rowid[TN] = *(LUT + offset + rbmn * 4 + 2);
long headid[TN] = *(LUT + offset + rbmn * 4 + 3);
// pointers to X
TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
#ifdef APPLY_RPE
// pointers to relative position embedding
TYPE *prpe[TN] = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn;
#endif
#ifdef APPLY_KP_MASK
// pointers to key padding mask
TYPE *pkp_m[TN] = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn;
#endif
#ifdef APPLY_ATTN_MASK
// pointers to attention mask
TYPE *pattn_m[TN] = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn;
#endif
// load input
TYPE x[TN] = check ? *px : -INFINITY;
#ifdef APPLY_RPE
// load relative position embedding
TYPE rpe[TN] = check ? *prpe : 0;
#endif
#ifdef APPLY_KP_MASK
// load key-padding mask
TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
#endif
#ifdef APPLY_ATTN_MASK
// load attention mask
TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
#endif
// compute softmax in float
#ifdef APPLY_RPE
float Frpe[TN] = rpe;
#endif
#ifdef APPLY_KP_MASK
float Fkp_m[TN] = kp_m;
#endif
#ifdef APPLY_ATTN_MASK
float Fattn_m[TN] = attn_m;
#endif
#ifdef KP_MASK_MUL
Fkp_m = (Fkp_m == 0) ? (float[TN]) - INFINITY : 0;
#endif
#ifdef ATTN_MASK_MUL
Fattn_m = (Fattn_m == 0) ? (float[TN]) - INFINITY : 0;
#endif
float Fx[TN] = x;
#ifdef APPLY_SCALE
Fx = Fx * scale; // apply scale
#endif
#ifdef APPLY_RPE
Fx = Fx + Frpe; // apply relative position embedding
#endif
#ifdef APPLY_KP_MASK
Fx = Fx + Fkp_m; // apply key padding mask
#endif
#ifdef APPLY_ATTN_MASK
Fx = Fx + Fattn_m; // apply attention mask
#endif
float Fxmax = Fx[max];
float Fy[TN] = exp(Fx - Fxmax);
float Fysum = (check ? Fy : 0)[+];
// write-back in half/float
TYPE y[TN] = Fy;
TYPE ysum = Fysum;
*? (check)px = y / ysum;
}
__global__ void backward(TYPE *X __readonly __noalias,
float scale,
TYPE *DX __readonly __noalias,
int *LUT,
int sizemax,
long stride_zx,
long stride_zdx) {
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table
int *header = LUT + rbm * 2;
int size = *(header + 0);
int offset = *(header + 1);
// bounds checking on lut
bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1;
// initialize pointers to block-sparse input
long blockid[TN] = *(LUT + offset + rbmn * 4);
TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
TYPE *pdx[TN] = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
// compute fused softmax backward
TYPE x[TN] = check ? *px : 0;
TYPE dx[TN] = check ? *pdx : 0;
float Fdx[TN] = dx;
float Fx[TN] = x;
float Fxdx[TN] = Fdx * Fx;
float Fxdxsum = Fxdx[+];
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
TYPE y[TN] = Fy;
// write-back
*? (check)pdx = y;
}

View File

@@ -2,24 +2,118 @@ import triton
import torch
import os
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
fwd_kernels = dict()
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['backward'])
bwd_kernels = dict()
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
def num_warps(n):
if n < 512:
return 4
if n < 2048:
return 8
return 16
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[6] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[6] * meta['BLOCK'])})
@triton.jit
def _forward(
X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
**meta
):
TN = meta['TN']
BLOCK = meta['BLOCK']
pidhm = triton.program_id(0)
pidz = triton.program_id(1)
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
rxn = triton.arange(0, TN) % BLOCK
rbn = triton.arange(0, TN) // BLOCK
# extract information from LUT
header = LUT + rbm * 2
size = triton.load(header + 0)
offset = triton.load(header + 1)
check = rbn < size
rbmn = triton.where(check, rbn, size - 1)
# block id and column id
blockid = triton.load(LUT + offset + rbmn * 4 + 0)
columnid = triton.load(LUT + offset + rbmn * 4 + 1)
rowid = triton.load(LUT + offset + rbmn * 4 + 2)
headid = triton.load(LUT + offset + rbmn * 4 + 3)
# pointers to X
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
x = triton.load(px, mask=check, other=-float('inf'))
x = x.to(triton.float32)
# apply scale
if meta['APPLY_SCALE']:
x = x * scale
# apply RPE
if meta['APPLY_RPE']:
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
rpe = triton.load(prpe, mask=check, other=0)
x = x + rpe
# apply key-padding mask
if meta['APPLY_KP_MASK']:
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
kp_m = triton.load(pkp_m, mask=check, other=-float('inf'))
if meta['KP_MASK_MUL']:
kp_m = triton.where(kp_m == 0, -float('inf'), 0.)
x = x + kp_m
# apply attention mask
if meta['APPLY_ATTN_MASK']:
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
attn_m = triton.load(pattn_m, mask=check, other=-float('inf'))
if meta['ATTN_MASK_MUL']:
attn_m = triton.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m
# computation
x = triton.softmax(x)
triton.store(px, x, mask=check)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})
@triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
pidhm = triton.program_id(0)
pidz = triton.program_id(1)
TN = meta['TN']
BLOCK = meta['BLOCK']
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
rxn = triton.arange(0, TN) % BLOCK
rbn = triton.arange(0, TN) // BLOCK
# extract information from look-up table
header = LUT + rbm * 2
size = triton.load(header + 0)
offset = triton.load(header + 1)
# bounds checking on lut
check = rbn < size
rbmn = triton.where(check, rbn, size - 1)
# initialize pointers to block-sparse input
blockid = triton.load(LUT + offset + rbmn * 4)
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
# compute fused softmax backward
x = triton.load(X, mask=check, other=0)
dx = triton.load(DX, mask=check, other=0)
x = x.to(triton.float32)
dx = dx.to(triton.float32)
y = x * (dx - triton.sum(x * dx, 0)) * scale
triton.store(DX, y, mask=check)
class _softmax(torch.autograd.Function):
@staticmethod
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
@@ -43,40 +137,9 @@ class _softmax(torch.autograd.Function):
return lut, int(sizes.max())
@staticmethod
def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
kp_mask_mode, attn_mask_mode):
if max_k >= 32768:
raise NotImplementedError('Reductions larger than 32768 elements '\
'are not yet implemented')
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
TN = _softmax.next_power_of_2(max_k)
# just-in-time compile kernel
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
kp_mask_mode, attn_mask_mode)
if key not in cache:
defines = {
'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, 'INFINITY':
{torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype]
}
if apply_scale:
defines['APPLY_SCALE'] = True
if apply_rpe:
defines['APPLY_RPE'] = True
if apply_kp_mask:
defines['APPLY_KP_MASK'] = True
if kp_mask_mode == 'mul':
defines['KP_MASK_MUL'] = True
if apply_attn_mask:
defines['APPLY_ATTN_MASK'] = True
if attn_mask_mode == 'mul':
defines['ATTN_MASK_MUL'] = True
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
cache[key] = kernel
return cache[key]
@staticmethod
def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,
maxlut, bench, time):
def forward(
ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut, maxlut, bench, time
):
apply_scale = False if scale == 1.0 else True
# handle None rpe
@@ -107,26 +170,20 @@ class _softmax(torch.autograd.Function):
stride_zattnm = attn_mask.stride(0)
# run kernel
kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale,
apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
M = x.shape[0]
meta = {
'BLOCK': block,
'APPLY_SCALE': apply_scale,
'APPLY_RPE': apply_rpe,
'APPLY_KP_MASK': apply_kp_mask,
'APPLY_ATTN_MASK': apply_attn_mask,
'KP_MASK_MUL': kp_mask_mode == 'mul',
'ATTN_MASK_MUL': attn_mask_mode == 'mul',
}
grid = lambda opt: [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)
# run kernel
kernel(x.data_ptr(),
scale,
lut.data_ptr(),
rpe.data_ptr(),
key_padding_mask.data_ptr(),
attn_mask.data_ptr(),
maxlut,
x.stride(0),
stride_zrpe,
stride_hrpe,
stride_srpe,
stride_zkpm,
stride_zattnm,
grid=grid)
# save to context
ctx.mark_dirty(x)
ctx.save_for_backward(x, lut)
@@ -147,14 +204,12 @@ class _softmax(torch.autograd.Function):
# retrieve from context
x, lut = ctx.saved_tensors
# run kernel
kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut * ctx.block, x.device, x.dtype, ctx.block,
ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask,
ctx.kp_mask_mode, ctx.attn_mask_mode)
M = x.shape[0]
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class softmax:
apply_softmax = _softmax.apply
@@ -172,14 +227,9 @@ class softmax:
self.bench = bench
self.lut_cache = dict()
def __call__(self,
x,
scale=1.,
rpe=None,
key_padding_mask=None,
attn_mask=None,
key_padding_mask_mode='add',
attn_mask_mode='add'):
def __call__(
self, x, scale=1., rpe=None, key_padding_mask=None, attn_mask=None, key_padding_mask_mode='add', attn_mask_mode='add'
):
time_y = [None]
if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype)
@@ -188,6 +238,8 @@ class softmax:
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device)
x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode,
self.spdims, self.block, lut, maxlut, self.bench, time_y)
x = softmax.apply_softmax(
x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode, self.spdims, self.block, lut,
maxlut, self.bench, time_y
)
return x

View File

@@ -1,123 +0,0 @@
__global__ void conv(TYPE *A __noalias __readonly,
TYPE *B __noalias __readonly,
TYPE *C __noalias,
float alpha,
// equivalent matmul
int M, int N, int K,
// convolution properties
int pad_h, int pad_w, int stride_h, int stride_w,
// pointer increment
int *ADELTA,
// memory strides
int lda_z, int lda_ci, int lda_h, int lda_w,
int ldb_ci, int ldb_r, int ldb_s, int ldb_co,
int ldc_z, int ldc_co, int ldc_p, int ldc_q)
{
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
ridx = rid / gridy;
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
// unpack aggregate rows
// m = (z, p, q)
int rq[TM] = rm % QQ;
int rzp[TM] = rm / QQ;
int rp[TM] = rzp % PP;
int rz[TM] = rzp / PP;
// unpack aggregate reduction
// k = (ci, r, s)
int rs[TK] = rk % SS;
int rcir[TK] = rk / SS;
int rr[TK] = rcir % RR;
int rci[TK] = rcir / RR;
// padding / striding
int rh_0[TM] = rp * stride_h - pad_h;
int rw_0[TM] = rq * stride_w - pad_w;
int rh[TM, TK] = rh_0[:, newaxis] + rr [newaxis, :];
int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :];
// pointers to lhs
int offa[TM, TK] = rz[:, newaxis] * lda_z + rci [newaxis, :] * lda_ci +
rh * lda_h + rw * 1;
TYPE *pa[TM, TK] = A + offa;
int *padelta[TK] = ADELTA + rk;
// pointers to rhs
int offb[TK, TN] = rci[:, newaxis] * ldb_ci + rr[:, newaxis] * ldb_r +
rs[:, newaxis] * ldb_s + rn [newaxis, :] * 1;
TYPE *pb[TK, TN] = B + offb;
// prefetches operands
bool checkam[TM, TK] = rm[:, newaxis] < M;
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
int total = 0;
// reduction loop
float acc[TM, TN] = 0;
for (int k = K; k > 0; k -= TK)
{
acc += a @b;
// increment A
int adelta[TK] = *padelta;
padelta += TK;
pa += adelta [newaxis, :];
// bounds-checking A
rk += TK;
rs = rk % SS;
rcir = rk / SS;
rr = rcir % RR;
rh = rh_0[:, newaxis] + rr [newaxis, :];
rw = rw_0[:, newaxis] + rs [newaxis, :];
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
// increment B
pb += TK * ldb_s;
// bounds-checking B
bool checkb[TK, TN] = k > TK;
a = checka ? *pa : 0;
b = *? (checkb)pb;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
rm = ridx * TM + 0 ... TM;
rn = ridy * TN + 0 ... TN;
rq = rm % QQ;
rzp = rm / QQ;
rp = rzp % PP;
rz = rzp / PP;
int offc[TM, TN] = rz[:, newaxis] * ldc_z + rn [newaxis, :] * ldc_co +
rp[:, newaxis] * ldc_p + rq[:, newaxis] * 1;
TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N;
#if (TZ == 1)
*? (checkc)pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;
else
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}

View File

@@ -1,81 +0,0 @@
import torch
import triton
import os
class _conv(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c'))
kernel = dict()
@staticmethod
def unpack(IDX, CI, R, S):
s = IDX % S
cr = IDX // S
r = cr % R
ci = cr // R
return ci, r, s
@staticmethod
def forward(ctx, a, b, pad, stride):
# create kernel if necessary
dtype = a.dtype
device = a.device
# shapes
Z, CI, H, W = a.shape
_, R, S, CO = b.shape
P = (H + 2 * pad[0] - R) // stride[0] + 1
Q = (W + 2 * pad[1] - S) // stride[1] + 1
# compile kernel
if (dtype, device) not in _conv.kernel:
TK = 16
defines = {
'TYPE': dtype,
'TM': 64,
'TN': 64,
'TK': TK,
'TZ': 1,
'HH': H,
'WW': W,
'PP': P,
'QQ': Q,
'SS': S,
'RR': R,
}
idx = torch.arange(CI * R * S)
ci, r, s = _conv.unpack(idx, CI, R, S)
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (ns - s) * a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
# enqueue
kernel(
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
1.,
Z * P * Q,
CO,
CI * R * S,
pad[0],
pad[1],
stride[0],
stride[1],
delta.data_ptr(),
a.stride(0),
a.stride(1),
a.stride(2),
a.stride(3),
b.stride(0),
b.stride(1),
b.stride(2),
b.stride(3),
c.stride(0),
c.stride(1),
c.stride(2),
c.stride(3),
grid=lambda opt: [triton.cdiv(Z * P * Q, opt.TM), triton.cdiv(CO, opt.TN)])
return c
conv = _conv.apply

View File

@@ -1,35 +0,0 @@
__global__ void forward(TYPE *logit, TYPE *modified_logit, long *indices, TYPE *result, int n_cols) {
int row = get_program_id(0);
bool check[TILE] = ((0 ... TILE) < n_cols);
int offset[TILE] = row * n_cols + 0 ... TILE;
TYPE *px[TILE] = logit + offset;
TYPE *pmodified[TILE] = modified_logit + offset;
long local_ind = *(indices + row);
TYPE F16[TILE] = check ? *px : -INFINITY;
float shifted_logit[TILE] = F16 - F16[max];
float neg_logprob[TILE] = log(exp(shifted_logit)[+]) - shifted_logit;
*? (check)pmodified = neg_logprob;
__debug_barrier();
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
}
__global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs, int n_cols) {
int row = get_program_id(0);
// pointer arithmetic
bool check[TILE] = ((0 ... TILE) < n_cols);
int offset[TILE] = row * n_cols + 0 ... TILE;
TYPE *px[TILE] = neg_logprobs + offset;
long local_ind = *(indices + row);
TYPE local_dn = *(dneg_logprobs + row);
// We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
// and we have -log(p[k]) stored, so this is easy
TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * px) : 0;
// selected_logit_idx is selected logit index for our token
bool find_one[TILE] = ((0 ... TILE) == local_ind);
intermediate = intermediate - ((TYPE[TILE])find_one);
// multiply by dneg_logprobs
*? (check)px = intermediate * local_dn;
}

View File

@@ -2,6 +2,7 @@ import os
import triton
import torch
def next_power_of_2(n):
n -= 1
n |= n >> 1
@@ -12,34 +13,61 @@ def next_power_of_2(n):
n += 1
return n
def largest_pow2_divisor(N):
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
return 1
def make_kernel(device, dtype, n_cols, cache, name):
rounded = next_power_of_2(n_cols)
div = largest_pow2_divisor(n_cols)
key = (dtype, rounded, div)
if key not in cache:
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
src = triton.read(fname, kernel_names=[name])
infinities = {
torch.float16: "F16_INFINITY",
torch.float32: "F32_INFINITY",
}
defines = {"TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div}
cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4)
return cache[key]
def num_warps(N):
if N < 2048:
return 4
elif N < 8192:
return 8
return 16
# forward kernel
fwd_kernels = dict()
make_fwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, fwd_kernels, "forward")
# backward kernel
bwd_kernels = dict()
make_bwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, bwd_kernels, "backward")
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])})
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])})
@triton.jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
BLOCK = meta['BLOCK']
row = triton.program_id(0)
cols = triton.arange(0, BLOCK)
idx = triton.load(IDX + row)
# pointers to logit and probs
LOGITS = LOGITS + row * N + cols
WRIT_PROBS = PROBS + row * N + cols
READ_PROBS = PROBS + row * N + idx
# write-back negative log-probs
logits = triton.load(LOGITS, mask=cols < N, other=-float('inf'))
logits = logits.to(triton.float32)
logits = logits - triton.max(logits, 0)
probs = triton.log(triton.sum(triton.exp(logits), 0)) - logits
triton.store(WRIT_PROBS, probs, mask=cols < N)
# There is a bug in the compiler, which fails to insert a barrier here.
# We add it explicitly for now. Will be fixed soon.
triton.debug_barrier()
# write-back loss
probs = triton.load(READ_PROBS)
triton.store(LOSS + row, probs)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])})
@triton.jit
def _backward(PROBS, IDX, DPROBS, N, **meta):
BLOCK = meta['BLOCK']
row = triton.program_id(0)
cols = triton.arange(0, BLOCK)
idx = triton.load(IDX + row)
# pointers to probs
PROBS = PROBS + row * N + cols
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
# and we have -log(p[k]) stored in PROBS, so this is easy
probs = -triton.load(PROBS, mask=cols < N, other=float('inf'))
probs = triton.exp(probs.to(triton.float32))
delta = cols == idx
# write result in-place in PROBS
dout = triton.load(DPROBS + row)
din = (probs - delta) * dout
triton.store(PROBS, din.to(triton.float16), mask=cols < N)
class _cross_entropy(torch.autograd.Function):
@classmethod
@@ -49,16 +77,11 @@ class _cross_entropy(torch.autograd.Function):
# make kernel
device, dtype = logits.device, logits.dtype
n_cols = logits.shape[-1]
kernel = make_fwd_kernel(device, dtype, n_cols)
# run the kernel
result = torch.empty_like(indices, dtype=dtype, device=device)
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
kernel(logits.data_ptr(),
neg_logprobs.data_ptr(),
indices.data_ptr(),
result.data_ptr(),
n_cols,
grid=lambda opt: (logits.numel() // n_cols, ))
grid = lambda opt: (logits.numel() // n_cols, )
_forward[grid](logits, neg_logprobs, indices, result, n_cols)
# save for backward
ctx.save_for_backward(neg_logprobs, indices)
return result
@@ -75,14 +98,11 @@ class _cross_entropy(torch.autograd.Function):
# make kernel
device, dtype = neg_logprobs.device, neg_logprobs.dtype
n_cols = neg_logprobs.shape[-1]
kernel = make_bwd_kernel(device, dtype, n_cols)
# run the kernel
# neg_logprobs will be modified in place to become our gradient:
kernel(neg_logprobs.data_ptr(),
indices.data_ptr(),
dneg_logprobs.data_ptr(),
n_cols,
grid=lambda opt: (neg_logprobs.numel() // n_cols, ))
grid = lambda opt: (neg_logprobs.numel() // n_cols, )
_backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)
return neg_logprobs, None
cross_entropy = _cross_entropy.apply

View File

@@ -1,94 +0,0 @@
#define STM 1
#define STN 1
__global__ void matmul(TYPE *A __noalias __readonly,
TYPE *B __noalias __readonly,
TYPE *C __noalias,
float alpha,
int M, int N, int K,
int lda, int ldb, int ldc,
int *locks) {
// prologue
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = (M + TM - 1) / TM;
int gridn = (N + TN - 1) / TN;
// swizzle for better L2 performance
int width = STM * gridn;
int stm = pid / width;
int RSTM = min(gridm - stm * STM, STM);
int stn = (pid % width) / (RSTM * STN);
int RSTN = min(gridn - stn * STN, STN);
int laneid = pid % (RSTM * RSTN);
int lanem = laneid / RSTN;
int lanen = laneid % RSTN;
int pidm = stm * STM + lanem;
int pidn = stn * STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// split-k for better parrallelism
K = K / SPLITK;
int rk[TK] = 0 ... TK;
// pointers to operands
int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = (pidz * K + rk[:, newaxis]) * STRIDE_BK + rn [newaxis, :] * STRIDE_BN;
TYPE *pa[TM, TK] = A + offa;
TYPE *pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rk [newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// reduction loop
float acc[TM, TN] = 0;
for (int k = K; k > 0; k -= TK) {
#if (IS_TK_DIV_K == 1)
bool checkk[TK] = k > TK;
#else
bool checkk[TK] = rk < k - TK;
#endif
bool checka[TM, TK] = checkk [newaxis, :];
bool checkb[TK, TN] = checkk[:, newaxis];
acc += a @b;
#if (IS_TK_DIV_K == 1)
a = *? (checka)pa;
b = *? (checkb)pb;
#else
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
#endif
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rcm[TM] = pidm * TM + 0 ... TM;
int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :];
TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N;
#if (SPLITK == 1)
*? (checkc)pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0);
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;
else
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % SPLITK);
atomic_xchg(plock, 0);
#endif
}

View File

@@ -1,108 +1,117 @@
import torch
import triton
import os
@triton.heuristics({
'EVEN_K': lambda *args, **meta: args[5] % (meta['BLOCK_K'] * meta['SPLIT_K']) == 0,
})
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
# triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2),\
# triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2),
],
key=['M', 'N', 'K']
)
@triton.jit
def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, LOCKS, **META):
# extract meta-parameters
BLOCK_M = META['BLOCK_M']
BLOCK_N = META['BLOCK_N']
BLOCK_K = META['BLOCK_K']
GROUP_M = META['GROUP_M']
SPLIT_K = META['SPLIT_K']
# matrix multiplication
pid = triton.program_id(0)
pid_z = triton.program_id(1)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K)
# pointers
K = K // SPLIT_K
A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
for k in range(K, 0, -BLOCK_K):
if META['EVEN_K']:
a = triton.load(A)
b = triton.load(B)
else:
a = triton.load(A, mask=rk[None, :] < k, other=0.)
b = triton.load(B, mask=rk[:, None] < k, other=0.)
acc += triton.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
acc = acc.to(triton.float16)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
triton.store(C, acc, mask=mask)
else:
LOCKS = LOCKS + triton.program_id(0)
COUNT = LOCKS + triton.num_programs(0)
while triton.atomic_cas(LOCKS, 0, 1) == 1:
pass
count = triton.load(COUNT)
if count == 0:
triton.store(C, acc, mask=mask)
else:
curr = triton.load(C, mask=mask, other=0.)
triton.store(C, acc + curr, mask=mask)
triton.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
triton.atomic_xchg(LOCKS, 0)
class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
_DEFAULT_CONFIGS = [
triton.config(defines={"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, num_warps=4),
triton.config(defines={'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
triton.config(defines={'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
]
_CONFIGS = _DEFAULT_CONFIGS
@staticmethod
def largest_pow2_divisor(N):
if N % 8 == 0:
return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1
kernel = _kernel
_locks = dict()
_kernels = dict()
@staticmethod
def _call(a, b):
dtype = a.dtype
device = a.device
# allocate output
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), dtype=dtype, device=device)
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# kernel hash
is_a_row = a.stride(1) == 1
is_b_row = b.stride(1) == 1
lda = a.stride(0) if is_a_row else a.stride(1)
ldb = b.stride(0) if is_b_row else b.stride(1)
ldc = c.stride(0)
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 64 == 0
key = (
device,
dtype,
is_a_row,
is_b_row,
lda_pow2_div,
ldb_pow2_div,
ldc_pow2_div,
is_tk_div_k,
)
if key not in _matmul._kernels:
defines = {
"TYPE": dtype,
"STRIDE_AM": "lda" if is_a_row else "1",
"STRIDE_AK": "1" if is_a_row else "lda",
"STRIDE_BK": "ldb" if is_b_row else "1",
"STRIDE_BN": "1" if is_b_row else "ldb",
"LDA_POW2_DIV": lda_pow2_div,
"LDB_POW2_DIV": ldb_pow2_div,
"LDC_POW2_DIV": ldc_pow2_div,
"IS_TK_DIV_K": int(is_tk_div_k),
}
_matmul._kernels[key] = triton.kernel(
_matmul.src,
device,
defines=defines,
autotune_configs=_matmul._CONFIGS,
autotune_key=["M", "N", "K"],
)
kernel = _matmul._kernels[key]
# # locks for split-k
if device not in _matmul._locks:
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
# allocate locks for split-k
if a.device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device]
# enqueue
alpha = 1.0
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.SPLITK]
kernel(*args, grid=grid)
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), locks)
# done
return c
@staticmethod
def forward(ctx, a, b):
c = _matmul._call(a, b)
return c
return _matmul._call(a, b)
matmul = _matmul.apply

View File

@@ -47,13 +47,37 @@ def mask_tensor(x, mask, block, value=0):
def allclose(x, y, tol=1e-2):
assert x.dtype == y.dtype
if x.dtype != y.dtype:
raise RuntimeError(f'{x.dtype} did not match with {x.dtype}')
if x.shape != y.shape:
raise RuntimeError(f'{x.shape} did not match with {y.shape}')
if x.dtype == torch.bool:
return torch.sum(x ^ y) == 0
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
tol = 0
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err < tol
return err <= tol
def assert_allclose(x, y, tol=1e-2):
assert x.dtype == y.dtype
assert allclose(x, y, tol)
def random(shape, dtype, device):
if isinstance(shape, int):
shape = (shape, )
if dtype == torch.bool:
return torch.randint(0, 2, shape, dtype=dtype, device=device)
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
return torch.randint(1, 32, shape, dtype=dtype, device=device)
if dtype in [torch.float16, torch.float32, torch.float64]:
return torch.randn(shape, dtype=dtype, device=device)
raise RuntimeError(f'Unknown dtype {dtype}')
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):

View File

@@ -1,139 +1,71 @@
import torch
import triton
"""
Vector Addition
=================
In this tutorial, you will write a simple vector addition using Triton and learn about:
- The basic syntax of the Triton programming language
- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API
- The basic programming model used by Triton
- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.
- The best practices for validating and benchmarking custom ops against native reference implementations
"""
# %%
# Compute Kernel
# --------------------------
#
# Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel
# on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)
# programming model for more details).
#
# .. code-block:: C
#
# __global__ void add(float* z, float* x, float* y, int N){
# // The `get_program_id(i)` returns the i-th coordinate
# // of the program in the overaching SPMD context
# // (a.k.a launch grid). This is what allows us to process
# // different chunks of data in parallel.
# // For those similar with CUDA, `get_program_id({0,1,2})`
# // is similar to blockIdx.{x,y,z}
# int pid = get_program_id(0);
# // In Triton, arrays are first-class citizen. In other words,
# // they are primitives data-types and are -- contrary to C and
# // CUDA -- not implemented as pointers to contiguous chunks of
# // memory.
# // In the few lines below, we create an array of `BLOCK` pointers
# // whose memory values are, e.g.:
# // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]
# // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time
# int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
# float* pz [BLOCK] = z + offset;
# float* px [BLOCK] = x + offset;
# float* py [BLOCK] = y + offset;
# // Simple element-wise control-flow for load/store operations can
# // be achieved using the the ternary operator `cond ? val_true : val_false`
# // or the conditional dereferencing operator `*?(cond)ptr
# // Here, we make sure that we do not access memory out-of-bounds when we
# // write-back `z`
# bool check[BLOCK] = offset < N;
# *?(check)pz = *?(check)px + *?(check)py;
# }
#
# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.
@triton.jit
def _add(
X, # *Pointer* to first input vector
Y, # *Pointer* to second input vector
Z, # *Pointer* to output vector
N, # Size of the vector
**meta # Optional meta-parameters for the kernel
):
pid = triton.program_id(0)
# Create an offset for the blocks of pointers to be
# processed by this program instance
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
# Create a mask to guard memory operations against
# out-of-bounds accesses
mask = offsets < N
# Load x
x = triton.load(X + offsets, mask=mask)
y = triton.load(Y + offsets, mask=mask)
# Write back x + y
z = x + y
triton.store(Z + offsets, z)
# %%
# Torch Bindings
# --------------------------
# The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
#
# - :code:`source: string`: the source-code of the kernel you want to create
# - :code:`device: torch.device`: the device you want to compile this code for
# - :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you
import torch
import triton
# source-code for Triton compute kernel
# here we just copy-paste the above code without the extensive comments.
# you may prefer to store it in a .c file and load it from there instead.
_src = """
__global__ void add(float* z, float* x, float* y, int N){
// program id
int pid = get_program_id(0);
// create arrays of pointers
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
float* pz[BLOCK] = z + offset;
float* px[BLOCK] = x + offset;
float* py[BLOCK] = y + offset;
// bounds checking
bool check[BLOCK] = offset < N;
// write-back
*?(check)pz = *?(check)px + *?(check)py;
}
"""
# We can also declara a helper function that handles allocating the output vector
# and enqueueing the kernel.
# This function returns a callable `triton.kernel` object created from the above source code.
# For portability, we maintain a cache of kernels for different `torch.device`
# We compile the kernel with -DBLOCK=1024
def make_add_kernel(device):
cache = make_add_kernel.cache
if device not in cache:
defines = {'BLOCK': 1024}
cache[device] = triton.kernel(_src, device=device, defines=defines)
return cache[device]
def add(x, y):
z = torch.empty_like(x)
N = z.shape[0]
# The SPMD launch grid denotes the number of kernel instances that should execute in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
# NOTE:
# - torch.tensor objects are implicitly converted to pointers to their first element.
# - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
_add[grid](x, y, z, N, BLOCK=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously.
return z
make_add_kernel.cache = dict()
# This is a standard torch custom autograd Function;
# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`
class _add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
# constraints of the op
assert x.dtype == torch.float32
# *allocate output*
z = torch.empty_like(x)
# *create launch grid*:
# this is a function which takes compilation parameters `opt`
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
# triton.cdiv is a shortcut for ceil division:
# triton.cdiv(a, b) = (a + b - 1) // b
N = z.shape[0]
grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )
# *launch kernel*:
# pointer to the data of torch tensors can be retrieved with
# the `.data_ptr()` method
kernel = make_add_kernel(z.device)
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)
return z
# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function
add = _add.apply
# %%
# We can now use the above function to compute the sum of two `torch.tensor` objects:
# %%
# Unit Test
# -----------
#
# Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:
# We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:
torch.manual_seed(0)
x = torch.rand(98432, device='cuda')
y = torch.rand(98432, device='cuda')
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
za = x + y
zb = add(x, y)
print(za)

View File

@@ -4,8 +4,7 @@ Fused Softmax
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:
- The benefits of kernel fusion for bandwidth-bound operations.
- The syntax and usage of reduction operators in Triton.
- The automatic vectorization capabilities of the Triton compiler.
- The reduction operators in Triton.
"""
# %%
@@ -36,79 +35,45 @@ def naive_softmax(x):
# %%
# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
# In this case, we would be reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
# This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
# In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.
# %%
# Compute Kernel
# ----------------
# Our softmax kernel works as follows: each program loads a row of the input X, normalizes it and writes back the result to the output Y.
# Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
# Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
#
# .. code-block:: C
#
# __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){
# // row index
# int m = get_program_id(0);
# // column indices
# int n [BLOCK] = 0 ... BLOCK;
# // the memory address of all the elements
# // that we want to load can be computed as follows
# float* px [BLOCK] = X + m*stride_xm + n;
# // because BLOCK has to be a power of two
# // (per Triton-C specs), it is important
# // to guard each memory operation with predicates
# // or we will read out of bounds
# bool check[BLOCK] = n < N;
# float x [BLOCK] = check ? *px : -F32_INFINITY;
# // syntax for reduction in Triton is:
# // x[:, :, OPERATOR, :, :]
# // ^
# // index
# // where operator is in {min, max, +}
# // for 1D vectors, this is just x[OPERATOR].
# float z [BLOCK] = x - x[max];
# // Note that exponentials in Triton are fast
# // but approximate (i.e., think __expf in CUDA)
# float num [BLOCK] = exp(z);
# float denom = num[+];
# // The result of the reduction is now stored in y
# float y [BLOCK] = num / denom;
# // We write it back
# float* py [BLOCK] = Y + m*stride_ym + n;
# *?(check)py = y;
# }
# %%
# Torch Bindings
# ---------------
# Here our torch bindings is quite similar to that of the vector addition mentioned in the previous tutorial.
# We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.
# This means that different values of BLOCK will result in different kernels
import torch
import triton
# Source code for the Triton kernel
_src = """
__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){
int m = get_program_id(0);
int n [BLOCK] = 0 ... BLOCK;
float* px [BLOCK] = X + m*stride_xm + n;
bool check[BLOCK] = n < N;
float x [BLOCK] = check ? *px : -F32_INFINITY;
float z [BLOCK] = x - x[max];
float num [BLOCK] = exp(z);
float denom = num[+];
float y [BLOCK] = num / denom;
float* py [BLOCK] = Y + m*stride_ym + n;
*?(check)py = y;
}
"""
@triton.jit
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index
m = triton.program_id(0)
# col indices
n = triton.arange(0, meta['BLOCK'])
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n
x = triton.load(X, mask=n < N, other=-float('inf'))
# Substract maximum for numerical stability
z = x - triton.max(x, axis=0)
# Note that exponentials in Triton are fast
# but approximate (i.e., think __expf in CUDA)
num = triton.exp(z)
denom = triton.sum(num, axis=0)
y = num / denom
# Write back to Y
Y = Y + m * stride_ym + n
triton.store(Y, y, mask=n < N)
# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
# helper function to get the smaller power-of-two larger than a given number
def next_power_of_2(n):
n -= 1
n |= n >> 1
@@ -120,11 +85,9 @@ def next_power_of_2(n):
return n
# kernel caching mechanism
def make_kernel(N, device):
cache = make_kernel.cache
# Now are kernels are indexed not only by the provided device but also
# by the rounded number of columns in the input matrix
def softmax(x):
M, N = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK = next_power_of_2(N)
# Another trick we can use is to ask the compiler to parallelize each
# row-normalization more aggressively -- i.e., with more warps -- vectors
@@ -134,37 +97,13 @@ def make_kernel(N, device):
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
# Each (BLOCK, num_warps, device) results in a different kernel
key = (BLOCK, num_warps, device)
if key not in cache:
defines = {'BLOCK': BLOCK}
cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
return cache[key]
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK)
return y
make_kernel.cache = dict()
class _softmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# constraints of the op
assert x.dtype == torch.float32
y = torch.empty_like(x)
# The launch grid is simple: we have one kernel instance per row of the input matrix
M, N = y.shape
grid = lambda opt: (M, )
# Launch kernel
kernel = make_kernel(N, y.device)
kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)
return y
softmax = _softmax.apply
# %%
# We can use the above softmax function to compute the row-wise softmax of a given matrix.
# %%
# Unit Test
# ----------

View File

@@ -1,10 +1,10 @@
"""
Matrix Multiplication
======================
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS's performance.
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.
You will specifically learn about:
- The block-level matrix multiplication operator `@`
- Block-level matrix multiplications
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
@@ -15,7 +15,7 @@ You will specifically learn about:
# -------------
# Matrix multiplications are a key building block of most modern high-performance computing systems.
# They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
# Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
# For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
#
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
@@ -23,322 +23,212 @@ You will specifically learn about:
# .. code-block:: python
#
# # do in parallel
# for m in range(0, M, MB):
# for m in range(0, M, BLOCK_M):
# # do in parallel
# for n in range(0, N, NB):
# acc = zeros((MB, NB), dtype=float32)
# for k in range(0, K, KB):
# acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB]
# C[m : m+MB, n : n+NB] = acc;
# for n in range(0, N, BLOCK_N):
# acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
# for k in range(0, K, BLOCK_K):
# a = A[m : m+BLOCK_M, k : k+BLOCK_K]
# b = B[k : k+BLOCK_K, n : n+BLOCK_N]
# acc += dot(a, b)
# C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
#
# where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.
# where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
# %%
# Compute Kernel
# ----------------
#
# The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication.
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop.
# The above algorithm is actually fairly straightforward to implement in Triton.
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop.
#
# Pointer Arithmetics
# ~~~~~~~~~~~~~~~~~~~~
#
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`.
# Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as:
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:
#
# .. code-block:: python
#
# &A[m : m+MB, k:k+KB] = A + (m : m+MB)[:, newaxis]*A.stride(0) + (k : k+KB)[newaxis, :];
# &B[k : k+KB, n:n+NB] = B + (k : k+KB)[:, newaxis]*B.stride(0) + (n : n+NB)[newaxis, :];
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :];
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :];
#
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
#
# .. code-block:: C
# .. code-block:: python
# :force:
#
# int rm[MB] = program_id_m * MB + 0 ... MB;
# int rn[NB] = program_id_n * NB + 0 ... NB;
# int rk[KB] = 0 ... KB;
# TYPE *pa[MB, KB] = A + (rm[:, newaxis] * stride_a_0 + rk [newaxis, :] * 1);
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
# pid_m = triton.program_id(0)
# pid_n = triton.program_id(1)
# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
# rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
# rk = triton.arange(0, BLOCK_K)
# // pointer for A operand
# pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
# // pointer for B operand
# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
#
# These pointers can then be updated in the inner loop as:
#
# .. code-block:: C
# .. code-block:: python
#
# pa += KB * 1;
# pb += KB * ldb;
# pa += BLOCK_K * stride_a_1;
# pb += BLOCK_K * stride_b_0;
#
#
# L2 Cache Optimizations
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# As mentioned above, each program instance computes an :code:`[MB, NB]` block of :code:`C`.
# As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
# However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
# This means that a naive row-major ordering:
#
# .. code-block:: C
# .. code-block:: Python
#
# int program_id = get_program_id(0);
# int grid_m = (M + MB - 1) / MB;
# int grid_n = (N + NB - 1) / NB;
# int program_id_m = program_id / grid_n;
# int program_id_n = program_id % grid_n;
# pid = triton.program_id(0);
# grid_m = (M + BLOCK_M - 1) / BLOCK_M;
# grid_n = (N + BLOCK_N - 1) / BLOCK_N;
# pid_m = pid / grid_n;
# pid_n = pid % grid_n;
#
# is unlikely to result in optimal performance.
#
# One possible solution is to launch blocks in an order that promotes data reuse.
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE` before switching to the next column:
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
#
# .. code-block:: C
#
# int program_id = get_program_id(0);
# int width = GROUP_SIZE * grid_n;
# int group_id = pid / width;
# // we need to handle the case where M % (GROUP_SIZE*BM) != 0
# int group_size = min(grid_m - group_id * GROUP_SIZE, GROUP_SIZE);
# int pid_m = group_id * GROUP_SIZE + (pid % group_size);
# int pid_n = (pid % width) / (group_size);
# pid = triton.program_id(0);
# width = GROUP_M * grid_n;
# group_id = pid / width;
# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
# pid_m = group_id * GROUP_M + (pid % group_size);
# pid_n = (pid % width) / (group_size);
#
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
#
# Final Result
# ~~~~~~~~~~~~~~
#
# We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
# Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure.
# This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.
#
# .. code-block:: C
# :force:
#
# #define MAX_GROUP_SIZE 8
#
# __global__ void dot(TYPE* A, TYPE* B, TYPE* C,
# int M, int N, int K,
# int stride_a_0, int stride_b_0, int stride_c_0) {
# // prologue
# int pid = get_program_id(0);
# int grid_m = (M + MB - 1) / MB;
# int grid_n = (N + NB - 1) / NB;
# // re-order program ID for better L2 performance
# int width = MAX_GROUP_SIZE * grid_n;
# int group_id = pid / width;
# int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
# int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
# int pid_n = (pid % width) / (group_size);
# // pointers to operands
# // note the parentheses here; they force the offset
# // computation to happen in typeof(stride_a_0) = int32 rather than
# // typeof(A) = int64
# int rm[MB] = pid_m * MB + 0 ... MB;
# int rn[NB] = pid_n * NB + 0 ... NB;
# int rk[KB] = 0 ... KB;
# TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0);
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
# // reduction loop
# float acc[MB, NB] = 0;
# for (int k = K; k > 0; k -= KB) {
# acc += (*pa) @ (*pb);
# pa += KB * 1;
# pb += KB * stride_b_0;
# }
# // pointers to output
# // here we rematerialize `rm` and `rn` so that they are not live through
# // the above reduction loop. In the future, the compiler should be able to
# // do this automatically.
# rm = pid_m * MB + 0 ... MB;
# rn = pid_n * NB + 0 ... NB;
# TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]);
# // we write back using *?() operator. `acc` gets casted to `float32` implicitly.
# *? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
# }
#
# Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code.
# Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
# If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16.
#
# %%
# Torch Bindings
# ----------------
# Final Result
# -------------
#
# Auto-Tuning
# ~~~~~~~~~~~~~~
#
# In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
import torch
import triton
autotune_configs = [
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
]
# %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
@triton.jit
def sigmoid(x):
ret_true = 1 / (1 + triton.exp(-x))
ret_false = triton.exp(x) / (1 + triton.exp(x))
return triton.where(x >= 0, ret_true, ret_false)
@triton.jit
def swish(x):
return x * sigmoid(x)
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
],
key=['M', 'N', 'K'],
)
# %
# We can now define our kernel as normal, using all the techniques presented above
@triton.jit
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
# extract meta-parameters
BLOCK_M = META['BLOCK_M']
BLOCK_N = META['BLOCK_N']
BLOCK_K = META['BLOCK_K']
GROUP_M = 8
# matrix multiplication
pid = triton.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K)
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
for k in range(K, 0, -BLOCK_K):
a = triton.load(A)
b = triton.load(B)
acc += triton.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# triton can accept arbitrary activation function
# via metaparameters!
if META['ACTIVATION']:
acc = META['ACTIVATION'](acc)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm[:, None] < M) & (rn[None, :] < N)
triton.store(C, acc, mask=mask)
# %%
# we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
# Here, we want to re-tune our kernel only when the shape of input matrices changes.
autotune_key = ["M", "N", "K"]
# %%
# We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
src = """
#define MAX_GROUP_SIZE 8
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
int M, int N, int K,
int lda, int ldb, int ldc) {
int pid = get_program_id(0);
int grid_m = (M + MB - 1) / MB;
int grid_n = (N + NB - 1) / NB;
int width = MAX_GROUP_SIZE * grid_n;
int group_id = pid / width;
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
int pid_n = (pid % width) / (group_size);
int rm[MB] = pid_m * MB + 0 ... MB;
int rn[NB] = pid_n * NB + 0 ... NB;
int rk[KB] = 0 ... KB;
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);
float acc[MB, NB] = 0;
for (int k = K; k > 0; k -= KB) {
acc += (*pa) @ (*pb);
pa += KB * 1;
pb += KB * ldb;
}
rm = pid_m * MB + 0 ... MB;
rn = pid_n * NB + 0 ... NB;
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
}
"""
# We can also create a convenience wrapper function that only takes two input tensors
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel
def make_kernel(device, dtype):
key = (device, dtype)
cache = make_kernel.cache
if key not in cache:
defines = {'TYPE': dtype}
cache[key] = triton.kernel(
src,
device=device,
defines=defines,
autotune_configs=autotune_configs,
autotune_key=autotune_key,
)
return cache[key]
def matmul(a, b, activation=None):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
assert a.is_contiguous(), "matrix A must be contiguous"
assert b.is_contiguous(), "matrix B must be contiguous"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
_matmul[grid](
a, b, c, M, N, K, \
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
ACTIVATION = activation
)
# return output
return c
make_kernel.cache = dict()
# %%
# Autograd Function
# ~~~~~~~~~~~~~~~~~~
#
# Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
# To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
class _dot(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
M, Ka = a.shape
Kb, N = b.shape
assert Ka == Kb, "incompatible dimensions"
assert a.is_contiguous() and b.is_contiguous(), "inputs must be contiguous"
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
kernel = make_kernel(a.device, a.dtype)
grid = lambda opt: (triton.cdiv(M, opt.MB) * triton.cdiv(N, opt.NB), )
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
M, N, Ka, \
a.stride(0), b.stride(0), c.stride(0), \
grid=grid)
return c
dot = _dot.apply
# %%
# Unit Test
# -----------
#
# We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
# Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
a = torch.rand((512, 768), device='cuda', dtype=torch.float16)
b = torch.rand((768, 896), device='cuda', dtype=torch.float16)
c_0 = dot(a, b)
c_1 = torch.matmul(a, b)
#torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
c_0 = matmul(a, b, activation=swish)
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
print(c_0)
print(c_1)
print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
print(triton.testing.allclose(c_0, c_1))
# %%
# Benchmark
# --------------
#
# Installing The CUTLASS Bindings
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
# For this reason, we will instead compare the performance of our kernel against `CUTLASS <https://github.com/NVIDIA/cutlass/>`_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
# To install CUTLASS, you need a recent version of cmake:
#
# .. code-block:: bash
#
# cd /path/to/cutlass/
# git clone https://github.com/NVIDIA/cutlass.git
# cd cutlass
# mkdir build
# cd build
# wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
# tar xzvf *.tar.gz
#
# You can then install CUTLASS as follows for V100
#
# .. code-block:: bash
#
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=70 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s884gemm_f16_*_align8
# make -j8 install
#
# Or as follows for A100:
#
# .. code-block:: bash
#
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s16816gemm_*align8
# make -j8 install
#
# Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
# Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process.
# To re-install Triton with the updated CUTLASS bindings, run the following command:
#
# .. code-block:: bash
#
# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
# pip uninstall -y triton
# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
#
# Which we can test as follows:
import triton
c_2 = triton.testing.cutlass_matmul(a, b)
print(c_2)
print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
# %%
# Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
#
# Square Matrix Performance
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
@@ -347,29 +237,25 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
x_vals=[8192], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name`
y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines
y_vals=['cublas', 'triton'], # possible keys for `y_name`
y_lines=["cuBLAS", "Triton"], # label name for the lines
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
)
)
def benchmark(M, N, K, provider):
silu = torch.nn.SiLU()
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
if provider == 'cutlass':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.testing.cutlass_matmul(a, b))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True)
# %%
# As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.
benchmark.run(print_data=True)