Deprecation of Triton-C and Replacement by decorated Python functions (#86)
This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes. See documentations for more information on the new API
This commit is contained in:
committed by
Philippe Tillet
parent
1fdb465b71
commit
39f4730305
@@ -49,11 +49,11 @@ class CMakeBuild(build_ext):
|
||||
self.build_extension(ext)
|
||||
|
||||
def build_extension(self, ext):
|
||||
# self.debug = True
|
||||
self.debug = False
|
||||
#self.debug = True
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# create build directories
|
||||
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm")
|
||||
build_suffix = 'debug' if self.debug else 'release'
|
||||
llvm_build_dir = os.path.join(tempfile.gettempdir(), f"llvm-{build_suffix}")
|
||||
if not os.path.exists(self.build_temp):
|
||||
os.makedirs(self.build_temp)
|
||||
if not os.path.exists(llvm_build_dir):
|
||||
|
676
python/src/functions.h
Normal file
676
python/src/functions.h
Normal file
@@ -0,0 +1,676 @@
|
||||
#include "triton/ir/builder.h"
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace ir = triton::ir;
|
||||
namespace py = pybind11;
|
||||
|
||||
static const std::string _builder_doc = R"pbdoc(
|
||||
:param builder: IR builder to generate code into, optional, set automatically when called inside a @triton.jit function
|
||||
:type builder: triton.ir.builder
|
||||
)pbdoc";
|
||||
|
||||
#define VA_ARGS(...) , ##__VA_ARGS__
|
||||
#define DEF_FUNC(MOD, PY_NAME, C_FUNC, ...) \
|
||||
MOD.def(PY_NAME, C_FUNC, (C_FUNC##_docstr + _builder_doc).c_str(), \
|
||||
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
|
||||
|
||||
void throw_not_implemented(std::string key) {
|
||||
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. This is likely a bug on our side.");
|
||||
}
|
||||
|
||||
void throw_not_int_or_float(std::string key) {
|
||||
throw std::runtime_error("`" + key + "` only supported for integer and floating point types.");
|
||||
}
|
||||
|
||||
enum type_code {
|
||||
_bool,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
float16,
|
||||
float32,
|
||||
float64
|
||||
};
|
||||
|
||||
ir::type *make_ir(type_code ty, ir::builder *builder) {
|
||||
switch (ty) {
|
||||
case float16:
|
||||
return builder->get_half_ty();
|
||||
case float32:
|
||||
return builder->get_float_ty();
|
||||
default:
|
||||
throw_not_implemented("make_ir");
|
||||
}
|
||||
}
|
||||
|
||||
type_code from_ir(ir::type *ty) {
|
||||
if (ty->is_half_ty())
|
||||
return float16;
|
||||
if (ty->is_float_ty())
|
||||
return float32;
|
||||
throw_not_implemented("from_ir");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.cast / triton.ir.value.to
|
||||
----------------------------------------------*/
|
||||
std::string cast_docstr = R"pbdoc(
|
||||
Tries to cast a block to a new data type.
|
||||
|
||||
:param input: The input block.
|
||||
:type input: triton.ir.value
|
||||
)pbdoc";
|
||||
|
||||
ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
|
||||
ir::type *src_ty = input->get_type();
|
||||
ir::type *dst_ty = make_ir(_dtype, builder);
|
||||
if (src_ty->is_block_ty())
|
||||
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
|
||||
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
||||
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
||||
// FP Truncation
|
||||
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
|
||||
if (truncate_fp)
|
||||
return builder->create_fp_trunc(input, dst_ty);
|
||||
// FP Extension
|
||||
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
|
||||
if (ext_fp)
|
||||
return builder->create_fp_ext(input, dst_ty);
|
||||
// Int cast
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
|
||||
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
|
||||
return builder->create_int_cast(input, dst_ty, true);
|
||||
// Float -> Int
|
||||
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty())
|
||||
return builder->create_fp_to_si(input, dst_ty);
|
||||
// int -> Float
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty())
|
||||
return builder->create_si_to_fp(input, dst_ty);
|
||||
// Ptr -> Ptr
|
||||
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
|
||||
return builder->create_cast(ir::BitCast, input, dst_ty);
|
||||
// * -> Bool
|
||||
if (dst_sca_ty->is_bool_ty()) {
|
||||
if (src_sca_ty->is_pointer_ty())
|
||||
input = cast(input, int64, builder);
|
||||
ir::value *other = builder->get_int64(0);
|
||||
if (src_ty->is_bool_ty())
|
||||
other = builder->create_splat(other, src_ty->get_block_shapes());
|
||||
return builder->create_icmpNE(input, other);
|
||||
}
|
||||
throw_not_implemented("cast");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.broadcast_check
|
||||
----------------------------------------------*/
|
||||
std::string try_broadcast_docstr = R"pbdoc(
|
||||
Tries to broadcast two blocks to a common compatible shape.
|
||||
|
||||
:param input: The first input block.
|
||||
:type input: triton.ir.value
|
||||
:param other: The second input block.
|
||||
:type other: triton.ir.value
|
||||
)pbdoc";
|
||||
|
||||
std::tuple<ir::value *, ir::value *> try_broadcast(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
ir::type *lhs_ty = lhs->get_type();
|
||||
ir::type *rhs_ty = rhs->get_type();
|
||||
// make_shape_compatible(block, scalar)
|
||||
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
|
||||
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
|
||||
// make_shape_compatible(scalar, block)
|
||||
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
|
||||
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
|
||||
// make_shape_compatible(block, block)
|
||||
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
|
||||
auto lhs_shape = lhs_ty->get_block_shapes();
|
||||
auto rhs_shape = rhs_ty->get_block_shapes();
|
||||
if (lhs_shape.size() != rhs_shape.size())
|
||||
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
|
||||
ir::type::block_shapes_t ret_shape;
|
||||
for (size_t i = 0; i < lhs_shape.size(); ++i) {
|
||||
unsigned left = lhs_shape[i];
|
||||
unsigned right = rhs_shape[i];
|
||||
if (left == 1)
|
||||
ret_shape.push_back(right);
|
||||
else if (right == 1)
|
||||
ret_shape.push_back(left);
|
||||
else if (left == right)
|
||||
ret_shape.push_back(left);
|
||||
else
|
||||
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
|
||||
": " + std::to_string(left) + " and " + std::to_string(right));
|
||||
}
|
||||
if (lhs_shape != ret_shape)
|
||||
lhs = builder->create_broadcast(lhs, ret_shape);
|
||||
if (rhs_shape != ret_shape)
|
||||
rhs = builder->create_broadcast(rhs, ret_shape);
|
||||
}
|
||||
return std::make_tuple(lhs, rhs);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.broadcast_to
|
||||
----------------------------------------------*/
|
||||
std::string broadcast_to_docstr = R"pbdoc(
|
||||
Tries to broadcast a block to a new shape.
|
||||
|
||||
:param input: The input block.
|
||||
:type input: triton.value
|
||||
:param shape: The new shape.
|
||||
:type shape: tuple of int
|
||||
)pbdoc";
|
||||
|
||||
ir::value *broadcast_to(ir::value *input, const ir::type::block_shapes_t &shape, ir::builder *builder) {
|
||||
if (!input->get_type()->is_block_ty())
|
||||
return builder->create_splat(input, shape);
|
||||
auto src_shape = input->get_type()->get_block_shapes();
|
||||
if (src_shape.size() != shape.size())
|
||||
throw std::runtime_error("Cannot broadcast");
|
||||
return builder->create_broadcast(input, shape);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.load
|
||||
----------------------------------------------*/
|
||||
std::string load_docstr = R"pbdoc(
|
||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
|
||||
|
||||
:param pointer: Pointer to the data to be loaded.
|
||||
:type pointer: Block of triton.pointer
|
||||
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
|
||||
:type mask: Block of triton.bool, optional
|
||||
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
|
||||
:type other: Block of triton.value, optional
|
||||
)pbdoc";
|
||||
|
||||
ir::value *load(ir::value *pointer, std::optional<ir::value *> _mask, std::optional<ir::value *> _other, ir::builder *builder) {
|
||||
if (!_mask.has_value() && !_other.has_value())
|
||||
return builder->create_load(pointer);
|
||||
if (!_mask.has_value())
|
||||
throw std::runtime_error("`other` cannot be provided without `mask`");
|
||||
ir::value *mask = _mask.value();
|
||||
ir::type *elt_ty = pointer->get_type()->get_scalar_ty()->get_pointer_element_ty();
|
||||
auto shape = pointer->get_type()->get_block_shapes();
|
||||
ir::value *other = _other.has_value() ? _other.value() : ir::undef_value::get(elt_ty);
|
||||
other = cast(other, from_ir(elt_ty), builder);
|
||||
other = broadcast_to(other, shape, builder);
|
||||
mask = broadcast_to(mask, shape, builder);
|
||||
return builder->create_masked_load(pointer, mask, other);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.store
|
||||
----------------------------------------------*/
|
||||
std::string store_docstr = R"pbdoc(
|
||||
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
|
||||
|
||||
:param pointer: The memory locations where the elements of `value` are stored.
|
||||
:type pointer: Block of triton.pointer
|
||||
:param value: The block of elements to be stored.
|
||||
:type value: Block of triton.value
|
||||
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
|
||||
:type mask: Block of triton.bool, optional
|
||||
)pbdoc";
|
||||
ir::value *store(ir::value *ptr, ir::value *val, std::optional<ir::value *> _mask, ir::builder *builder) {
|
||||
if (!_mask.has_value())
|
||||
return builder->create_store(ptr, val);
|
||||
ir::value *mask = _mask.value();
|
||||
return builder->create_masked_store(ptr, val, mask);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.dot
|
||||
----------------------------------------------*/
|
||||
std::string dot_docstr = R"pbdoc(
|
||||
Returns the matrix product of two blocks.
|
||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
||||
|
||||
:param input: The first block to be multiplied.
|
||||
:type input: 2D block of scalar-type in {`float16`, `float32`}
|
||||
:param other: The second block to be multiplied.
|
||||
:type other: 2D block of scalar-type in {`float16`, `float32`}
|
||||
)pbdoc";
|
||||
ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
ir::value *_0 = builder->get_float32(0);
|
||||
unsigned M = lhs->get_type()->get_block_shapes()[0];
|
||||
unsigned N = rhs->get_type()->get_block_shapes()[1];
|
||||
_0 = builder->create_splat(_0, {M, N});
|
||||
return builder->create_dot(lhs, rhs, _0);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.where
|
||||
----------------------------------------------*/
|
||||
std::string where_docstr = R"pbdoc(
|
||||
Returns a block of elements from either `x` or `y`, depending on `condition`.
|
||||
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
|
||||
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
|
||||
|
||||
:param condition: When True (nonzero), yield x, otherwise yield y.
|
||||
:type condition: Block of triton.bool
|
||||
:param x: values selected at indices where condition is True.
|
||||
:param y: values selected at indices where condition is False.
|
||||
)pbdoc";
|
||||
ir::value *where(ir::value *condition, ir::value *x, ir::value *y, ir::builder *builder) {
|
||||
return builder->create_select(condition, x, y);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arange
|
||||
----------------------------------------------*/
|
||||
std::string arange_docstr = R"pbdoc(
|
||||
Returns contiguous values within the open interval [start, end).
|
||||
|
||||
:param start: Start of the interval.
|
||||
:type start: int
|
||||
:param stop: End of the interval.
|
||||
:type stop: int
|
||||
)pbdoc";
|
||||
ir::value *arange(int start, int end, ir::builder *builder) {
|
||||
return builder->get_range(start, end);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.program_id
|
||||
----------------------------------------------*/
|
||||
std::string program_id_docstr = R"pbdoc(
|
||||
Returns the id of the current program instance along the given `axis`.
|
||||
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
|
||||
|
||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||
:type axis: int
|
||||
)pbdoc";
|
||||
ir::value *program_id(int axis, ir::builder *builder) {
|
||||
return builder->create_get_program_id(axis);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.num_programs
|
||||
----------------------------------------------*/
|
||||
std::string num_programs_docstr = R"pbdoc(
|
||||
Returns the number of program instances launched along the given `axis`.
|
||||
|
||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||
:type axis: int
|
||||
)pbdoc";
|
||||
ir::value *num_programs(int axis, ir::builder *builder) {
|
||||
return builder->create_get_num_programs(axis);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.zeros
|
||||
----------------------------------------------*/
|
||||
std::string zeros_docstr = R"pbdoc(
|
||||
Returns a block filled with the scalar value 0 and the given shape.
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., triton.float16
|
||||
:type dtype: triton.ir.dtype
|
||||
)pbdoc";
|
||||
ir::value *zeros(ir::type::block_shapes_t shape, type_code _dtype, ir::builder *builder) {
|
||||
ir::type *dtype = make_ir(_dtype, builder);
|
||||
ir::value *_0 = ir::constant::get_null_value(dtype);
|
||||
return builder->create_splat(_0, shape);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.exp
|
||||
----------------------------------------------*/
|
||||
std::string _exp_docstr = R"pbdoc(
|
||||
Returns the element-wise exponential of `input`.
|
||||
)pbdoc";
|
||||
ir::value *_exp(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_exp(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.log
|
||||
----------------------------------------------*/
|
||||
std::string _log_docstr = R"pbdoc(
|
||||
Returns the element-wise natural logarithm of `input`.
|
||||
)pbdoc";
|
||||
ir::value *_log(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_log(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.sqrt
|
||||
----------------------------------------------*/
|
||||
std::string sqrt_docstr = R"pbdoc(
|
||||
Returns the element-wise square root of `input`.
|
||||
)pbdoc";
|
||||
ir::value *sqrt(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_sqrt(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.min
|
||||
----------------------------------------------*/
|
||||
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
|
||||
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_reduce(input, FLOAT_OP, axis);
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_reduce(input, INT_OP, axis);
|
||||
else
|
||||
throw_not_int_or_float(name);
|
||||
}
|
||||
|
||||
std::string min_docstr = R"pbdoc(
|
||||
Returns the minimum value of `input`.
|
||||
)pbdoc";
|
||||
ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.max
|
||||
----------------------------------------------*/
|
||||
std::string max_docstr = R"pbdoc(
|
||||
Returns the maximum value of `input`.
|
||||
)pbdoc";
|
||||
ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.sum
|
||||
----------------------------------------------*/
|
||||
std::string sum_docstr = R"pbdoc(
|
||||
Returns the sum of `input`.
|
||||
)pbdoc";
|
||||
ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.atomic_cas
|
||||
----------------------------------------------*/
|
||||
std::string atomic_cas_docstr = R"pbdoc(
|
||||
Atomic compare-and-swap.
|
||||
)pbdoc";
|
||||
ir::value *atomic_cas(ir::value *ptr, ir::value *cmp, ir::value *val, ir::builder *builder) {
|
||||
return builder->create_atomic_cas(ptr, cmp, val);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.atomic_xchg
|
||||
----------------------------------------------*/
|
||||
std::string atomic_xchg_docstr = R"pbdoc(
|
||||
Atomic exchange.
|
||||
)pbdoc";
|
||||
ir::value *atomic_xchg(ir::value *ptr, ir::value *val, ir::builder *builder) {
|
||||
return builder->create_atomic_exch(ptr, val);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
debug barrier
|
||||
----------------------------------------------*/
|
||||
std::string debug_barrier_docstr = R"pbdoc(
|
||||
Temporary hacky fixup for when the compiler forgets to insert sync barriers
|
||||
)pbdoc";
|
||||
ir::value *debug_barrier(ir::builder *builder) {
|
||||
return builder->create_barrier();
|
||||
}
|
||||
|
||||
#define DEF_BINARY_OP(MOD, PY_NAME, C_FUNC, ...) \
|
||||
MOD.def(PY_NAME, binary_op(C_FUNC), (C_FUNC##_docstr + _builder_doc).c_str(), \
|
||||
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
|
||||
|
||||
template <class FN>
|
||||
std::function<ir::value *(ir::value *, ir::value *, ir::builder *builder)>
|
||||
binary_op(const FN &fn) {
|
||||
auto ret = [&fn](ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
//std::tie(self, other) = try_broadcast(self, other, builder);
|
||||
return fn(self, other, builder);
|
||||
};
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self + other
|
||||
----------------------------------------------*/
|
||||
std::string add_docstr = R"pbdoc(
|
||||
Returns self + other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *add(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// ptr + offset
|
||||
if (scalar_ty->is_pointer_ty())
|
||||
return builder->create_gep(self, {other});
|
||||
// float + float
|
||||
else if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fadd(self, other);
|
||||
// int + int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_add(self, other);
|
||||
throw_not_implemented("add");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self - other
|
||||
----------------------------------------------*/
|
||||
std::string sub_docstr = R"pbdoc(
|
||||
Returns self - other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *sub(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// ptr + offset
|
||||
if (scalar_ty->is_pointer_ty())
|
||||
return builder->create_gep(self, {other});
|
||||
// float + float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fsub(self, other);
|
||||
// int + int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_sub(self, other);
|
||||
throw_not_implemented("sub");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self * other
|
||||
----------------------------------------------*/
|
||||
std::string mul_docstr = R"pbdoc(
|
||||
Returns self * other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *mul(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float * float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fmul(self, other);
|
||||
// int * int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_mul(self, other);
|
||||
throw_not_implemented("mul");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self > other
|
||||
----------------------------------------------*/
|
||||
std::string greater_than_docstr = R"pbdoc(
|
||||
Returns self > other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *greater_than(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float > float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGT(self, other);
|
||||
// int > int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGT(self, other);
|
||||
throw_not_implemented("greater_than");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self >= other
|
||||
----------------------------------------------*/
|
||||
std::string greater_equal_docstr = R"pbdoc(
|
||||
Returns self >= other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *greater_equal(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float >= float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGE(self, other);
|
||||
// int >= int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGE(self, other);
|
||||
throw_not_implemented("greater_equal");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self < other
|
||||
----------------------------------------------*/
|
||||
std::string less_than_docstr = R"pbdoc(
|
||||
Returns self < other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *less_than(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float < float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLT(self, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLT(self, other);
|
||||
throw_not_implemented("less_than");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self <= other
|
||||
----------------------------------------------*/
|
||||
std::string less_equal_docstr = R"pbdoc(
|
||||
Returns self <= other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *less_equal(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float < float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLE(self, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLE(self, other);
|
||||
throw_not_implemented("less_equal");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self == other
|
||||
----------------------------------------------*/
|
||||
std::string equal_docstr = R"pbdoc(
|
||||
Returns self == other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *equal(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float == float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOEQ(self, other);
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpEQ(self, other);
|
||||
throw_not_implemented("equal");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self / other
|
||||
----------------------------------------------*/
|
||||
std::string _div_docstr = R"pbdoc(
|
||||
Returns self / other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *_div(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float / float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fdiv(self, other);
|
||||
// int / int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_sdiv(self, other);
|
||||
throw_not_implemented("div");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self % other
|
||||
----------------------------------------------*/
|
||||
std::string mod_docstr = R"pbdoc(
|
||||
Returns self % other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *mod(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float % int
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_frem(self, other);
|
||||
// int % int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_srem(self, other);
|
||||
throw_not_implemented("mod");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self & other
|
||||
----------------------------------------------*/
|
||||
std::string _and_docstr = R"pbdoc(
|
||||
Returns self & other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *_and(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
return builder->create_and(self, other);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of minimum(self, other)
|
||||
----------------------------------------------*/
|
||||
std::string minimum_docstr = R"pbdoc(
|
||||
Returns element-wise minimum of self and other
|
||||
)pbdoc";
|
||||
ir::value *minimum(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
return where(less_than(self, other, builder), self, other, builder);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self[slices]
|
||||
----------------------------------------------*/
|
||||
|
||||
enum slice_mode_t {
|
||||
NEWAXIS,
|
||||
ALL
|
||||
};
|
||||
|
||||
std::string subscript_docstr = R"pbdoc(
|
||||
returns self[slices].
|
||||
|
||||
:param slices: The slices to subscript with.
|
||||
:type slices: List of `None` or `:` slices.
|
||||
)pbdoc";
|
||||
ir::value *subscript(ir::value *self, std::vector<py::object> slices, ir::builder *builder) {
|
||||
std::vector<slice_mode_t> modes;
|
||||
for (py::object slice : slices) {
|
||||
py::object none = py::none();
|
||||
py::object all = py::make_tuple(none, none, none);
|
||||
if (slice.is(none))
|
||||
modes.push_back(NEWAXIS);
|
||||
else if (all.attr("__eq__")(slice))
|
||||
modes.push_back(ALL);
|
||||
else
|
||||
throw std::runtime_error("slice must be None or (None, None, None)");
|
||||
}
|
||||
|
||||
ir::type::block_shapes_t shape;
|
||||
size_t curr = 0;
|
||||
for (slice_mode_t mode : modes) {
|
||||
if (mode == NEWAXIS)
|
||||
shape.push_back(1);
|
||||
else {
|
||||
assert(mode == ALL);
|
||||
shape.push_back(self->get_type()->get_block_shapes()[curr++]);
|
||||
}
|
||||
}
|
||||
return builder->create_reshape(self, shape);
|
||||
}
|
@@ -1,5 +1,13 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/codegen/pass.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/dispatch.h"
|
||||
#include "triton/ir/enums.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include <optional>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
@@ -8,78 +16,9 @@
|
||||
#include <string>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace triton;
|
||||
namespace rt = triton::runtime;
|
||||
namespace ir = triton::ir;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::tools */
|
||||
/*****************************************************************************/
|
||||
|
||||
/*!
|
||||
@brief Function for extracting kernels out of a given source-string
|
||||
|
||||
This can be important to enable pre-processor macros (or tunable parameters) that should only
|
||||
be defined within the scope of a single kernel function
|
||||
*/
|
||||
std::string extract_kernels(const std::string &str, const std::vector<std::string> &names) {
|
||||
if (names.empty())
|
||||
return str;
|
||||
// search for all regex matches of kernel_regex in str
|
||||
std::smatch matches;
|
||||
std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})");
|
||||
std::sregex_iterator it(str.begin(), str.end(), regex);
|
||||
std::sregex_iterator end;
|
||||
std::vector<std::tuple<std::string, int, int>> kernels;
|
||||
for (; it != end; ++it) {
|
||||
int pos = it->position();
|
||||
int len = it->length();
|
||||
std::string name = it->str(1);
|
||||
kernels.push_back(std::make_tuple(name, pos, len));
|
||||
}
|
||||
// check that all the kernels provided actually exist
|
||||
for (const std::string &name : names) {
|
||||
auto pred = [&name](const std::tuple<std::string, int, int> &t) { return std::get<0>(t) == name; };
|
||||
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
|
||||
if (!found)
|
||||
throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
|
||||
}
|
||||
// simple parsing logic to extract the declaration and body of each specified kernel
|
||||
std::string ret;
|
||||
for (const auto &k : kernels) {
|
||||
std::string name;
|
||||
int pos, len;
|
||||
std::tie(name, pos, len) = k;
|
||||
if (std::find(names.begin(), names.end(), name) == names.end())
|
||||
continue;
|
||||
std::string def = str.substr(pos, str.size() - pos);
|
||||
// skip over declaration
|
||||
// by finding matching ')' for first '('
|
||||
int count = 1;
|
||||
pos = def.find('(');
|
||||
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
|
||||
count += def[pos] == '(';
|
||||
count -= def[pos] == ')';
|
||||
}
|
||||
// skip over definition
|
||||
// by finding matching '{' for first '}'
|
||||
count = 1;
|
||||
pos = def.find('{', pos);
|
||||
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
|
||||
count += def[pos] == '{';
|
||||
count -= def[pos] == '}';
|
||||
}
|
||||
ret += def.substr(0, pos);
|
||||
ret += '\n';
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void init_triton_tools(py::module &&m) {
|
||||
m.def("extract_kernels", &extract_kernels);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::driver */
|
||||
/*****************************************************************************/
|
||||
@@ -88,14 +27,14 @@ void init_triton_driver(py::module &&m) {
|
||||
// base device
|
||||
py::class_<drv::device>(m, "device");
|
||||
// cuda device
|
||||
py::class_<drv::cu_device, driver::device>(m, "cu_device")
|
||||
py::class_<drv::cu_device, drv::device>(m, "cu_device")
|
||||
.def(py::init([](int dev_id, bool take_ownership) {
|
||||
CUdevice handle;
|
||||
drv::dispatch::cuDeviceGet(&handle, dev_id);
|
||||
return new drv::cu_device(handle, take_ownership);
|
||||
}));
|
||||
// host device
|
||||
py::class_<drv::host_device, driver::device>(m, "host_device")
|
||||
py::class_<drv::host_device, drv::device>(m, "host_device")
|
||||
.def(py::init<>());
|
||||
|
||||
// base stream
|
||||
@@ -108,54 +47,236 @@ void init_triton_driver(py::module &&m) {
|
||||
// py doesn't support opaque pointer (e.g., CUstream) so
|
||||
// we assume it has been converted to uint64_t
|
||||
.def(py::init([](uint64_t handle, bool take_ownership) {
|
||||
return std::unique_ptr<driver::cu_stream>(new driver::cu_stream((CUstream)handle, take_ownership));
|
||||
}));
|
||||
return std::unique_ptr<drv::cu_stream>(new drv::cu_stream((CUstream)handle, take_ownership));
|
||||
}))
|
||||
.def("enqueue", [](drv::cu_stream *self, drv::kernel *kernel,
|
||||
size_t grid_0, size_t grid_1, size_t grid_2,
|
||||
size_t block_0, size_t block_1, size_t block_2,
|
||||
const std::string &args,
|
||||
size_t shared_mem) {
|
||||
return self->enqueue(kernel, {grid_0, grid_1, grid_2}, {block_0, block_1, block_2},
|
||||
(void *)args.data(), args.size(), shared_mem);
|
||||
});
|
||||
|
||||
py::class_<drv::module>(m, "module");
|
||||
//py::class_<drv::cu_module, drv::module>(m, "cu_module");
|
||||
|
||||
py::class_<drv::kernel>(m, "kernel");
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::runtime */
|
||||
/* Python bindings for triton::codegen */
|
||||
/*****************************************************************************/
|
||||
void init_triton_runtime(py::module &&m) {
|
||||
// argument type
|
||||
py::enum_<rt::arg_type>(m, "arg_type")
|
||||
.value("int1", rt::INT1_T)
|
||||
.value("int8", rt::INT8_T)
|
||||
.value("int16", rt::INT16_T)
|
||||
.value("int32", rt::INT32_T)
|
||||
.value("int64", rt::INT64_T)
|
||||
.value("half", rt::HALF_T)
|
||||
.value("float", rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
// compilation options
|
||||
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
|
||||
.def(py::init<>())
|
||||
.def_readwrite("defines", &rt::options_t::defines)
|
||||
.def_readwrite("num_warps", &rt::options_t::num_warps)
|
||||
.def("__getattr__", [](rt::options_t *opt, const std::string &name) {
|
||||
return opt->D<int>(name);
|
||||
});
|
||||
// kernel
|
||||
py::class_<rt::kernel>(m, "kernel")
|
||||
.def("__call__", &rt::kernel::operator())
|
||||
.def_readonly("opt", &rt::kernel::opt)
|
||||
.def("asm", &rt::kernel::get_asm);
|
||||
// tune conf
|
||||
py::class_<rt::config>(m, "config")
|
||||
.def(py::init<std::map<std::string, std::string>, int>(),
|
||||
py::arg("defines") = std::map<std::string, std::string>(),
|
||||
py::arg("num_warps"));
|
||||
|
||||
// function
|
||||
py::class_<rt::function>(m, "function")
|
||||
.def(py::init<const std::string &, const rt::options_t &, driver::device *, const std::vector<rt::config> &, const std::vector<std::string> &>())
|
||||
.def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal)
|
||||
.def("signature", &rt::function::get_signature);
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def(
|
||||
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps) {
|
||||
drv::module *mod;
|
||||
drv::kernel *ker;
|
||||
size_t shared_mem;
|
||||
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem);
|
||||
return std::make_tuple(mod, ker, shared_mem);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* User-facing language features */
|
||||
/*****************************************************************************/
|
||||
|
||||
void init_triton_frontend(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
|
||||
// programming model
|
||||
m.def("program_id", &ir::dispatch::program_id, ret::reference);
|
||||
m.def("num_programs", &ir::dispatch::num_programs, ret::reference);
|
||||
// binary
|
||||
m.def("add", &ir::dispatch::add, ret::reference);
|
||||
m.def("sub", &ir::dispatch::sub, ret::reference);
|
||||
m.def("mul", &ir::dispatch::mul, ret::reference);
|
||||
m.def("truediv", &ir::dispatch::truediv, ret::reference);
|
||||
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
|
||||
m.def("mod", &ir::dispatch::mod, ret::reference);
|
||||
m.def("and_", &ir::dispatch::and_, ret::reference);
|
||||
m.def("or_", &ir::dispatch::or_, ret::reference);
|
||||
m.def("xor_", &ir::dispatch::xor_, ret::reference);
|
||||
m.def("lshr", &ir::dispatch::lshr, ret::reference);
|
||||
m.def("shl", &ir::dispatch::shl, ret::reference);
|
||||
// unary
|
||||
m.def("plus", &ir::dispatch::plus, ret::reference);
|
||||
m.def("minus", &ir::dispatch::minus, ret::reference);
|
||||
m.def("invert", &ir::dispatch::invert, ret::reference);
|
||||
// comparison
|
||||
m.def("greater_than", &ir::dispatch::greater_than, ret::reference);
|
||||
m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference);
|
||||
m.def("less_than", &ir::dispatch::less_than, ret::reference);
|
||||
m.def("less_equal", &ir::dispatch::less_equal, ret::reference);
|
||||
m.def("equal", &ir::dispatch::equal, ret::reference);
|
||||
m.def("not_equal", &ir::dispatch::not_equal, ret::reference);
|
||||
// block creation
|
||||
m.def("arange", &ir::dispatch::arange, ret::reference);
|
||||
m.def("zeros", &ir::dispatch::zeros, ret::reference);
|
||||
// type manipuatation
|
||||
m.def("reshape", &ir::dispatch::reshape, ret::reference);
|
||||
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
|
||||
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
|
||||
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
|
||||
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
|
||||
m.def("cast", &ir::dispatch::cast, ret::reference);
|
||||
// memory
|
||||
m.def("load", &ir::dispatch::load, ret::reference);
|
||||
m.def("store", &ir::dispatch::store, ret::reference);
|
||||
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
|
||||
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
|
||||
// linear algebra
|
||||
m.def("dot", &ir::dispatch::dot, ret::reference);
|
||||
// indexing
|
||||
m.def("where", &ir::dispatch::where, ret::reference);
|
||||
// reduction
|
||||
m.def("min", &ir::dispatch::min, ret::reference);
|
||||
m.def("max", &ir::dispatch::max, ret::reference);
|
||||
m.def("sum", &ir::dispatch::sum, ret::reference);
|
||||
// math
|
||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||
m.def("log", &ir::dispatch::log, ret::reference);
|
||||
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
|
||||
// internal (debugging only)
|
||||
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
|
||||
m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::ir */
|
||||
/*****************************************************************************/
|
||||
|
||||
void init_triton_ir(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
using namespace pybind11::literals;
|
||||
|
||||
py::class_<ir::context>(m, "context")
|
||||
.def(py::init<>());
|
||||
|
||||
auto value = py::class_<ir::value>(m, "value");
|
||||
value.def_property("name", &ir::value::get_name, &ir::value::set_name);
|
||||
value.def_property_readonly("type", &ir::value::get_type);
|
||||
|
||||
py::class_<ir::user, ir::value>(m, "user");
|
||||
|
||||
py::class_<ir::constant, ir::user>(m, "constant");
|
||||
|
||||
py::class_<ir::undef_value, ir::constant>(m, "undef")
|
||||
.def("get", &ir::undef_value::get, ret::reference);
|
||||
|
||||
py::class_<ir::constant_int, ir::constant>(m, "constant_int")
|
||||
.def_property_readonly("value", &ir::constant_int::get_value)
|
||||
.def("__int__", [](ir::constant_int *self) { return self->get_value(); });
|
||||
|
||||
py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
|
||||
.def_property_readonly("value", &ir::constant_fp::get_value);
|
||||
|
||||
py::class_<ir::type>(m, "type")
|
||||
.def("is_ptr", &ir::type::is_pointer_ty)
|
||||
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
||||
.def("is_floating", &ir::type::is_floating_point_ty)
|
||||
.def("is_block", &ir::type::is_block_ty)
|
||||
.def("make_ptr", &ir::pointer_type::get, ret::reference)
|
||||
.def("make_function", &ir::function_type::get, ret::reference)
|
||||
.def("make_block", &ir::block_type::get, ret::reference)
|
||||
.def("get_void", &ir::type::get_void_ty, ret::reference)
|
||||
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
|
||||
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
|
||||
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
|
||||
.def("get_int1", &ir::type::get_int1_ty, ret::reference)
|
||||
.def("get_int8", &ir::type::get_int8_ty, ret::reference)
|
||||
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
||||
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
|
||||
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
||||
|
||||
.def("is_void", &ir::type::is_void_ty)
|
||||
.def("is_fp16", &ir::type::is_half_ty)
|
||||
.def("is_fp32", &ir::type::is_float_ty)
|
||||
.def("is_fp64", &ir::type::is_double_ty)
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
|
||||
|
||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
|
||||
.def_property_readonly("context", &ir::type::get_context, ret::reference);
|
||||
|
||||
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
|
||||
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
|
||||
|
||||
py::class_<ir::function_type, ir::type>(m, "function_type");
|
||||
py::class_<ir::integer_type, ir::type>(m, "integer_type");
|
||||
py::class_<ir::block_type, ir::type>(m, "block_type")
|
||||
.def_property_readonly("shape", &ir::block_type::get_shapes)
|
||||
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
|
||||
|
||||
py::class_<ir::scope>(m, "scope")
|
||||
.def(py::init<>())
|
||||
.def_property_readonly("values", &ir::scope::get_values)
|
||||
.def("set_type", &ir::scope::set_type);
|
||||
|
||||
py::class_<ir::module>(m, "module")
|
||||
.def(py::init<std::string, ir::builder &>())
|
||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
||||
.def("add_new_scope", &ir::module::add_new_scope, ret::reference)
|
||||
.def("seal_block", &ir::module::seal_block)
|
||||
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
|
||||
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
|
||||
.def("pop_scope", &ir::module::pop_scope)
|
||||
.def_property_readonly("scope", &ir::module::get_scope, ret::reference)
|
||||
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
||||
|
||||
using eattr = ir::attribute_kind_t;
|
||||
py::enum_<eattr>(m, "attribute_kind")
|
||||
.value("readonly", eattr::readonly)
|
||||
.value("writeonly", eattr::writeonly)
|
||||
.value("noalias", eattr::noalias)
|
||||
.value("aligned", eattr::aligned)
|
||||
.value("multiple_of", eattr::multiple_of)
|
||||
.value("retune", eattr::retune)
|
||||
.value("not_implemented", eattr::not_implemented);
|
||||
|
||||
py::class_<ir::attribute>(m, "attribute")
|
||||
.def(py::init<eattr, int>());
|
||||
|
||||
py::class_<ir::function>(m, "function")
|
||||
.def_property_readonly("args", &ir::function::args)
|
||||
.def_property_readonly("attrs", &ir::function::attrs)
|
||||
.def("add_attr", &ir::function::add_attr);
|
||||
|
||||
py::class_<ir::argument, ir::value>(m, "argument");
|
||||
|
||||
py::class_<ir::basic_block, ir::value>(m, "basic_block")
|
||||
.def("create", &ir::basic_block::create, ret::reference)
|
||||
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
||||
|
||||
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
|
||||
.def(py::init<ir::context &>())
|
||||
// getters
|
||||
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
|
||||
// control flow
|
||||
.def("br", &ir::builder::create_br, ret::reference)
|
||||
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
|
||||
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
|
||||
// constants
|
||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
||||
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
||||
.def("get_range", &ir::builder::get_range, ret::reference);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
py::module subm = m.def_submodule("triton");
|
||||
init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_driver(std::move(subm.def_submodule("driver")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_tools(std::move(subm.def_submodule("tools")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
init_triton_frontend(std::move(subm.def_submodule("frontend")));
|
||||
}
|
||||
|
209
python/test/test_code_gen.py
Normal file
209
python/test/test_code_gen.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import torch
|
||||
import triton
|
||||
import copy
|
||||
import pytest
|
||||
import ast
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
# convert from string to torch.dtype
|
||||
# Necessary because doesn't print torch.dtype properly
|
||||
cvt = {
|
||||
'bool': torch.bool,
|
||||
'int8': torch.int8,
|
||||
'int16': torch.int16,
|
||||
'int32': torch.int32,
|
||||
'int64': torch.int64,
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'float64': torch.float64,
|
||||
}
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + float_dtypes
|
||||
|
||||
|
||||
def patch_kernel(template, to_replace):
|
||||
kernel = copy.deepcopy(template)
|
||||
for key, value in to_replace.items():
|
||||
kernel.src = kernel.src.replace(key, value)
|
||||
return kernel
|
||||
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
off = triton.arange(0, meta['SIZE'])
|
||||
x = triton.load(X + off)
|
||||
z = GENERATE_TEST_HERE
|
||||
triton.store(Z + off, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
# reference result
|
||||
z_ref = eval(expr)
|
||||
# triton result
|
||||
z_tri = torch.empty_like(z_ref)
|
||||
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
|
||||
# compare
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y, **meta):
|
||||
off = triton.arange(0, meta['SIZE'])
|
||||
x = triton.load(X + off)
|
||||
y = triton.load(Y + off)
|
||||
z = GENERATE_TEST_HERE
|
||||
triton.store(Z + off, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
|
||||
# reference result
|
||||
z_ref = eval(expr)
|
||||
# triton result
|
||||
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
|
||||
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
|
||||
# compare
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test binary ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
||||
(dtype_x, dtype_y, f' x {op} y') \
|
||||
for op in ['+', '-', '*', '/', '%'] \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test bitwise ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
||||
(dtype_x, dtype_y, f' x {op} y') \
|
||||
for op in ['&', '|', '^'] \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
])
|
||||
def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
if 'float' in dtype_x + dtype_y:
|
||||
with pytest.raises(RuntimeError):
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
else:
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test compare ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
||||
(dtype_x, dtype_y, f' x {op} y') \
|
||||
for op in ['==', '!=', '>', '<', '>=', '<='] \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
])
|
||||
def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test unary ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, expr", [
|
||||
(dtype_x, f' -x') for dtype_x in float_dtypes
|
||||
] + [\
|
||||
(dtype_x, f' ~x') for dtype_x in int_dtypes
|
||||
])
|
||||
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
_test_unary(dtype_x, expr, device=device)
|
||||
|
||||
|
||||
# ----------------
|
||||
# test indexing
|
||||
# ----------------
|
||||
|
||||
|
||||
def make_ptr_str(name, shape):
|
||||
rank = len(shape)
|
||||
offsets = []
|
||||
stride = 1
|
||||
for i in reversed(range(rank)):
|
||||
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
|
||||
offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}']
|
||||
stride *= shape[i]
|
||||
return f"{name} + {' + '.join(offsets)}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr", [f'x[{s}]' for s in
|
||||
['None, :', ':, None',\
|
||||
'None, :, :', ':, :, None']\
|
||||
])
|
||||
def test_index1d(expr, device='cuda'):
|
||||
dtype = torch.int32
|
||||
rank_x = expr.count(':')
|
||||
rank_y = expr.count(',') + 1
|
||||
shape_x = [32 for _ in range(rank_x)]
|
||||
shape_z = [32 for _ in range(rank_y)]
|
||||
|
||||
# Triton kernel
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
SIZE = meta['SIZE']
|
||||
m = triton.arange(0, SIZE)
|
||||
n = triton.arange(0, SIZE)
|
||||
x = triton.load(X_PTR_EXPR)
|
||||
z = GENERATE_TEST_HERE
|
||||
triton.store(Z_PTR_EXPR, z)
|
||||
|
||||
to_replace = {
|
||||
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
||||
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
||||
'GENERATE_TEST_HERE': expr,
|
||||
}
|
||||
kernel = patch_kernel(kernel, to_replace)
|
||||
|
||||
# torch result
|
||||
x = triton.testing.random(shape_x, dtype=dtype, device=device)
|
||||
y = torch.zeros(shape_z, dtype=dtype, device=device)
|
||||
z_ref = eval(expr) + y
|
||||
# triton result
|
||||
z_tri = torch.empty_like(z_ref)
|
||||
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0])
|
||||
# compare
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test if
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test for
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test while
|
||||
# ---------------
|
@@ -1,17 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
|
||||
def test_op():
|
||||
torch.manual_seed(0)
|
||||
DTYPE = torch.float16
|
||||
N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
|
||||
pad, stride, = (1, 1), (1, 1)
|
||||
dilation = (1, 1)
|
||||
a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5
|
||||
b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5
|
||||
th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation)
|
||||
tt_c = triton.ops.conv(a, b, pad, stride)
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol)
|
@@ -3,66 +3,74 @@ import itertools
|
||||
import triton
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# # 2 warp
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# # 4 warp
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
# (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
]),
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# # split-k
|
||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# # variable input
|
||||
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
torch.manual_seed(0)
|
||||
defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}
|
||||
triton.ops._matmul._kernels = dict()
|
||||
triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)]
|
||||
if M is None:
|
||||
M = TM
|
||||
if N is None:
|
||||
N = TN
|
||||
if K is None:
|
||||
K = TK * SPLITK
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8}
|
||||
configs = [triton.Config(meta=META, num_warps=NWARP)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
decorators = kernel.kernel_decorators
|
||||
kernel.kernel_decorators = []
|
||||
triton.autotune(configs, [])(kernel)
|
||||
kernel.kernel_decorators += decorators[1:]
|
||||
# get matrix shape
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K * SPLIT_K if K is None else K
|
||||
# allocate/transpose inputs
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.ops.matmul(a, b)
|
||||
assert triton.testing.allclose(th_c, tt_c)
|
||||
|
@@ -2,9 +2,10 @@
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
# submodules
|
||||
from . import testing
|
||||
from .kernel import *
|
||||
from . import ops
|
||||
from .code_gen import jit, autotune, heuristics, Config, Autotuner
|
||||
from .core import *
|
||||
|
||||
from . import testing
|
||||
from . import ops
|
||||
# version
|
||||
__version__ = '1.0.0'
|
648
python/triton/code_gen.py
Normal file
648
python/triton/code_gen.py
Normal file
@@ -0,0 +1,648 @@
|
||||
import inspect
|
||||
import struct
|
||||
import enum
|
||||
import types
|
||||
import torch
|
||||
import ast
|
||||
import builtins
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton
|
||||
import sys
|
||||
import textwrap
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def get_value(self, name):
|
||||
# search node.id in local scope
|
||||
ret = None
|
||||
if name in self.lscope:
|
||||
ret = self.lscope[name]
|
||||
# search node.id in global scope
|
||||
elif name in self.gscope:
|
||||
ret = self.gscope[name]
|
||||
# search node.id in builtins
|
||||
elif name in self.builtins:
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if isinstance(ret, triton.block):
|
||||
handle = self.module.get_value(name)
|
||||
return triton.block(handle)
|
||||
return ret
|
||||
|
||||
def set_value(self, name, value):
|
||||
if isinstance(value, _triton.ir.value):
|
||||
value = triton.block(value)
|
||||
if isinstance(value, triton.block):
|
||||
self.module.set_value(name, value.handle)
|
||||
self.module.scope.set_type(name, value.handle.type)
|
||||
self.lscope[name] = value
|
||||
|
||||
def is_triton_object(self, value):
|
||||
return isinstance(value, triton.block)
|
||||
|
||||
def visit_compound_statement(self, stmts, add_scope=False):
|
||||
if add_scope:
|
||||
self.module.add_new_scope()
|
||||
for stmt in stmts:
|
||||
self.last_ret = self.visit(stmt)
|
||||
if isinstance(stmt, ast.Return):
|
||||
break
|
||||
if add_scope:
|
||||
self.module.pop_scope()
|
||||
return self.last_ret
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = _triton.ir.module('', self.builder)
|
||||
self.prototype = prototype
|
||||
self.gscope = gscope
|
||||
self.lscope = dict()
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.kwargs = kwargs
|
||||
self.last_node = None
|
||||
self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr}
|
||||
|
||||
def visit_Module(self, node):
|
||||
self.module.add_new_scope()
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
self.module.pop_scope()
|
||||
|
||||
def visit_List(self, node):
|
||||
ctx = self.visit(node.ctx)
|
||||
assert ctx is None
|
||||
elts = [self.visit(elt) for elt in node.elts]
|
||||
return elts
|
||||
|
||||
# By design, only non-kernel functions can return
|
||||
def visit_Return(self, node):
|
||||
return self.visit(node.value)
|
||||
|
||||
def visit_FunctionDef(self, node, inline=False, arg_values=None):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
# store keyword arguments in local scope
|
||||
self.lscope[kwarg_names] = self.kwargs
|
||||
# initialize function
|
||||
if inline:
|
||||
pass
|
||||
else:
|
||||
fn = self.module.get_or_insert_function(node.name, self.prototype)
|
||||
arg_values = []
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
arg_values.append(self.constants[i])
|
||||
else:
|
||||
if i in self.attributes:
|
||||
is_ptr = fn.args[i].type.is_ptr()
|
||||
attr = 'aligned' if is_ptr else 'multiple_of'
|
||||
attr = getattr(_triton.ir.attribute_kind, attr)
|
||||
attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
fn.add_attr(i + 1, attr)
|
||||
fn.args[i].name = arg_name
|
||||
arg_values.append(fn.args[i])
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
if inline:
|
||||
return self.visit_compound_statement(node.body, add_scope=True)
|
||||
else:
|
||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
||||
self.module.seal_block(entry)
|
||||
self.builder.set_insert_block(entry)
|
||||
# visit function body
|
||||
self.visit_compound_statement(node.body, add_scope=True)
|
||||
# finalize function
|
||||
self.builder.ret_void()
|
||||
|
||||
def visit_arguments(self, node):
|
||||
arg_names = []
|
||||
for arg in node.args:
|
||||
arg_names += [self.visit(arg)]
|
||||
kwarg_names = self.visit(node.kwarg)
|
||||
return arg_names, kwarg_names
|
||||
|
||||
def visit_arg(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
return node.arg
|
||||
|
||||
def visit_Assign(self, node):
|
||||
names = []
|
||||
for target in node.targets:
|
||||
names += [self.visit(target)]
|
||||
assert len(names) == 1
|
||||
name = names[0]
|
||||
value = self.visit(node.value)
|
||||
self.set_value(names[0], value)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
name = node.target.id
|
||||
lhs = ast.Name(id=name, ctx=ast.Load())
|
||||
rhs = ast.BinOp(lhs, node.op, node.value)
|
||||
assign = ast.Assign(targets=[node.target], value=rhs)
|
||||
self.visit(assign)
|
||||
return self.get_value(name)
|
||||
|
||||
def visit_Name(self, node):
|
||||
if type(node.ctx) == ast.Store:
|
||||
return node.id
|
||||
return self.get_value(node.id)
|
||||
|
||||
def visit_Store(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Load(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Tuple(self, node):
|
||||
args = [self.visit(x) for x in node.elts]
|
||||
return tuple(args)
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
fn = {
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
ast.Mult: '__mul__',
|
||||
ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__',
|
||||
ast.Mod: '__mod__',
|
||||
ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__',
|
||||
ast.RShift: '__rshift__',
|
||||
ast.BitAnd: '__and__',
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}[type(node.op)]
|
||||
kws = dict()
|
||||
|
||||
if self.is_triton_object(lhs):
|
||||
kws['builder'] = self.builder
|
||||
ret = getattr(lhs, fn)(rhs, **kws)
|
||||
if ret is NotImplemented:
|
||||
if self.is_triton_object(rhs):
|
||||
kws['builder'] = self.builder
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
ret = getattr(rhs, fn)(lhs, **kws)
|
||||
return ret
|
||||
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if self.is_triton_object(cond):
|
||||
current_bb = self.builder.get_insert_block()
|
||||
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||
self.module.seal_block(then_bb)
|
||||
if else_bb:
|
||||
self.module.seal_block(else_bb)
|
||||
self.builder.cond_br(cond.handle, then_bb, else_bb)
|
||||
else:
|
||||
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||
self.builder.set_insert_block(then_bb)
|
||||
self.visit_compound_statement(node.body, add_scope=True)
|
||||
# TODO: last statement is a terminator?
|
||||
self.builder.br(endif_bb)
|
||||
if else_bb:
|
||||
self.builder.set_insert_block(else_bb)
|
||||
self.visit_compound_statement(node.orelse, add_scope=True)
|
||||
#TODO: last statement is a terminator?
|
||||
self.builder.br(endif_bb)
|
||||
self.module.seal_block(endif_bb)
|
||||
self.builder.set_insert_block(endif_bb)
|
||||
else:
|
||||
if cond:
|
||||
self.visit_compound_statement(node.body)
|
||||
else:
|
||||
self.visit_compound_statement(node.orelse)
|
||||
|
||||
def visit_IfExp(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if cond:
|
||||
return self.visit(node.body)
|
||||
else:
|
||||
return self.visit(node.orelse)
|
||||
|
||||
def visit_Pass(self, node):
|
||||
pass
|
||||
|
||||
def visit_Compare(self, node):
|
||||
assert len(node.comparators) == 1
|
||||
assert len(node.ops) == 1
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.comparators[0])
|
||||
fn = {
|
||||
ast.Eq: '__eq__',
|
||||
ast.NotEq: '__ne__',
|
||||
ast.Lt: '__lt__',
|
||||
ast.LtE: '__le__',
|
||||
ast.Gt: '__gt__',
|
||||
ast.GtE: '__ge__',
|
||||
ast.Is: '__eq__',
|
||||
ast.IsNot: '__ne__',
|
||||
}[type(node.ops[0])]
|
||||
if self.is_triton_object(lhs) or self.is_triton_object(rhs):
|
||||
return getattr(lhs, fn)(rhs, builder=self.builder)
|
||||
return getattr(lhs, fn)(rhs)
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
op = self.visit(node.operand)
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
ast.Invert: '__invert__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_object(op):
|
||||
return getattr(op, fn)(builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
current_bb = self.builder.get_insert_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
def continue_fn():
|
||||
cond = self.visit(node.test)
|
||||
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||
|
||||
continue_fn()
|
||||
self.builder.set_insert_block(loop_bb)
|
||||
self.visit_compound_statement(node.body, add_scope=True)
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
self.module.seal_block(stop_bb)
|
||||
self.module.seal_block(loop_bb)
|
||||
self.module.seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
|
||||
def visit_Str(self, node):
|
||||
return ast.literal_eval(node)
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
assert node.ctx.__class__.__name__ == "Load"
|
||||
lhs = self.visit(node.value)
|
||||
slices = self.visit(node.slice)
|
||||
if self.is_triton_object(lhs):
|
||||
return lhs.__getitem__(slices, builder=self.builder)
|
||||
return lhs[slices]
|
||||
|
||||
def visit_ExtSlice(self, node):
|
||||
return [self.visit(dim) for dim in node.dims]
|
||||
|
||||
def visit_For(self, node):
|
||||
iterator = self.visit(node.iter.func)
|
||||
assert iterator == self.builtins['range']
|
||||
# create nodes
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
||||
init_node = ast.Assign(targets=[st_target], value=node.iter.args[0])
|
||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
|
||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
|
||||
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
|
||||
build_cond = lambda: triton.where(self.visit(pos_step_node),\
|
||||
self.visit(pos_cond_node),\
|
||||
self.visit(neg_cond_node),\
|
||||
builder=self.builder)
|
||||
#cond_node = neg_cond_node
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=node.iter.args[2])
|
||||
# code generation
|
||||
current_bb = self.builder.get_insert_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
def continue_fn():
|
||||
self.visit(step_node)
|
||||
cond = build_cond()
|
||||
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||
|
||||
self.visit(init_node)
|
||||
cond = build_cond()
|
||||
self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||
self.builder.set_insert_block(loop_bb)
|
||||
self.visit_compound_statement(node.body, add_scope=True)
|
||||
# TODO: handle case where body breaks control flow
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
self.module.seal_block(stop_bb)
|
||||
self.module.seal_block(loop_bb)
|
||||
self.module.seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
|
||||
def visit_Slice(self, node):
|
||||
lower = self.visit(node.lower)
|
||||
upper = self.visit(node.upper)
|
||||
step = self.visit(node.step)
|
||||
return slice(lower, upper, step)
|
||||
|
||||
def visit_Index(self, node):
|
||||
return self.visit(node.value)
|
||||
|
||||
def visit_NameConstant(self, node):
|
||||
return node.value
|
||||
|
||||
def visit_keyword(self, node):
|
||||
return {node.arg: self.visit(node.value)}
|
||||
|
||||
def visit_Call(self, node):
|
||||
fn = self.visit(node.func)
|
||||
kws = dict()
|
||||
for keyword in node.keywords:
|
||||
kws.update(self.visit(keyword))
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if isinstance(fn, JITFunction):
|
||||
return fn(*args, generator=self, **kws)
|
||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.core:
|
||||
return fn(*args, builder=self.builder, **kws)
|
||||
return fn(*args, **kws)
|
||||
|
||||
def visit_Num(self, node):
|
||||
return node.n
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_NoneType(self, node):
|
||||
return None
|
||||
|
||||
def visit(self, node):
|
||||
if node is not None:
|
||||
self.last_node = node
|
||||
return super().visit(node)
|
||||
|
||||
def generic_visit(self, node):
|
||||
typename = type(node).__name__
|
||||
raise NotImplementedError("Unsupported node: {}".format(typename))
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, module, kernel, num_warps, shared_mem):
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
|
||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||
stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem)
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
def __init__(self, src, node, err):
|
||||
self.message = '\n'.join(src.split('\n')[:node.lineno])
|
||||
self.message += '\n' + ' ' * node.col_offset + '^'
|
||||
self.message += '\n Error: ' + str(err)
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class Kernel:
|
||||
|
||||
type_names = {
|
||||
int: 'I',
|
||||
float: 'f',
|
||||
bool: 'B',
|
||||
torch.float16: 'f16',
|
||||
torch.float32: 'f32',
|
||||
torch.float64: 'f64',
|
||||
torch.bool: 'i1',
|
||||
torch.int8: 'i8',
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(context, obj):
|
||||
type_map = {
|
||||
'I': _triton.ir.type.get_int32,
|
||||
'f': _triton.ir.type.get_fp32,
|
||||
'B': _triton.ir.type.get_int1,
|
||||
'f16': _triton.ir.type.get_fp16,
|
||||
'f32': _triton.ir.type.get_fp32,
|
||||
'f64': _triton.ir.type.get_fp64,
|
||||
'i1': _triton.ir.type.get_int1,
|
||||
'i8': _triton.ir.type.get_int8,
|
||||
'i16': _triton.ir.type.get_int16,
|
||||
'i32': _triton.ir.type.get_int32,
|
||||
'i64': _triton.ir.type.get_int64,
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if isinstance(obj, torch.Tensor):
|
||||
name = Kernel.type_names[obj.dtype]
|
||||
elt_ty = type_map[name](context)
|
||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||
# default path returns triton.ir.type directly
|
||||
name = Kernel.type_names[obj.__class__]
|
||||
return type_map[name](context)
|
||||
|
||||
@staticmethod
|
||||
def _types_key(*wargs, tensor_idxs):
|
||||
# type inference
|
||||
types_key = [None] * len(wargs)
|
||||
for i, arg in enumerate(wargs):
|
||||
prefix = 'P' if i in tensor_idxs else ''
|
||||
suffix = Kernel.type_names[arg.dtype] if i in tensor_idxs else Kernel.type_names[arg.__class__]
|
||||
types_key[i] = prefix + suffix
|
||||
return tuple(types_key)
|
||||
|
||||
@staticmethod
|
||||
def pow2_divisor(N):
|
||||
if N % 16 == 0: return 16
|
||||
if N % 8 == 0: return 8
|
||||
if N % 4 == 0: return 4
|
||||
if N % 2 == 0: return 2
|
||||
return 1
|
||||
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, **meta):
|
||||
# explicitly set device
|
||||
torch.cuda.set_device(device.index)
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs]
|
||||
ret_type = _triton.ir.type.get_void(context)
|
||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self.fn into code-generator object
|
||||
gscope = sys.modules[self.fn.module].__dict__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta)
|
||||
try:
|
||||
generator.visit(self.fn.parse())
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(self.fn.src, node, e)
|
||||
tt_device = _triton.driver.cu_device(device.index, False)
|
||||
# Compile to machine code
|
||||
mod, ker, shared_mem = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
|
||||
return Binary(mod, ker, num_warps, shared_mem)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, **meta):
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if isinstance(arg, torch.Tensor)]
|
||||
if len(tensor_idxs) == 0:
|
||||
raise ValueError("No Tensor argument found.")
|
||||
device = wargs[tensor_idxs[0]].device
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
# determine if we need to re-compile
|
||||
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
||||
attr_key = frozenset(attributes.items())
|
||||
meta_key = frozenset(meta.items())
|
||||
const_key = frozenset(constants.items())
|
||||
key = (device.type, device.index, types_key, attr_key, num_warps, meta_key, const_key)
|
||||
cache = self.fn.cache
|
||||
if key not in cache:
|
||||
# compile and cache configuration if necessary
|
||||
cache[key] = self._compile(
|
||||
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
|
||||
)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel.type_names[arg.__class__] for i, arg in enumerate(wargs)])
|
||||
params = struct.pack(fmt, *args)
|
||||
# enqueue cached function into stream
|
||||
binary = cache[key]
|
||||
cu_stream = torch.cuda.current_stream(device.index).cuda_stream
|
||||
stream = _triton.driver.cu_stream(cu_stream, False)
|
||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
||||
binary(stream, params, *grid)
|
||||
|
||||
|
||||
class Launcher:
|
||||
def __init__(self, kernel, grid):
|
||||
self.kernel = kernel
|
||||
self.grid = grid
|
||||
|
||||
def __call__(self, *wargs, **kwargs):
|
||||
self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||
|
||||
|
||||
class Autotuner:
|
||||
def __init__(self, kernel, arg_names, configs, key):
|
||||
if not configs:
|
||||
self.configs = [Config(dict(), num_warps=4)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = dict()
|
||||
self.kernel = kernel
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.meta.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.meta)
|
||||
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, **current)
|
||||
return triton.testing.do_bench(kernel_call)
|
||||
|
||||
def __call__(self, *args, **meta):
|
||||
if len(self.configs) > 1:
|
||||
key = tuple([args[i] for i in self.key_idx])
|
||||
if key not in self.cache:
|
||||
timings = {config: self._bench(*args, config=config, **meta) \
|
||||
for config in self.configs}
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
|
||||
|
||||
|
||||
class JITFunction:
|
||||
def __init__(self, fn):
|
||||
self.module = fn.__module__
|
||||
self.arg_names = inspect.getfullargspec(fn).args
|
||||
self.cache = dict()
|
||||
self.kernel_decorators = []
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.kernel = None
|
||||
|
||||
# we do not parse in the constructor because
|
||||
# the user might want to monkey-patch self.src dynamically.
|
||||
# Some unit tests do this, for example.
|
||||
def parse(self):
|
||||
tree = ast.parse(self.src)
|
||||
assert isinstance(tree, ast.Module)
|
||||
assert len(tree.body) == 1
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
def __call__(self, *args, generator: CodeGenerator, **meta):
|
||||
try:
|
||||
return generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(self.src, node, e)
|
||||
|
||||
def _init_kernel(self):
|
||||
if self.kernel is None:
|
||||
self.kernel = Kernel(self)
|
||||
for decorator in reversed(self.kernel_decorators):
|
||||
self.kernel = decorator(self.kernel)
|
||||
return self.kernel
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return Launcher(self._init_kernel(), grid)
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, meta, num_warps=4):
|
||||
self.meta = meta
|
||||
self.num_warps = num_warps
|
||||
|
||||
|
||||
def autotune(configs, key):
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
return Autotuner(kernel, fn.arg_names, configs, key)
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def heuristics(values):
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
def fun(*args, **meta):
|
||||
for v, heur in values.items():
|
||||
assert v not in meta
|
||||
meta[v] = heur(*args, **meta)
|
||||
return kernel(*args, **meta)
|
||||
|
||||
return fun
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def jit(fn):
|
||||
return JITFunction(fn)
|
499
python/triton/core.py
Normal file
499
python/triton/core.py
Normal file
@@ -0,0 +1,499 @@
|
||||
from triton._C.libtriton.triton import ir
|
||||
from triton._C.libtriton.triton import frontend
|
||||
import triton
|
||||
from functools import wraps
|
||||
|
||||
|
||||
def _patch(fn):
|
||||
|
||||
# convert block/dtype to ir values
|
||||
def _to_ir(x, builder):
|
||||
if isinstance(x, bool):
|
||||
return builder.get_int1(x)
|
||||
elif isinstance(x, int):
|
||||
return builder.get_int32(x)
|
||||
elif isinstance(x, float):
|
||||
return builder.get_float32(x)
|
||||
if isinstance(x, block):
|
||||
return x.handle
|
||||
if isinstance(x, dtype):
|
||||
return x.handle(builder)
|
||||
return x
|
||||
|
||||
def _from_ir(x):
|
||||
if isinstance(x, ir.value):
|
||||
if x.type.is_void():
|
||||
return None
|
||||
return block(x)
|
||||
return x
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
builder = args[-1]
|
||||
assert isinstance(builder, ir.builder)
|
||||
args = [_to_ir(x, builder) for x in args]
|
||||
kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()}
|
||||
ret = fn(*args, **kwargs)
|
||||
if isinstance(ret, tuple):
|
||||
return map(_from_ir, ret)
|
||||
return _from_ir(ret)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
for name in dir(frontend):
|
||||
fn = getattr(frontend, name)
|
||||
if callable(fn):
|
||||
setattr(frontend, name, _patch(fn))
|
||||
|
||||
|
||||
def builtin(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if 'builder' not in kwargs or \
|
||||
kwargs['builder'] is None:
|
||||
raise ValueError("Builder argument must be provided outside of JIT functions")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
if wrapper.__doc__:
|
||||
wrapper.__doc__ += """\
|
||||
:param builder: IR builder to generate code into, optional from within @triton.jit functions
|
||||
:type builder: triton.ir.builder
|
||||
"""
|
||||
return wrapper
|
||||
|
||||
|
||||
class dtype:
|
||||
def __init__(self, init):
|
||||
self.init = init
|
||||
|
||||
def handle(self, builder):
|
||||
ctx = builder.context
|
||||
return self.init(ctx)
|
||||
|
||||
|
||||
class pointer_dtype:
|
||||
def __init__(self, element_ty):
|
||||
self.element_ty = element_ty
|
||||
|
||||
def handle(self, builder):
|
||||
return ir.type.make_ptr(self.element_ty, 1)
|
||||
|
||||
|
||||
int1 = dtype(ir.type.get_int1)
|
||||
int8 = dtype(ir.type.get_int8)
|
||||
int16 = dtype(ir.type.get_int16)
|
||||
int32 = dtype(ir.type.get_int32)
|
||||
int64 = dtype(ir.type.get_int64)
|
||||
float16 = dtype(ir.type.get_fp16)
|
||||
float32 = dtype(ir.type.get_fp32)
|
||||
float64 = dtype(ir.type.get_fp64)
|
||||
|
||||
|
||||
class block:
|
||||
@staticmethod
|
||||
def _init_dtype(ir_type):
|
||||
# primitive type
|
||||
if ir_type.is_int1(): return int1
|
||||
if ir_type.is_int8(): return int8
|
||||
if ir_type.is_int16(): return int16
|
||||
if ir_type.is_int32(): return int32
|
||||
if ir_type.is_int64(): return int64
|
||||
if ir_type.is_fp16(): return float16
|
||||
if ir_type.is_fp32(): return float32
|
||||
if ir_type.is_fp64(): return float64
|
||||
# pointer type
|
||||
if ir_type.is_ptr():
|
||||
element_ty = block._init_dtype(ir_type.element)
|
||||
return pointer_dtype(element_ty)
|
||||
raise ValueError(f"Unsupported type {ir_type}")
|
||||
|
||||
def __init__(self, handle):
|
||||
# IR handle
|
||||
self.handle = handle
|
||||
# Block shape
|
||||
self.shape = (1, )
|
||||
if self.handle.type.is_block():
|
||||
self.shape = self.handle.type.shape
|
||||
# Data-type wrapper
|
||||
self.dtype = block._init_dtype(self.handle.type.scalar)
|
||||
|
||||
@builtin
|
||||
def __add__(self, other, builder=None):
|
||||
return frontend.add(self, other, builder)
|
||||
|
||||
def __radd__(self, other, builder=None):
|
||||
return self.__add__(other, builder=builder)
|
||||
|
||||
@builtin
|
||||
def __sub__(self, other, builder=None):
|
||||
return frontend.sub(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __mul__(self, other, builder=None):
|
||||
return frontend.mul(self, other, builder)
|
||||
|
||||
def __rmul__(self, other, builder=None):
|
||||
return self.__mul__(other, builder=builder)
|
||||
|
||||
@builtin
|
||||
def __truediv__(self, other, builder=None):
|
||||
return frontend.truediv(self, other, builder)
|
||||
|
||||
def __rtruediv__(self, other, builder=None):
|
||||
return frontend.truediv(other, self, builder)
|
||||
|
||||
@builtin
|
||||
def __floordiv__(self, other, builder=None):
|
||||
return frontend.floordiv(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __mod__(self, other, builder=None):
|
||||
return frontend.mod(self, other, builder)
|
||||
|
||||
# unary operators
|
||||
@builtin
|
||||
def __neg__(self, builder=None):
|
||||
return frontend.minus(self, builder)
|
||||
|
||||
@builtin
|
||||
def __invert__(self, builder=None):
|
||||
return frontend.invert(self, builder)
|
||||
|
||||
# bitwise operators
|
||||
|
||||
@builtin
|
||||
def __and__(self, other, builder=None):
|
||||
return frontend.and_(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __or__(self, other, builder=None):
|
||||
return frontend.or_(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __xor__(self, other, builder=None):
|
||||
return frontend.xor_(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __lshift__(self, other, builder=None):
|
||||
return frontend.shl(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __rshift__(self, other, builder=None):
|
||||
return frontend.lshr(self, other, builder)
|
||||
|
||||
# comparison operators
|
||||
|
||||
@builtin
|
||||
def __gt__(self, other, builder=None):
|
||||
return frontend.greater_than(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __ge__(self, other, builder=None):
|
||||
return frontend.greater_equal(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __lt__(self, other, builder=None):
|
||||
return frontend.less_than(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __le__(self, other, builder=None):
|
||||
return frontend.less_equal(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __eq__(self, other, builder=None):
|
||||
return frontend.equal(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __ne__(self, other, builder=None):
|
||||
return frontend.not_equal(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __getitem__(self, slices, builder=None):
|
||||
if isinstance(slices, slice):
|
||||
slices = [slices]
|
||||
src_shape = self.shape
|
||||
dst_shape = []
|
||||
curr = 0
|
||||
for sl in slices:
|
||||
if sl == None:
|
||||
dst_shape.append(1)
|
||||
elif sl == slice(None, None, None):
|
||||
dst_shape.append(src_shape[curr])
|
||||
curr += 1
|
||||
ret = frontend.reshape(self, dst_shape, builder)
|
||||
return ret
|
||||
|
||||
@builtin
|
||||
def to(self, dtype, builder=None):
|
||||
return frontend.cast(self, dtype.handle(builder), builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# SPMD Programming Model
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def program_id(axis, builder=None):
|
||||
"""
|
||||
Returns the id of the current program instance along the given `axis`.
|
||||
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
|
||||
|
||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||
:type axis: int
|
||||
"""
|
||||
return frontend.program_id(axis, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def num_programs(axis, builder=None):
|
||||
"""
|
||||
Returns the number of program instances launched along the given `axis`.
|
||||
|
||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||
:type axis: int
|
||||
"""
|
||||
return frontend.num_programs(axis, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Block Initialization
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def arange(start, end, builder=None):
|
||||
"""
|
||||
Returns contiguous values within the open interval [start, end).
|
||||
|
||||
:param start: Start of the interval.
|
||||
:type start: int
|
||||
:param stop: End of the interval.
|
||||
:type stop: int
|
||||
"""
|
||||
return frontend.arange(start, end, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def zeros(shape, dtype, builder=None):
|
||||
"""
|
||||
Returns a block filled with the scalar value 0 and the given shape.
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., triton.float16
|
||||
:type dtype: triton.ir.dtype
|
||||
"""
|
||||
return frontend.zeros(shape, dtype, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Shape Manipulation
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def broadcast(input, other, builder=None):
|
||||
"""
|
||||
Tries to broadcast two blocks to a common compatible shape.
|
||||
|
||||
:param input: The first input block.
|
||||
:type input: triton.ir.value
|
||||
:param other: The second input block.
|
||||
:type other: triton.ir.value
|
||||
"""
|
||||
return frontend.broadcast(input, other, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def broadcast_to(input, shape, builder=None):
|
||||
"""
|
||||
Tries to broadcast a block to a new shape.
|
||||
|
||||
:param input: The input block.
|
||||
:type input: triton.value
|
||||
:param shape: The new shape.
|
||||
:type shape: tuple of int
|
||||
"""
|
||||
return frontend.broadcast_to(input, shape, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def reshape(input, shape, builder=None):
|
||||
"""
|
||||
Reshapes a block to a new shape.
|
||||
"""
|
||||
return frontend.reshape(input, shape, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Linear Algebra
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def dot(input, other, builder=None):
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
||||
|
||||
:param input: The first block to be multiplied.
|
||||
:type input: 2D block of scalar-type in {`float16`, `float32`}
|
||||
:param other: The second block to be multiplied.
|
||||
:type other: 2D block of scalar-type in {`float16`, `float32`}
|
||||
"""
|
||||
return frontend.dot(input, other, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Memory Operations
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def load(pointer, mask=None, other=None, builder=None):
|
||||
"""
|
||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
|
||||
|
||||
:param pointer: Pointer to the data to be loaded.
|
||||
:type pointer: Block of triton.pointer
|
||||
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
|
||||
:type mask: Block of triton.bool, optional
|
||||
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
|
||||
:type other: Block of triton.value, optional
|
||||
"""
|
||||
return frontend.load(pointer, mask, other, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def store(pointer, value, mask=None, builder=None):
|
||||
"""
|
||||
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
|
||||
|
||||
:param pointer: The memory locations where the elements of `value` are stored.
|
||||
:type pointer: Block of triton.pointer
|
||||
:param value: The block of elements to be stored.
|
||||
:type value: Block of triton.value
|
||||
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
|
||||
:type mask: Block of triton.bool, optional
|
||||
"""
|
||||
return frontend.store(pointer, value, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_cas(ptr, cmp, val, builder=None):
|
||||
return frontend.atomic_cas(ptr, cmp, val, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_xchg(ptr, val, builder=None):
|
||||
return frontend.atomic_xchg(ptr, val, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Conditioning
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def where(condition, x, y, builder=None):
|
||||
"""
|
||||
Returns a block of elements from either `x` or `y`, depending on `condition`.
|
||||
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
|
||||
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
|
||||
The shape of `x` and `y` are both broadcast to the shape of `condition`.
|
||||
`x` and `y` must have the data type.
|
||||
|
||||
:param condition: When True (nonzero), yield x, otherwise yield y.
|
||||
:type condition: Block of triton.bool
|
||||
:param x: values selected at indices where condition is True.
|
||||
:param y: values selected at indices where condition is False.
|
||||
"""
|
||||
return frontend.where(condition, x, y, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Math
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def exp(x, builder=None):
|
||||
return frontend.exp(x, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def log(x, builder=None):
|
||||
return frontend.log(x, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Reductions
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def max(input, axis, builder=None):
|
||||
return frontend.max(input, axis, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def min(input, axis, builder=None):
|
||||
return frontend.min(input, axis, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def sum(input, axis, builder=None):
|
||||
return frontend.sum(input, axis, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Internal for debugging
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def debug_barrier(builder=None):
|
||||
return frontend.debug_barrier(builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def multiple_of(x, value, builder=None):
|
||||
return frontend.multiple_of(x, value, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Standard library
|
||||
# -----------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def minimum(x, y):
|
||||
return triton.where(x < y, x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def maximum(x, y):
|
||||
return triton.where(x > y, x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + np.exp(-x))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def ravel(x):
|
||||
return triton.reshape(x, [x.type.numel])
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax(x):
|
||||
z = x - triton.max(x, 0)
|
||||
num = triton.exp(z)
|
||||
den = triton.sum(num, 0)
|
||||
return num / den
|
||||
|
||||
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
@@ -1,119 +0,0 @@
|
||||
import os
|
||||
import struct
|
||||
from typing import Optional, Dict, List, Callable
|
||||
import torch
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
codes = {
|
||||
_triton.runtime.arg_type.int1: 'B',
|
||||
_triton.runtime.arg_type.int8: 'B',
|
||||
_triton.runtime.arg_type.int32: 'I',
|
||||
_triton.runtime.arg_type.int64: 'Q',
|
||||
_triton.runtime.arg_type.half: 'H',
|
||||
_triton.runtime.arg_type.float: 'f',
|
||||
_triton.runtime.arg_type.double: 'd',
|
||||
_triton.runtime.arg_type.buffer: 'P',
|
||||
}
|
||||
|
||||
|
||||
def th_to_triton(obj):
|
||||
""" Convert a `torch.dtype` to a Triton-C type string. """
|
||||
tys = {
|
||||
torch.int8: 'char',
|
||||
torch.int16: 'short',
|
||||
torch.int32: 'int',
|
||||
torch.int64: 'long',
|
||||
torch.float16: 'half',
|
||||
torch.float32: 'float',
|
||||
torch.float64: 'double',
|
||||
}
|
||||
if isinstance(obj, torch.dtype):
|
||||
return tys[obj]
|
||||
return str(obj)
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
""" Ceil division (a + b - 1) // b"""
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def read(path: str, kernel_names: Optional[List] = None) -> str:
|
||||
""" Extracts the source code for `kernel_names` from the given `path` file."""
|
||||
if kernel_names is None:
|
||||
kernel_names = []
|
||||
with open(path, 'r') as f:
|
||||
source = f.read()
|
||||
source = _triton.tools.extract_kernels(source, kernel_names)
|
||||
return source
|
||||
|
||||
|
||||
config = _triton.runtime.config
|
||||
|
||||
|
||||
class kernel:
|
||||
"""
|
||||
A class used to represent a Triton kernel.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
src: str,
|
||||
device: torch.device,
|
||||
defines: Optional[Dict] = None,
|
||||
num_warps: int = 4,
|
||||
autotune_configs: Optional[List] = None,
|
||||
autotune_key: Optional[List] = None
|
||||
):
|
||||
"""
|
||||
:param src: The source code of the kernel.
|
||||
:param device: The device to compile the kernel for.
|
||||
:param defines: A dictionary of preprocessor #define for the compiler.
|
||||
:param num_warps: Optimization flag for the compiler's internal auto-parallelization engine.
|
||||
:param autotune_configs: A list of triton.config objects for the autotuner to try.
|
||||
:param autotune_key: A list of kernel argument names whose change in value should trigger the autotuner to re-run.
|
||||
"""
|
||||
|
||||
if defines is None:
|
||||
defines = {}
|
||||
if autotune_configs is None:
|
||||
autotune_configs = []
|
||||
if autotune_key is None:
|
||||
autotune_key = []
|
||||
# check if src is empty
|
||||
if src == '':
|
||||
raise ValueError('Kernel source code is empty')
|
||||
self.src = src
|
||||
# device
|
||||
assert device.type in ['cuda', 'cpu']
|
||||
if device.type == 'cuda':
|
||||
self.device_id = torch.cuda.current_device() if device.index is None else device.index
|
||||
self.device = _triton.driver.cu_device(self.device_id, False)
|
||||
cu_stream = torch.cuda.current_stream(self.device_id).cuda_stream
|
||||
self.stream = _triton.driver.cu_stream(cu_stream, False)
|
||||
if device.type == 'cpu':
|
||||
self.device_id = -1
|
||||
self.device = _triton.driver.host_device()
|
||||
self.device = _triton.driver.host_stream()
|
||||
torch.cuda.set_device(self.device_id)
|
||||
# function
|
||||
self.opt = _triton.runtime.options()
|
||||
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
||||
self.opt.num_warps = num_warps
|
||||
# autotune_configs = [({}, 4)]
|
||||
self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_configs, autotune_key)
|
||||
self.tys = ''.join([codes[x] for x in self.fn.signature()])
|
||||
|
||||
def __call__(self, *args, grid: Callable[[_triton.runtime.options], tuple]):
|
||||
"""
|
||||
Runs the kernel on the given arguments and launch grid.
|
||||
:param args: The arguments to the kernel in the orders that they appear in the Triton-C source.
|
||||
:param grid: The launch grid for the kernel, i.e., callable that transform compilation options into a tuple of at most 3 integers.
|
||||
:return: None
|
||||
"""
|
||||
# make sure that the executing thread is on the right device
|
||||
torch.cuda.set_device(self.device_id)
|
||||
# pack parameters into a byte buffer
|
||||
params = struct.pack(self.tys, *args)
|
||||
kernel = self.fn.autotune(params, grid, self.stream)
|
||||
# run kernel
|
||||
grid = grid(kernel.opt)
|
||||
kernel(params, self.stream, grid)
|
@@ -1,4 +1,4 @@
|
||||
from .conv import _conv, conv
|
||||
#from .conv import _conv, conv
|
||||
from .matmul import _matmul, matmul
|
||||
from .cross_entropy import _cross_entropy, cross_entropy
|
||||
from . import blocksparse
|
@@ -1,199 +0,0 @@
|
||||
__global__ void NAME(TYPE *A __readonly __noalias,
|
||||
TYPE *B __readonly __noalias,
|
||||
TYPE *C __noalias,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
long stride_za,
|
||||
long stride_zb,
|
||||
long stride_zc,
|
||||
long stride_ha,
|
||||
long stride_hb,
|
||||
long stride_hc,
|
||||
int DS0, int DS1,
|
||||
int SDD_K,
|
||||
int SDD_off_width,
|
||||
int *lut, int *locks, int nlocks) {
|
||||
/* ---------------- */
|
||||
/* Prologue */
|
||||
/* ---------------- */
|
||||
// program ids
|
||||
int pid0 = get_program_id(0);
|
||||
int pid1 = get_program_id(1);
|
||||
int pidz = get_program_id(2);
|
||||
#ifdef SDD
|
||||
// load LUT header
|
||||
pid1 = pid1 + SDD_off_width;
|
||||
int blockidm[TM] = (0 ... TM) / BLOCK;
|
||||
int blockidn[TN] = (0 ... TN) / BLOCK;
|
||||
int offlutm[TM] = blockidm * (TN / BLOCK) * 4;
|
||||
int offlutn[TN] = blockidn * 4;
|
||||
int *header = lut + pid1 * (TM / BLOCK) * (TN / BLOCK) * 4;
|
||||
int z = *(header + 0);
|
||||
int i[TM] = *(header + 1 + offlutm);
|
||||
int j[TN] = *(header + 2 + offlutn);
|
||||
int AS1 = SDD_K / TZ;
|
||||
int lockid = select(TZ > 1, 1, 0);
|
||||
int offka = pid0 * AS1;
|
||||
int offkb = pid0 * AS1;
|
||||
int offmc = 0;
|
||||
int offnc = 0;
|
||||
int offpa = 0;
|
||||
int offpb = 0;
|
||||
int maxid = TZ;
|
||||
int offhc = 0;
|
||||
int offha = z;
|
||||
int offhb = z;
|
||||
int ram[TM] = i * BLOCK + ((0 ... TM) % BLOCK);
|
||||
int rbn[TN] = j * BLOCK + ((0 ... TN) % BLOCK);
|
||||
#else
|
||||
// load LUT header
|
||||
int *header = lut + pid0 * 6;
|
||||
int offset = *(header + 0);
|
||||
int AS1 = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int depth = *(header + 3);
|
||||
int lockid = *(header + 4);
|
||||
int maxid = *(header + 5);
|
||||
int *pinc = lut + offset;
|
||||
int offhc = depth;
|
||||
#ifdef DSD
|
||||
// output offset
|
||||
int offnc = pid1 * TN;
|
||||
int offmc = column * TM;
|
||||
int offpc = 0;
|
||||
// dense input offset
|
||||
int offnb = pid1 * TN;
|
||||
int offkb __multipleof(8) = *pinc;
|
||||
int offpb = 0;
|
||||
// sparse input offset
|
||||
int offma = 0;
|
||||
int offka = 0;
|
||||
long offpa __multipleof(8) = *(pinc + 1);
|
||||
offpa = offpa * BLOCK * BLOCK;
|
||||
int offha = 0;
|
||||
int offhb = depth;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
// output offset
|
||||
int offmc = pid1 * TM;
|
||||
int offnc = column * TN;
|
||||
int offpc = 0;
|
||||
// dense input offset
|
||||
int offma = pid1 * TM;
|
||||
int offka __multipleof(8) = *pinc;
|
||||
int offpa = 0;
|
||||
// sparse input offset
|
||||
int offnb = 0;
|
||||
int offkb = 0;
|
||||
long offpb __multipleof(8) = *(pinc + 1);
|
||||
offpb = offpb * BLOCK * BLOCK;
|
||||
int offha = depth;
|
||||
int offhb = 0;
|
||||
#endif
|
||||
int ram[TM] = offma + 0 ... TM;
|
||||
int rbn[TN] = offnb + 0 ... TN;
|
||||
#endif
|
||||
// initialize a, b pointers
|
||||
int rka[TK] = offka + 0 ... TK;
|
||||
int rkb[TK] = offkb + 0 ... TK;
|
||||
TYPE *pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka [newaxis, :] * STRIDE_AK;
|
||||
TYPE *pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn [newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
|
||||
// pre-fetch
|
||||
#ifdef DDS
|
||||
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
|
||||
#else
|
||||
bool checkam[TM, TK] = AS1 > 0;
|
||||
#endif
|
||||
#ifdef DSD
|
||||
bool checkbn[TK, TN] = rbn [newaxis, :] < DS0;
|
||||
#else
|
||||
bool checkbn[TK, TN] = AS1 > 0;
|
||||
#endif
|
||||
TYPE a[TM, TK] = checkam ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkbn ? *pb : 0;
|
||||
|
||||
/* ---------------- */
|
||||
/* Inner Loop */
|
||||
/* ---------------- */
|
||||
// create result tile
|
||||
float acc[TM, TN] = 0;
|
||||
int step = TK;
|
||||
for (int k = AS1; k > 0; k -= step) {
|
||||
acc += a @b;
|
||||
// update pointers
|
||||
#ifdef SDD
|
||||
int inc_a = TK * STRIDE_AK;
|
||||
int inc_b = TK * STRIDE_BK;
|
||||
#else
|
||||
pinc += 2;
|
||||
#ifdef DSD
|
||||
int inc_b __multipleof(8) = *pinc;
|
||||
int inc_a __multipleof(8) = *(pinc + 1);
|
||||
inc_b = inc_b * STRIDE_BK;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
int inc_a __multipleof(8) = *pinc;
|
||||
int inc_b __multipleof(8) = *(pinc + 1);
|
||||
inc_a = inc_a * STRIDE_AK;
|
||||
#endif
|
||||
#endif
|
||||
pa += inc_a;
|
||||
pb += inc_b;
|
||||
// pre-fetch
|
||||
bool checkak[TM, TK] = k > TK;
|
||||
bool checkbk[TK, TN] = k > TK;
|
||||
bool checka[TM, TK] = checkam && checkak;
|
||||
bool checkb[TK, TN] = checkbk && checkbn;
|
||||
a = *? (checka)pa;
|
||||
b = *? (checkb)pb;
|
||||
}
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
/* ---------------- */
|
||||
/* Epilogue */
|
||||
/* ---------------- */
|
||||
// initialize c pointers
|
||||
#ifdef SDD
|
||||
bool checkc[TM, TN] = 1;
|
||||
// rematerialize
|
||||
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
|
||||
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
|
||||
int rr_offlutm[TM] = rr_blockidm * (TN / BLOCK) * 4;
|
||||
int rr_offlutn[TN] = rr_blockidn * 4;
|
||||
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn [newaxis, :];
|
||||
int bkid[TM, TN] = *(header + off_bkid);
|
||||
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
|
||||
// range within blocks
|
||||
int rcm[TM] = (0 ... TM) % BLOCK;
|
||||
int rcn[TN] = (0 ... TN) % BLOCK;
|
||||
#else
|
||||
int rcm[TM] = offmc + 0 ... TM;
|
||||
int rcn[TN] = offnc + 0 ... TN;
|
||||
#ifdef DSD
|
||||
bool checkc[TM, TN] = rcn [newaxis, :] < DS0;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
|
||||
#endif
|
||||
#endif
|
||||
TYPE *pc[TM, TN] = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, newaxis] * STRIDE_CM + rcn [newaxis, :] * STRIDE_CN;
|
||||
// write-back directly
|
||||
if (lockid == 0) {
|
||||
*? (checkc)pc = c;
|
||||
}
|
||||
// accumulate partial result using spin-locks
|
||||
else {
|
||||
int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks;
|
||||
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
|
||||
;
|
||||
int count = *pcount;
|
||||
if (count == 0)
|
||||
*? (checkc)pc = c;
|
||||
else
|
||||
*? (checkc)pc = c + *? (checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % maxid);
|
||||
atomic_xchg(plock, 0);
|
||||
}
|
||||
}
|
@@ -4,7 +4,183 @@ import torch
|
||||
import os
|
||||
import math
|
||||
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
||||
|
||||
@triton.jit
|
||||
def _kernel(
|
||||
A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc,
|
||||
stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta
|
||||
):
|
||||
TM = meta['TM']
|
||||
TN = meta['TN']
|
||||
TK = meta['TK']
|
||||
TZ = meta['TZ']
|
||||
BLOCK = meta['BLOCK']
|
||||
#------------#
|
||||
#- Prologue -#
|
||||
#------------#
|
||||
pid0 = triton.program_id(0)
|
||||
pid1 = triton.program_id(1)
|
||||
pidz = triton.program_id(2)
|
||||
if meta['SDD']:
|
||||
pid1 = pid1 + SDD_off_width
|
||||
blockidm = triton.arange(0, TM) // BLOCK
|
||||
blockidn = triton.arange(0, TN) // BLOCK
|
||||
offlutm = blockidm * (TN // BLOCK) * 4
|
||||
offlutn = blockidn * 4
|
||||
header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
|
||||
z = triton.load(header + 0)
|
||||
i = triton.load(header + 1 + offlutm)
|
||||
j = triton.load(header + 2 + offlutn)
|
||||
AS1 = SDD_K // TZ
|
||||
lockid = triton.where(TZ > 1, 1, 0)
|
||||
offka = pid0 * AS1
|
||||
offkb = pid0 * AS1
|
||||
offmc = 0
|
||||
offnc = 0
|
||||
offpa = 0
|
||||
offpb = 0
|
||||
maxid = TZ
|
||||
offhc = 0
|
||||
offha = z
|
||||
offhb = z
|
||||
ram = i * BLOCK + (triton.arange(0, TM) % BLOCK)
|
||||
rbn = j * BLOCK + (triton.arange(0, TN) % BLOCK)
|
||||
else:
|
||||
header = lut + pid0 * 6
|
||||
offset = triton.load(header + 0)
|
||||
AS1 = triton.load(header + 1)
|
||||
column = triton.load(header + 2)
|
||||
depth = triton.load(header + 3)
|
||||
lockid = triton.load(header + 4)
|
||||
maxid = triton.load(header + 5)
|
||||
pinc = lut + offset
|
||||
offhc = depth
|
||||
if meta['DSD']:
|
||||
# output offset
|
||||
offnc = pid1 * TN
|
||||
offmc = column * TM
|
||||
offpc = 0
|
||||
# dense input offset
|
||||
offnb = pid1 * TN
|
||||
offkb = triton.load(pinc)
|
||||
offkb = triton.multiple_of(offkb, 8) # compiler hint
|
||||
offpb = 0
|
||||
# sparse input offset
|
||||
offma = 0
|
||||
offka = 0
|
||||
offpa = triton.load(pinc + 1)
|
||||
offpa = triton.multiple_of(offpa, 8) # compiler hint
|
||||
offpa = offpa * BLOCK * BLOCK
|
||||
offha = 0
|
||||
offhb = depth
|
||||
else:
|
||||
# output offset
|
||||
offmc = pid1 * TM
|
||||
offnc = column * TN
|
||||
offpc = 0
|
||||
# dense input offset
|
||||
offma = pid1 * TM
|
||||
offka = triton.load(pinc)
|
||||
offka = triton.multiple_of(offka, 8) # compiler hint
|
||||
offpa = 0
|
||||
# sparse input offset
|
||||
offnb = 0
|
||||
offkb = 0
|
||||
offpb = triton.load(pinc + 1)
|
||||
offpb = triton.multiple_of(offpb, 8) # compiler hint
|
||||
offpb = offpb * BLOCK * BLOCK
|
||||
offha = depth
|
||||
offhb = 0
|
||||
ram = offma + triton.arange(0, TM)
|
||||
rbn = offnb + triton.arange(0, TN)
|
||||
|
||||
# initialize a, b pointers
|
||||
rka = offka + triton.arange(0, TK)
|
||||
rkb = offkb + triton.arange(0, TK)
|
||||
pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
|
||||
pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
|
||||
if meta['DDS']:
|
||||
checkam = ram[:, None] < DS0
|
||||
else:
|
||||
checkam = AS1 > 0
|
||||
if meta['DSD']:
|
||||
checkbn = rbn[None, :] < DS0
|
||||
else:
|
||||
checkbn = AS1 > 0
|
||||
a = triton.load(pa, mask=checkam, other=0.)
|
||||
b = triton.load(pb, mask=checkbn, other=0.)
|
||||
|
||||
## ---------------- ##
|
||||
## Inner Loop ##
|
||||
## ---------------- ##
|
||||
acc = triton.zeros((TM, TN), dtype=triton.float32)
|
||||
for k in range(AS1, 0, -TK):
|
||||
acc += triton.dot(a, b)
|
||||
if meta['SDD']:
|
||||
inc_a = TK * stride_ka
|
||||
inc_b = TK * stride_kb
|
||||
else:
|
||||
pinc += 2
|
||||
if meta['DSD']:
|
||||
inc_b = triton.load(pinc)
|
||||
inc_a = triton.load(pinc + 1)
|
||||
inc_b = triton.multiple_of(inc_b, 8)
|
||||
inc_a = triton.multiple_of(inc_a, 8)
|
||||
inc_b = inc_b * stride_kb
|
||||
if meta['DDS']:
|
||||
inc_a = triton.load(pinc)
|
||||
inc_b = triton.load(pinc + 1)
|
||||
inc_a = triton.multiple_of(inc_a, 8)
|
||||
inc_b = triton.multiple_of(inc_b, 8)
|
||||
inc_a = inc_a * stride_ka
|
||||
pa += inc_a
|
||||
pb += inc_b
|
||||
# pre-fetch
|
||||
checkak = k > TK
|
||||
checkbk = k > TK
|
||||
checka = checkam & checkak
|
||||
checkb = checkbn & checkbk
|
||||
a = triton.load(pa, mask=checka)
|
||||
b = triton.load(pb, mask=checkb)
|
||||
c = acc.to(C.dtype.element_ty)
|
||||
|
||||
if meta['SDD']:
|
||||
checkc = True
|
||||
rr_blockidm = triton.arange(0, TM) // BLOCK
|
||||
rr_blockidn = triton.arange(0, TN) // BLOCK
|
||||
rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
|
||||
rr_offlutn = rr_blockidn * 4
|
||||
off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
|
||||
bkid = triton.load(header + off_bkid)
|
||||
offpc = bkid * BLOCK * BLOCK
|
||||
rcm = triton.arange(0, TM) % BLOCK
|
||||
rcn = triton.arange(0, TN) % BLOCK
|
||||
else:
|
||||
rcm = offmc + triton.arange(0, TM)
|
||||
rcn = offnc + triton.arange(0, TN)
|
||||
if meta['DSD']:
|
||||
checkc = rcn[None, :] < DS0
|
||||
if meta['DDS']:
|
||||
checkc = rcm[:, None] < DS0
|
||||
|
||||
pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
|
||||
# write-back directly
|
||||
if lockid == 0:
|
||||
triton.store(pc, c, mask=checkc)
|
||||
# accumulate partial results using spin-locks
|
||||
else:
|
||||
plock = locks + triton.program_id(2) * nlocks * triton.num_programs(1) + triton.program_id(1) * nlocks + lockid - 1
|
||||
pcount = plock + triton.num_programs(2) * triton.num_programs(1) * nlocks
|
||||
while triton.atomic_cas(plock, 0, 1) == 1:
|
||||
pass
|
||||
count = triton.load(pcount)
|
||||
if count == 0:
|
||||
triton.store(pc, c, mask=checkc)
|
||||
else:
|
||||
d = triton.load(pc, mask=checkc)
|
||||
triton.store(pc, d + c, mask=checkc)
|
||||
triton.atomic_xchg(pcount, (count + 1) % maxid)
|
||||
triton.atomic_xchg(plock, 0)
|
||||
|
||||
|
||||
##############
|
||||
@@ -118,31 +294,11 @@ class _matmul(torch.autograd.Function):
|
||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||
# create kernel
|
||||
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
|
||||
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
|
||||
c = torch.zeros((AS0, total_width, block, block), dtype=dtype, device=device)
|
||||
for lut, width, pack in zip(luts, widths, packs):
|
||||
num_lock = 1
|
||||
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
||||
if key not in _matmul.sdd_cache:
|
||||
defines = {
|
||||
'TM': block * pack,
|
||||
'TN': block * pack,
|
||||
'TMN': block * block * pack * pack,
|
||||
'BLOCK': block,
|
||||
'TK': 32,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': '1' if trans_a else 'lda',
|
||||
'STRIDE_AK': 'lda' if trans_a else '1',
|
||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||
'STRIDE_CM': 'ldc',
|
||||
'STRIDE_CN': '1',
|
||||
'SDD': True,
|
||||
'TZ': 1,
|
||||
'NAME': 'sdd_kernel'
|
||||
}
|
||||
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
|
||||
|
||||
kernel = _matmul.sdd_cache[key]
|
||||
meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1, \
|
||||
'SDD': True, 'DSD': False, 'DDS': False}
|
||||
# create output
|
||||
locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
|
||||
# maximum grid size is 65535
|
||||
@@ -150,27 +306,32 @@ class _matmul(torch.autograd.Function):
|
||||
# kernel calls
|
||||
max_width = 49152
|
||||
for off_width in range(0, width, max_width):
|
||||
kernel(
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
a.stride(2),
|
||||
b.stride(2),
|
||||
block,
|
||||
grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), AS0]
|
||||
_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
a.stride(0),
|
||||
b.stride(0),
|
||||
c.stride(0),
|
||||
a.stride(1),
|
||||
a.stride(3 if trans_a else 2),
|
||||
a.stride(2 if trans_a else 3),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
b.stride(3 if trans_b else 2),
|
||||
b.stride(2 if trans_b else 3),
|
||||
c.stride(0),
|
||||
c.stride(0),
|
||||
c.stride(2),
|
||||
c.stride(3),
|
||||
AS2,
|
||||
AS2,
|
||||
AS3,
|
||||
off_width,
|
||||
lut.data_ptr(),
|
||||
locks.data_ptr(),
|
||||
lut,
|
||||
locks,
|
||||
num_lock,
|
||||
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]
|
||||
num_warps=4,
|
||||
**meta
|
||||
)
|
||||
# save for backward pass
|
||||
return c
|
||||
@@ -282,25 +443,8 @@ class _matmul(torch.autograd.Function):
|
||||
BS2 = block * spdims[1 if trans_b else 2]
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _matmul.dds_cache:
|
||||
defines = {
|
||||
'TM': 128,
|
||||
'TN': block,
|
||||
'TK': 16,
|
||||
'BLOCK': block,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': 1 if trans_a else 'lda',
|
||||
'STRIDE_AK': 'lda' if trans_a else 1,
|
||||
'STRIDE_BN': block if trans_b else 1,
|
||||
'STRIDE_BK': 1 if trans_b else block,
|
||||
'STRIDE_CM': '1' if trans_c else 'ldc',
|
||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dds_kernel',
|
||||
'DDS': True
|
||||
}
|
||||
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||
kernel = _matmul.dds_cache[key]
|
||||
meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
|
||||
'SDD': False, 'DSD': False, 'DDS': True}
|
||||
# output
|
||||
CS0 = AS0
|
||||
CS1 = AS1
|
||||
@@ -308,27 +452,32 @@ class _matmul(torch.autograd.Function):
|
||||
CS3 = AS2 if trans_c else BS2
|
||||
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
kernel(
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
a.stride(2),
|
||||
block,
|
||||
c.stride(2),
|
||||
grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0]
|
||||
_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
a.stride(0),
|
||||
b.stride(0),
|
||||
c.stride(0),
|
||||
a.stride(1),
|
||||
a.stride(3 if trans_a else 2),
|
||||
a.stride(2 if trans_a else 3),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
b.stride(3 if trans_b else 2),
|
||||
b.stride(2 if trans_b else 3),
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
c.stride(3 if trans_c else 2),
|
||||
c.stride(2 if trans_c else 3),
|
||||
AS2,
|
||||
BS2,
|
||||
0,
|
||||
0,
|
||||
lut.data_ptr(),
|
||||
locks.data_ptr(),
|
||||
lut,
|
||||
locks,
|
||||
num_locks,
|
||||
grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0]
|
||||
num_warps=4,
|
||||
**meta
|
||||
)
|
||||
return c
|
||||
|
||||
@@ -344,25 +493,8 @@ class _matmul(torch.autograd.Function):
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _matmul.dsd_cache:
|
||||
defines = {
|
||||
'TM': block,
|
||||
'TN': 128,
|
||||
'TK': 16,
|
||||
'BLOCK': block,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': 1 if trans_a else block,
|
||||
'STRIDE_AK': block if trans_a else 1,
|
||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||
'STRIDE_CM': '1' if trans_c else 'ldc',
|
||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dsd_kernel',
|
||||
'DSD': True
|
||||
}
|
||||
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||
kernel = _matmul.dsd_cache[key]
|
||||
meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
|
||||
'SDD': False, 'DSD': True, 'DDS': False}
|
||||
# output
|
||||
CS0 = BS0
|
||||
CS1 = BS1
|
||||
@@ -370,27 +502,32 @@ class _matmul(torch.autograd.Function):
|
||||
CS3 = AS1 if trans_c else BS3
|
||||
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
kernel(
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
block,
|
||||
b.stride(2),
|
||||
c.stride(2),
|
||||
grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0]
|
||||
_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
a.stride(0),
|
||||
b.stride(0),
|
||||
c.stride(0),
|
||||
a.stride(1),
|
||||
a.stride(3 if trans_a else 2),
|
||||
a.stride(2 if trans_a else 3),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
b.stride(3 if trans_b else 2),
|
||||
b.stride(2 if trans_b else 3),
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
c.stride(2),
|
||||
c.stride(3),
|
||||
BS3,
|
||||
AS1,
|
||||
0,
|
||||
0,
|
||||
lut.data_ptr(),
|
||||
locks.data_ptr(),
|
||||
lut,
|
||||
locks,
|
||||
num_locks,
|
||||
grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0]
|
||||
num_warps=4,
|
||||
**meta
|
||||
)
|
||||
return c
|
||||
|
||||
|
@@ -1,135 +0,0 @@
|
||||
__global__ void forward(TYPE *X __readonly __noalias,
|
||||
float scale,
|
||||
int *LUT __readonly __noalias,
|
||||
TYPE *RPE __readonly __noalias,
|
||||
TYPE *KP_M __readonly __noalias,
|
||||
TYPE *ATTN_M __readonly __noalias,
|
||||
int sizemax,
|
||||
long stride_zx,
|
||||
long stride_zrpe,
|
||||
int stride_hrpe,
|
||||
int stride_srpe,
|
||||
int stride_zkpm,
|
||||
int stride_zattnm) {
|
||||
int pidhm = get_program_id(0);
|
||||
int pidz = get_program_id(1);
|
||||
// create index ranges
|
||||
int rxm = pidhm % BLOCK;
|
||||
int rbm = pidhm / BLOCK;
|
||||
int rxn[TN] = (0 ... TN) % BLOCK;
|
||||
int rbn[TN] = (0 ... TN) / BLOCK;
|
||||
// extract information from look-up table
|
||||
int *header = LUT + rbm * 2;
|
||||
int size = *(header + 0);
|
||||
int offset = *(header + 1);
|
||||
bool check[TN] = rbn < size;
|
||||
int rbmn[TN] = check ? rbn : size - 1;
|
||||
// block id and column id
|
||||
long blockid[TN] = *(LUT + offset + rbmn * 4 + 0);
|
||||
long columnid[TN] = *(LUT + offset + rbmn * 4 + 1);
|
||||
long rowid[TN] = *(LUT + offset + rbmn * 4 + 2);
|
||||
long headid[TN] = *(LUT + offset + rbmn * 4 + 3);
|
||||
// pointers to X
|
||||
TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
|
||||
#ifdef APPLY_RPE
|
||||
// pointers to relative position embedding
|
||||
TYPE *prpe[TN] = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn;
|
||||
#endif
|
||||
#ifdef APPLY_KP_MASK
|
||||
// pointers to key padding mask
|
||||
TYPE *pkp_m[TN] = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn;
|
||||
#endif
|
||||
#ifdef APPLY_ATTN_MASK
|
||||
// pointers to attention mask
|
||||
TYPE *pattn_m[TN] = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn;
|
||||
#endif
|
||||
|
||||
// load input
|
||||
TYPE x[TN] = check ? *px : -INFINITY;
|
||||
#ifdef APPLY_RPE
|
||||
// load relative position embedding
|
||||
TYPE rpe[TN] = check ? *prpe : 0;
|
||||
#endif
|
||||
#ifdef APPLY_KP_MASK
|
||||
// load key-padding mask
|
||||
TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
|
||||
#endif
|
||||
#ifdef APPLY_ATTN_MASK
|
||||
// load attention mask
|
||||
TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
|
||||
#endif
|
||||
// compute softmax in float
|
||||
#ifdef APPLY_RPE
|
||||
float Frpe[TN] = rpe;
|
||||
#endif
|
||||
#ifdef APPLY_KP_MASK
|
||||
float Fkp_m[TN] = kp_m;
|
||||
#endif
|
||||
#ifdef APPLY_ATTN_MASK
|
||||
float Fattn_m[TN] = attn_m;
|
||||
#endif
|
||||
#ifdef KP_MASK_MUL
|
||||
Fkp_m = (Fkp_m == 0) ? (float[TN]) - INFINITY : 0;
|
||||
#endif
|
||||
#ifdef ATTN_MASK_MUL
|
||||
Fattn_m = (Fattn_m == 0) ? (float[TN]) - INFINITY : 0;
|
||||
#endif
|
||||
float Fx[TN] = x;
|
||||
#ifdef APPLY_SCALE
|
||||
Fx = Fx * scale; // apply scale
|
||||
#endif
|
||||
#ifdef APPLY_RPE
|
||||
Fx = Fx + Frpe; // apply relative position embedding
|
||||
#endif
|
||||
#ifdef APPLY_KP_MASK
|
||||
Fx = Fx + Fkp_m; // apply key padding mask
|
||||
#endif
|
||||
#ifdef APPLY_ATTN_MASK
|
||||
Fx = Fx + Fattn_m; // apply attention mask
|
||||
#endif
|
||||
float Fxmax = Fx[max];
|
||||
float Fy[TN] = exp(Fx - Fxmax);
|
||||
float Fysum = (check ? Fy : 0)[+];
|
||||
// write-back in half/float
|
||||
TYPE y[TN] = Fy;
|
||||
TYPE ysum = Fysum;
|
||||
*? (check)px = y / ysum;
|
||||
}
|
||||
|
||||
__global__ void backward(TYPE *X __readonly __noalias,
|
||||
float scale,
|
||||
TYPE *DX __readonly __noalias,
|
||||
int *LUT,
|
||||
int sizemax,
|
||||
long stride_zx,
|
||||
long stride_zdx) {
|
||||
int pidhm = get_program_id(0);
|
||||
int pidz = get_program_id(1);
|
||||
// create index ranges
|
||||
int rxm = pidhm % BLOCK;
|
||||
int rbm = pidhm / BLOCK;
|
||||
int rxn[TN] = (0 ... TN) % BLOCK;
|
||||
int rbn[TN] = (0 ... TN) / BLOCK;
|
||||
// extract information from look-up table
|
||||
int *header = LUT + rbm * 2;
|
||||
int size = *(header + 0);
|
||||
int offset = *(header + 1);
|
||||
// bounds checking on lut
|
||||
bool check[TN] = rbn < size;
|
||||
int rbmn[TN] = check ? rbn : size - 1;
|
||||
// initialize pointers to block-sparse input
|
||||
long blockid[TN] = *(LUT + offset + rbmn * 4);
|
||||
TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
|
||||
TYPE *pdx[TN] = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
|
||||
// compute fused softmax backward
|
||||
TYPE x[TN] = check ? *px : 0;
|
||||
TYPE dx[TN] = check ? *pdx : 0;
|
||||
float Fdx[TN] = dx;
|
||||
float Fx[TN] = x;
|
||||
float Fxdx[TN] = Fdx * Fx;
|
||||
float Fxdxsum = Fxdx[+];
|
||||
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
|
||||
TYPE y[TN] = Fy;
|
||||
// write-back
|
||||
*? (check)pdx = y;
|
||||
}
|
@@ -2,24 +2,118 @@ import triton
|
||||
import torch
|
||||
import os
|
||||
|
||||
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
|
||||
fwd_kernels = dict()
|
||||
|
||||
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['backward'])
|
||||
bwd_kernels = dict()
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
def num_warps(n):
|
||||
if n < 512:
|
||||
return 4
|
||||
if n < 2048:
|
||||
return 8
|
||||
return 16
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[6] * meta['BLOCK'])})
|
||||
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[6] * meta['BLOCK'])})
|
||||
@triton.jit
|
||||
def _forward(
|
||||
X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||
**meta
|
||||
):
|
||||
TN = meta['TN']
|
||||
BLOCK = meta['BLOCK']
|
||||
pidhm = triton.program_id(0)
|
||||
pidz = triton.program_id(1)
|
||||
# create index ranges
|
||||
rxm = pidhm % BLOCK
|
||||
rbm = pidhm // BLOCK
|
||||
rxn = triton.arange(0, TN) % BLOCK
|
||||
rbn = triton.arange(0, TN) // BLOCK
|
||||
# extract information from LUT
|
||||
header = LUT + rbm * 2
|
||||
size = triton.load(header + 0)
|
||||
offset = triton.load(header + 1)
|
||||
check = rbn < size
|
||||
rbmn = triton.where(check, rbn, size - 1)
|
||||
# block id and column id
|
||||
blockid = triton.load(LUT + offset + rbmn * 4 + 0)
|
||||
columnid = triton.load(LUT + offset + rbmn * 4 + 1)
|
||||
rowid = triton.load(LUT + offset + rbmn * 4 + 2)
|
||||
headid = triton.load(LUT + offset + rbmn * 4 + 3)
|
||||
# pointers to X
|
||||
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||
x = triton.load(px, mask=check, other=-float('inf'))
|
||||
x = x.to(triton.float32)
|
||||
# apply scale
|
||||
if meta['APPLY_SCALE']:
|
||||
x = x * scale
|
||||
# apply RPE
|
||||
if meta['APPLY_RPE']:
|
||||
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
|
||||
rpe = triton.load(prpe, mask=check, other=0)
|
||||
x = x + rpe
|
||||
# apply key-padding mask
|
||||
if meta['APPLY_KP_MASK']:
|
||||
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
|
||||
kp_m = triton.load(pkp_m, mask=check, other=-float('inf'))
|
||||
if meta['KP_MASK_MUL']:
|
||||
kp_m = triton.where(kp_m == 0, -float('inf'), 0.)
|
||||
x = x + kp_m
|
||||
# apply attention mask
|
||||
if meta['APPLY_ATTN_MASK']:
|
||||
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
|
||||
attn_m = triton.load(pattn_m, mask=check, other=-float('inf'))
|
||||
if meta['ATTN_MASK_MUL']:
|
||||
attn_m = triton.where(attn_m == 0, -float('inf'), 0.)
|
||||
x = x + attn_m
|
||||
# computation
|
||||
x = triton.softmax(x)
|
||||
triton.store(px, x, mask=check)
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
|
||||
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})
|
||||
@triton.jit
|
||||
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
|
||||
pidhm = triton.program_id(0)
|
||||
pidz = triton.program_id(1)
|
||||
TN = meta['TN']
|
||||
BLOCK = meta['BLOCK']
|
||||
# create index ranges
|
||||
rxm = pidhm % BLOCK
|
||||
rbm = pidhm // BLOCK
|
||||
rxn = triton.arange(0, TN) % BLOCK
|
||||
rbn = triton.arange(0, TN) // BLOCK
|
||||
# extract information from look-up table
|
||||
header = LUT + rbm * 2
|
||||
size = triton.load(header + 0)
|
||||
offset = triton.load(header + 1)
|
||||
# bounds checking on lut
|
||||
check = rbn < size
|
||||
rbmn = triton.where(check, rbn, size - 1)
|
||||
# initialize pointers to block-sparse input
|
||||
blockid = triton.load(LUT + offset + rbmn * 4)
|
||||
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||
# compute fused softmax backward
|
||||
x = triton.load(X, mask=check, other=0)
|
||||
dx = triton.load(DX, mask=check, other=0)
|
||||
x = x.to(triton.float32)
|
||||
dx = dx.to(triton.float32)
|
||||
y = x * (dx - triton.sum(x * dx, 0)) * scale
|
||||
triton.store(DX, y, mask=check)
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n += 1
|
||||
return n
|
||||
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
@@ -43,40 +137,9 @@ class _softmax(torch.autograd.Function):
|
||||
return lut, int(sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
|
||||
kp_mask_mode, attn_mask_mode):
|
||||
if max_k >= 32768:
|
||||
raise NotImplementedError('Reductions larger than 32768 elements '\
|
||||
'are not yet implemented')
|
||||
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
|
||||
TN = _softmax.next_power_of_2(max_k)
|
||||
# just-in-time compile kernel
|
||||
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
|
||||
kp_mask_mode, attn_mask_mode)
|
||||
if key not in cache:
|
||||
defines = {
|
||||
'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, 'INFINITY':
|
||||
{torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype]
|
||||
}
|
||||
if apply_scale:
|
||||
defines['APPLY_SCALE'] = True
|
||||
if apply_rpe:
|
||||
defines['APPLY_RPE'] = True
|
||||
if apply_kp_mask:
|
||||
defines['APPLY_KP_MASK'] = True
|
||||
if kp_mask_mode == 'mul':
|
||||
defines['KP_MASK_MUL'] = True
|
||||
if apply_attn_mask:
|
||||
defines['APPLY_ATTN_MASK'] = True
|
||||
if attn_mask_mode == 'mul':
|
||||
defines['ATTN_MASK_MUL'] = True
|
||||
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
|
||||
cache[key] = kernel
|
||||
return cache[key]
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,
|
||||
maxlut, bench, time):
|
||||
def forward(
|
||||
ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut, maxlut, bench, time
|
||||
):
|
||||
apply_scale = False if scale == 1.0 else True
|
||||
|
||||
# handle None rpe
|
||||
@@ -107,26 +170,20 @@ class _softmax(torch.autograd.Function):
|
||||
stride_zattnm = attn_mask.stride(0)
|
||||
|
||||
# run kernel
|
||||
kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale,
|
||||
apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
|
||||
M = x.shape[0]
|
||||
meta = {
|
||||
'BLOCK': block,
|
||||
'APPLY_SCALE': apply_scale,
|
||||
'APPLY_RPE': apply_rpe,
|
||||
'APPLY_KP_MASK': apply_kp_mask,
|
||||
'APPLY_ATTN_MASK': apply_attn_mask,
|
||||
'KP_MASK_MUL': kp_mask_mode == 'mul',
|
||||
'ATTN_MASK_MUL': attn_mask_mode == 'mul',
|
||||
}
|
||||
grid = lambda opt: [spdims[0] * spdims[1] * block, M]
|
||||
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
|
||||
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)
|
||||
|
||||
# run kernel
|
||||
kernel(x.data_ptr(),
|
||||
scale,
|
||||
lut.data_ptr(),
|
||||
rpe.data_ptr(),
|
||||
key_padding_mask.data_ptr(),
|
||||
attn_mask.data_ptr(),
|
||||
maxlut,
|
||||
x.stride(0),
|
||||
stride_zrpe,
|
||||
stride_hrpe,
|
||||
stride_srpe,
|
||||
stride_zkpm,
|
||||
stride_zattnm,
|
||||
grid=grid)
|
||||
# save to context
|
||||
ctx.mark_dirty(x)
|
||||
ctx.save_for_backward(x, lut)
|
||||
@@ -147,14 +204,12 @@ class _softmax(torch.autograd.Function):
|
||||
# retrieve from context
|
||||
x, lut = ctx.saved_tensors
|
||||
# run kernel
|
||||
kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut * ctx.block, x.device, x.dtype, ctx.block,
|
||||
ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask,
|
||||
ctx.kp_mask_mode, ctx.attn_mask_mode)
|
||||
M = x.shape[0]
|
||||
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
|
||||
kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
|
||||
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
|
||||
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class softmax:
|
||||
|
||||
apply_softmax = _softmax.apply
|
||||
@@ -172,14 +227,9 @@ class softmax:
|
||||
self.bench = bench
|
||||
self.lut_cache = dict()
|
||||
|
||||
def __call__(self,
|
||||
x,
|
||||
scale=1.,
|
||||
rpe=None,
|
||||
key_padding_mask=None,
|
||||
attn_mask=None,
|
||||
key_padding_mask_mode='add',
|
||||
attn_mask_mode='add'):
|
||||
def __call__(
|
||||
self, x, scale=1., rpe=None, key_padding_mask=None, attn_mask=None, key_padding_mask_mode='add', attn_mask_mode='add'
|
||||
):
|
||||
time_y = [None]
|
||||
if rpe is not None and rpe.dtype != x.dtype:
|
||||
raise ValueError('relative position embedding must be %s' % x.dtype)
|
||||
@@ -188,6 +238,8 @@ class softmax:
|
||||
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
|
||||
raise ValueError('Key padding mask must be %s' % x.dtype)
|
||||
lut, maxlut = self.make_lut(x.device)
|
||||
x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode,
|
||||
self.spdims, self.block, lut, maxlut, self.bench, time_y)
|
||||
x = softmax.apply_softmax(
|
||||
x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode, self.spdims, self.block, lut,
|
||||
maxlut, self.bench, time_y
|
||||
)
|
||||
return x
|
@@ -1,123 +0,0 @@
|
||||
__global__ void conv(TYPE *A __noalias __readonly,
|
||||
TYPE *B __noalias __readonly,
|
||||
TYPE *C __noalias,
|
||||
float alpha,
|
||||
// equivalent matmul
|
||||
int M, int N, int K,
|
||||
// convolution properties
|
||||
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||
// pointer increment
|
||||
int *ADELTA,
|
||||
// memory strides
|
||||
int lda_z, int lda_ci, int lda_h, int lda_w,
|
||||
int ldb_ci, int ldb_r, int ldb_s, int ldb_co,
|
||||
int ldc_z, int ldc_co, int ldc_p, int ldc_q)
|
||||
{
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
ridx = rid / gridy;
|
||||
ridy = rid % gridy;
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
|
||||
// unpack aggregate rows
|
||||
// m = (z, p, q)
|
||||
int rq[TM] = rm % QQ;
|
||||
int rzp[TM] = rm / QQ;
|
||||
int rp[TM] = rzp % PP;
|
||||
int rz[TM] = rzp / PP;
|
||||
// unpack aggregate reduction
|
||||
// k = (ci, r, s)
|
||||
int rs[TK] = rk % SS;
|
||||
int rcir[TK] = rk / SS;
|
||||
int rr[TK] = rcir % RR;
|
||||
int rci[TK] = rcir / RR;
|
||||
|
||||
// padding / striding
|
||||
int rh_0[TM] = rp * stride_h - pad_h;
|
||||
int rw_0[TM] = rq * stride_w - pad_w;
|
||||
int rh[TM, TK] = rh_0[:, newaxis] + rr [newaxis, :];
|
||||
int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :];
|
||||
|
||||
// pointers to lhs
|
||||
int offa[TM, TK] = rz[:, newaxis] * lda_z + rci [newaxis, :] * lda_ci +
|
||||
rh * lda_h + rw * 1;
|
||||
TYPE *pa[TM, TK] = A + offa;
|
||||
int *padelta[TK] = ADELTA + rk;
|
||||
// pointers to rhs
|
||||
int offb[TK, TN] = rci[:, newaxis] * ldb_ci + rr[:, newaxis] * ldb_r +
|
||||
rs[:, newaxis] * ldb_s + rn [newaxis, :] * 1;
|
||||
TYPE *pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checkam[TM, TK] = rm[:, newaxis] < M;
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
int total = 0;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for (int k = K; k > 0; k -= TK)
|
||||
{
|
||||
acc += a @b;
|
||||
// increment A
|
||||
int adelta[TK] = *padelta;
|
||||
padelta += TK;
|
||||
pa += adelta [newaxis, :];
|
||||
// bounds-checking A
|
||||
rk += TK;
|
||||
rs = rk % SS;
|
||||
rcir = rk / SS;
|
||||
rr = rcir % RR;
|
||||
rh = rh_0[:, newaxis] + rr [newaxis, :];
|
||||
rw = rw_0[:, newaxis] + rs [newaxis, :];
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
// increment B
|
||||
pb += TK * ldb_s;
|
||||
// bounds-checking B
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = *? (checkb)pb;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
rm = ridx * TM + 0 ... TM;
|
||||
rn = ridy * TN + 0 ... TN;
|
||||
rq = rm % QQ;
|
||||
rzp = rm / QQ;
|
||||
rp = rzp % PP;
|
||||
rz = rzp / PP;
|
||||
int offc[TM, TN] = rz[:, newaxis] * ldc_z + rn [newaxis, :] * ldc_co +
|
||||
rp[:, newaxis] * ldc_p + rq[:, newaxis] * 1;
|
||||
TYPE *pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N;
|
||||
|
||||
#if (TZ == 1)
|
||||
*? (checkc)pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
|
||||
;
|
||||
int count = *pcount;
|
||||
if (count == 0)
|
||||
*? (checkc)pc = c;
|
||||
else
|
||||
*? (checkc)pc = c + *? (checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
@@ -1,81 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
import os
|
||||
|
||||
class _conv(torch.autograd.Function):
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c'))
|
||||
kernel = dict()
|
||||
|
||||
@staticmethod
|
||||
def unpack(IDX, CI, R, S):
|
||||
s = IDX % S
|
||||
cr = IDX // S
|
||||
r = cr % R
|
||||
ci = cr // R
|
||||
return ci, r, s
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, pad, stride):
|
||||
# create kernel if necessary
|
||||
dtype = a.dtype
|
||||
device = a.device
|
||||
# shapes
|
||||
Z, CI, H, W = a.shape
|
||||
_, R, S, CO = b.shape
|
||||
P = (H + 2 * pad[0] - R) // stride[0] + 1
|
||||
Q = (W + 2 * pad[1] - S) // stride[1] + 1
|
||||
# compile kernel
|
||||
if (dtype, device) not in _conv.kernel:
|
||||
TK = 16
|
||||
defines = {
|
||||
'TYPE': dtype,
|
||||
'TM': 64,
|
||||
'TN': 64,
|
||||
'TK': TK,
|
||||
'TZ': 1,
|
||||
'HH': H,
|
||||
'WW': W,
|
||||
'PP': P,
|
||||
'QQ': Q,
|
||||
'SS': S,
|
||||
'RR': R,
|
||||
}
|
||||
idx = torch.arange(CI * R * S)
|
||||
ci, r, s = _conv.unpack(idx, CI, R, S)
|
||||
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
||||
delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (ns - s) * a.stride(3)
|
||||
delta = delta.type(torch.int32).cuda()
|
||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
|
||||
delta, kernel = _conv.kernel[dtype]
|
||||
# allocate output
|
||||
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
|
||||
# enqueue
|
||||
kernel(
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
1.,
|
||||
Z * P * Q,
|
||||
CO,
|
||||
CI * R * S,
|
||||
pad[0],
|
||||
pad[1],
|
||||
stride[0],
|
||||
stride[1],
|
||||
delta.data_ptr(),
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
a.stride(2),
|
||||
a.stride(3),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
b.stride(2),
|
||||
b.stride(3),
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
c.stride(2),
|
||||
c.stride(3),
|
||||
grid=lambda opt: [triton.cdiv(Z * P * Q, opt.TM), triton.cdiv(CO, opt.TN)])
|
||||
return c
|
||||
|
||||
conv = _conv.apply
|
@@ -1,35 +0,0 @@
|
||||
__global__ void forward(TYPE *logit, TYPE *modified_logit, long *indices, TYPE *result, int n_cols) {
|
||||
int row = get_program_id(0);
|
||||
|
||||
bool check[TILE] = ((0 ... TILE) < n_cols);
|
||||
int offset[TILE] = row * n_cols + 0 ... TILE;
|
||||
TYPE *px[TILE] = logit + offset;
|
||||
TYPE *pmodified[TILE] = modified_logit + offset;
|
||||
long local_ind = *(indices + row);
|
||||
|
||||
TYPE F16[TILE] = check ? *px : -INFINITY;
|
||||
float shifted_logit[TILE] = F16 - F16[max];
|
||||
float neg_logprob[TILE] = log(exp(shifted_logit)[+]) - shifted_logit;
|
||||
*? (check)pmodified = neg_logprob;
|
||||
__debug_barrier();
|
||||
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
|
||||
}
|
||||
|
||||
__global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs, int n_cols) {
|
||||
|
||||
int row = get_program_id(0);
|
||||
// pointer arithmetic
|
||||
bool check[TILE] = ((0 ... TILE) < n_cols);
|
||||
int offset[TILE] = row * n_cols + 0 ... TILE;
|
||||
TYPE *px[TILE] = neg_logprobs + offset;
|
||||
long local_ind = *(indices + row);
|
||||
TYPE local_dn = *(dneg_logprobs + row);
|
||||
// We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||
// and we have -log(p[k]) stored, so this is easy
|
||||
TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * px) : 0;
|
||||
// selected_logit_idx is selected logit index for our token
|
||||
bool find_one[TILE] = ((0 ... TILE) == local_ind);
|
||||
intermediate = intermediate - ((TYPE[TILE])find_one);
|
||||
// multiply by dneg_logprobs
|
||||
*? (check)px = intermediate * local_dn;
|
||||
}
|
@@ -2,6 +2,7 @@ import os
|
||||
import triton
|
||||
import torch
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
@@ -12,34 +13,61 @@ def next_power_of_2(n):
|
||||
n += 1
|
||||
return n
|
||||
|
||||
def largest_pow2_divisor(N):
|
||||
if N % 8 == 0: return 8
|
||||
if N % 4 == 0: return 4
|
||||
if N % 2 == 0: return 2
|
||||
return 1
|
||||
|
||||
def make_kernel(device, dtype, n_cols, cache, name):
|
||||
rounded = next_power_of_2(n_cols)
|
||||
div = largest_pow2_divisor(n_cols)
|
||||
key = (dtype, rounded, div)
|
||||
if key not in cache:
|
||||
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
|
||||
src = triton.read(fname, kernel_names=[name])
|
||||
infinities = {
|
||||
torch.float16: "F16_INFINITY",
|
||||
torch.float32: "F32_INFINITY",
|
||||
}
|
||||
defines = {"TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div}
|
||||
cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4)
|
||||
return cache[key]
|
||||
def num_warps(N):
|
||||
if N < 2048:
|
||||
return 4
|
||||
elif N < 8192:
|
||||
return 8
|
||||
return 16
|
||||
|
||||
# forward kernel
|
||||
fwd_kernels = dict()
|
||||
make_fwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, fwd_kernels, "forward")
|
||||
|
||||
# backward kernel
|
||||
bwd_kernels = dict()
|
||||
make_bwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, bwd_kernels, "backward")
|
||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])})
|
||||
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])})
|
||||
@triton.jit
|
||||
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
|
||||
BLOCK = meta['BLOCK']
|
||||
row = triton.program_id(0)
|
||||
cols = triton.arange(0, BLOCK)
|
||||
idx = triton.load(IDX + row)
|
||||
# pointers to logit and probs
|
||||
LOGITS = LOGITS + row * N + cols
|
||||
WRIT_PROBS = PROBS + row * N + cols
|
||||
READ_PROBS = PROBS + row * N + idx
|
||||
# write-back negative log-probs
|
||||
logits = triton.load(LOGITS, mask=cols < N, other=-float('inf'))
|
||||
logits = logits.to(triton.float32)
|
||||
logits = logits - triton.max(logits, 0)
|
||||
probs = triton.log(triton.sum(triton.exp(logits), 0)) - logits
|
||||
triton.store(WRIT_PROBS, probs, mask=cols < N)
|
||||
# There is a bug in the compiler, which fails to insert a barrier here.
|
||||
# We add it explicitly for now. Will be fixed soon.
|
||||
triton.debug_barrier()
|
||||
# write-back loss
|
||||
probs = triton.load(READ_PROBS)
|
||||
triton.store(LOSS + row, probs)
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
|
||||
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])})
|
||||
@triton.jit
|
||||
def _backward(PROBS, IDX, DPROBS, N, **meta):
|
||||
BLOCK = meta['BLOCK']
|
||||
row = triton.program_id(0)
|
||||
cols = triton.arange(0, BLOCK)
|
||||
idx = triton.load(IDX + row)
|
||||
# pointers to probs
|
||||
PROBS = PROBS + row * N + cols
|
||||
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||
# and we have -log(p[k]) stored in PROBS, so this is easy
|
||||
probs = -triton.load(PROBS, mask=cols < N, other=float('inf'))
|
||||
probs = triton.exp(probs.to(triton.float32))
|
||||
delta = cols == idx
|
||||
# write result in-place in PROBS
|
||||
dout = triton.load(DPROBS + row)
|
||||
din = (probs - delta) * dout
|
||||
triton.store(PROBS, din.to(triton.float16), mask=cols < N)
|
||||
|
||||
|
||||
class _cross_entropy(torch.autograd.Function):
|
||||
@classmethod
|
||||
@@ -49,16 +77,11 @@ class _cross_entropy(torch.autograd.Function):
|
||||
# make kernel
|
||||
device, dtype = logits.device, logits.dtype
|
||||
n_cols = logits.shape[-1]
|
||||
kernel = make_fwd_kernel(device, dtype, n_cols)
|
||||
# run the kernel
|
||||
result = torch.empty_like(indices, dtype=dtype, device=device)
|
||||
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
|
||||
kernel(logits.data_ptr(),
|
||||
neg_logprobs.data_ptr(),
|
||||
indices.data_ptr(),
|
||||
result.data_ptr(),
|
||||
n_cols,
|
||||
grid=lambda opt: (logits.numel() // n_cols, ))
|
||||
grid = lambda opt: (logits.numel() // n_cols, )
|
||||
_forward[grid](logits, neg_logprobs, indices, result, n_cols)
|
||||
# save for backward
|
||||
ctx.save_for_backward(neg_logprobs, indices)
|
||||
return result
|
||||
@@ -75,14 +98,11 @@ class _cross_entropy(torch.autograd.Function):
|
||||
# make kernel
|
||||
device, dtype = neg_logprobs.device, neg_logprobs.dtype
|
||||
n_cols = neg_logprobs.shape[-1]
|
||||
kernel = make_bwd_kernel(device, dtype, n_cols)
|
||||
# run the kernel
|
||||
# neg_logprobs will be modified in place to become our gradient:
|
||||
kernel(neg_logprobs.data_ptr(),
|
||||
indices.data_ptr(),
|
||||
dneg_logprobs.data_ptr(),
|
||||
n_cols,
|
||||
grid=lambda opt: (neg_logprobs.numel() // n_cols, ))
|
||||
grid = lambda opt: (neg_logprobs.numel() // n_cols, )
|
||||
_backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)
|
||||
return neg_logprobs, None
|
||||
|
||||
|
||||
cross_entropy = _cross_entropy.apply
|
@@ -1,94 +0,0 @@
|
||||
#define STM 1
|
||||
#define STN 1
|
||||
|
||||
__global__ void matmul(TYPE *A __noalias __readonly,
|
||||
TYPE *B __noalias __readonly,
|
||||
TYPE *C __noalias,
|
||||
float alpha,
|
||||
int M, int N, int K,
|
||||
int lda, int ldb, int ldc,
|
||||
int *locks) {
|
||||
// prologue
|
||||
int pid = get_program_id(0);
|
||||
int pidz = get_program_id(2);
|
||||
int gridm = (M + TM - 1) / TM;
|
||||
int gridn = (N + TN - 1) / TN;
|
||||
|
||||
// swizzle for better L2 performance
|
||||
int width = STM * gridn;
|
||||
int stm = pid / width;
|
||||
int RSTM = min(gridm - stm * STM, STM);
|
||||
int stn = (pid % width) / (RSTM * STN);
|
||||
int RSTN = min(gridn - stn * STN, STN);
|
||||
int laneid = pid % (RSTM * RSTN);
|
||||
int lanem = laneid / RSTN;
|
||||
int lanen = laneid % RSTN;
|
||||
int pidm = stm * STM + lanem;
|
||||
int pidn = stn * STN + lanen;
|
||||
int rm[TM] = pidm * TM + 0 ... TM;
|
||||
int rn[TN] = pidn * TN + 0 ... TN;
|
||||
|
||||
// split-k for better parrallelism
|
||||
K = K / SPLITK;
|
||||
int rk[TK] = 0 ... TK;
|
||||
// pointers to operands
|
||||
int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||
int offb[TK, TN] = (pidz * K + rk[:, newaxis]) * STRIDE_BK + rn [newaxis, :] * STRIDE_BN;
|
||||
TYPE *pa[TM, TK] = A + offa;
|
||||
TYPE *pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checka[TM, TK] = rk [newaxis, :] < K;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for (int k = K; k > 0; k -= TK) {
|
||||
#if (IS_TK_DIV_K == 1)
|
||||
bool checkk[TK] = k > TK;
|
||||
#else
|
||||
bool checkk[TK] = rk < k - TK;
|
||||
#endif
|
||||
bool checka[TM, TK] = checkk [newaxis, :];
|
||||
bool checkb[TK, TN] = checkk[:, newaxis];
|
||||
acc += a @b;
|
||||
#if (IS_TK_DIV_K == 1)
|
||||
a = *? (checka)pa;
|
||||
b = *? (checkb)pb;
|
||||
#else
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
#endif
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rcm[TM] = pidm * TM + 0 ... TM;
|
||||
int rcn[TN] = pidn * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :];
|
||||
TYPE *pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N;
|
||||
#if (SPLITK == 1)
|
||||
*? (checkc)pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + pid;
|
||||
int *pcount = plock + get_num_programs(0);
|
||||
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
|
||||
;
|
||||
int count = *pcount;
|
||||
if (count == 0)
|
||||
*? (checkc)pc = c;
|
||||
else
|
||||
*? (checkc)pc = c + *? (checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % SPLITK);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
@@ -1,108 +1,117 @@
|
||||
import torch
|
||||
import triton
|
||||
import os
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda *args, **meta: args[5] % (meta['BLOCK_K'] * meta['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
|
||||
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
|
||||
# triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
|
||||
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2),\
|
||||
# triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2),
|
||||
],
|
||||
key=['M', 'N', 'K']
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, LOCKS, **META):
|
||||
# extract meta-parameters
|
||||
BLOCK_M = META['BLOCK_M']
|
||||
BLOCK_N = META['BLOCK_N']
|
||||
BLOCK_K = META['BLOCK_K']
|
||||
GROUP_M = META['GROUP_M']
|
||||
SPLIT_K = META['SPLIT_K']
|
||||
# matrix multiplication
|
||||
pid = triton.program_id(0)
|
||||
pid_z = triton.program_id(1)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
rk = triton.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
K = K // SPLIT_K
|
||||
A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
if META['EVEN_K']:
|
||||
a = triton.load(A)
|
||||
b = triton.load(B)
|
||||
else:
|
||||
a = triton.load(A, mask=rk[None, :] < k, other=0.)
|
||||
b = triton.load(B, mask=rk[:, None] < k, other=0.)
|
||||
acc += triton.dot(a, b)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
acc = acc.to(triton.float16)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
triton.store(C, acc, mask=mask)
|
||||
else:
|
||||
LOCKS = LOCKS + triton.program_id(0)
|
||||
COUNT = LOCKS + triton.num_programs(0)
|
||||
while triton.atomic_cas(LOCKS, 0, 1) == 1:
|
||||
pass
|
||||
count = triton.load(COUNT)
|
||||
if count == 0:
|
||||
triton.store(C, acc, mask=mask)
|
||||
else:
|
||||
curr = triton.load(C, mask=mask, other=0.)
|
||||
triton.store(C, acc + curr, mask=mask)
|
||||
triton.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
|
||||
triton.atomic_xchg(LOCKS, 0)
|
||||
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
|
||||
|
||||
_DEFAULT_CONFIGS = [
|
||||
triton.config(defines={"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, num_warps=4),
|
||||
triton.config(defines={'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
|
||||
triton.config(defines={'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
|
||||
triton.config(defines={'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
|
||||
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
|
||||
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
|
||||
triton.config(defines={'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
|
||||
triton.config(defines={'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
|
||||
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
|
||||
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
|
||||
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
|
||||
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
|
||||
]
|
||||
_CONFIGS = _DEFAULT_CONFIGS
|
||||
|
||||
@staticmethod
|
||||
def largest_pow2_divisor(N):
|
||||
if N % 8 == 0:
|
||||
return 8
|
||||
if N % 4 == 0:
|
||||
return 4
|
||||
if N % 2 == 0:
|
||||
return 2
|
||||
return 1
|
||||
kernel = _kernel
|
||||
|
||||
_locks = dict()
|
||||
_kernels = dict()
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b):
|
||||
dtype = a.dtype
|
||||
device = a.device
|
||||
# allocate output
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
c = torch.empty((M, N), dtype=dtype, device=device)
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# kernel hash
|
||||
is_a_row = a.stride(1) == 1
|
||||
is_b_row = b.stride(1) == 1
|
||||
lda = a.stride(0) if is_a_row else a.stride(1)
|
||||
ldb = b.stride(0) if is_b_row else b.stride(1)
|
||||
ldc = c.stride(0)
|
||||
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
|
||||
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
|
||||
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
|
||||
is_tk_div_k = K % 64 == 0
|
||||
key = (
|
||||
device,
|
||||
dtype,
|
||||
is_a_row,
|
||||
is_b_row,
|
||||
lda_pow2_div,
|
||||
ldb_pow2_div,
|
||||
ldc_pow2_div,
|
||||
is_tk_div_k,
|
||||
)
|
||||
if key not in _matmul._kernels:
|
||||
defines = {
|
||||
"TYPE": dtype,
|
||||
"STRIDE_AM": "lda" if is_a_row else "1",
|
||||
"STRIDE_AK": "1" if is_a_row else "lda",
|
||||
"STRIDE_BK": "ldb" if is_b_row else "1",
|
||||
"STRIDE_BN": "1" if is_b_row else "ldb",
|
||||
"LDA_POW2_DIV": lda_pow2_div,
|
||||
"LDB_POW2_DIV": ldb_pow2_div,
|
||||
"LDC_POW2_DIV": ldc_pow2_div,
|
||||
"IS_TK_DIV_K": int(is_tk_div_k),
|
||||
}
|
||||
_matmul._kernels[key] = triton.kernel(
|
||||
_matmul.src,
|
||||
device,
|
||||
defines=defines,
|
||||
autotune_configs=_matmul._CONFIGS,
|
||||
autotune_key=["M", "N", "K"],
|
||||
)
|
||||
kernel = _matmul._kernels[key]
|
||||
# # locks for split-k
|
||||
if device not in _matmul._locks:
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=a.dtype)
|
||||
# allocate locks for split-k
|
||||
if a.device not in _matmul._locks:
|
||||
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
|
||||
locks = _matmul._locks[device]
|
||||
# enqueue
|
||||
alpha = 1.0
|
||||
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
|
||||
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.SPLITK]
|
||||
kernel(*args, grid=grid)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), locks)
|
||||
# done
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
c = _matmul._call(a, b)
|
||||
return c
|
||||
return _matmul._call(a, b)
|
||||
|
||||
|
||||
matmul = _matmul.apply
|
||||
|
@@ -47,13 +47,37 @@ def mask_tensor(x, mask, block, value=0):
|
||||
|
||||
|
||||
def allclose(x, y, tol=1e-2):
|
||||
assert x.dtype == y.dtype
|
||||
if x.dtype != y.dtype:
|
||||
raise RuntimeError(f'{x.dtype} did not match with {x.dtype}')
|
||||
if x.shape != y.shape:
|
||||
raise RuntimeError(f'{x.shape} did not match with {y.shape}')
|
||||
if x.dtype == torch.bool:
|
||||
return torch.sum(x ^ y) == 0
|
||||
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
tol = 0
|
||||
diff = abs(x - y)
|
||||
x_max = torch.max(x)
|
||||
y_max = torch.max(y)
|
||||
tol = 1e-2
|
||||
err = torch.max(diff) / torch.max(x_max, y_max)
|
||||
return err < tol
|
||||
return err <= tol
|
||||
|
||||
|
||||
def assert_allclose(x, y, tol=1e-2):
|
||||
assert x.dtype == y.dtype
|
||||
assert allclose(x, y, tol)
|
||||
|
||||
|
||||
def random(shape, dtype, device):
|
||||
if isinstance(shape, int):
|
||||
shape = (shape, )
|
||||
if dtype == torch.bool:
|
||||
return torch.randint(0, 2, shape, dtype=dtype, device=device)
|
||||
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
return torch.randint(1, 32, shape, dtype=dtype, device=device)
|
||||
if dtype in [torch.float16, torch.float32, torch.float64]:
|
||||
return torch.randn(shape, dtype=dtype, device=device)
|
||||
raise RuntimeError(f'Unknown dtype {dtype}')
|
||||
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
|
||||
|
@@ -1,139 +1,71 @@
|
||||
import torch
|
||||
import triton
|
||||
"""
|
||||
Vector Addition
|
||||
=================
|
||||
In this tutorial, you will write a simple vector addition using Triton and learn about:
|
||||
|
||||
- The basic syntax of the Triton programming language
|
||||
- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API
|
||||
- The basic programming model used by Triton
|
||||
- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.
|
||||
- The best practices for validating and benchmarking custom ops against native reference implementations
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# --------------------------
|
||||
#
|
||||
# Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel
|
||||
# on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)
|
||||
# programming model for more details).
|
||||
#
|
||||
# .. code-block:: C
|
||||
#
|
||||
# __global__ void add(float* z, float* x, float* y, int N){
|
||||
# // The `get_program_id(i)` returns the i-th coordinate
|
||||
# // of the program in the overaching SPMD context
|
||||
# // (a.k.a launch grid). This is what allows us to process
|
||||
# // different chunks of data in parallel.
|
||||
# // For those similar with CUDA, `get_program_id({0,1,2})`
|
||||
# // is similar to blockIdx.{x,y,z}
|
||||
# int pid = get_program_id(0);
|
||||
# // In Triton, arrays are first-class citizen. In other words,
|
||||
# // they are primitives data-types and are -- contrary to C and
|
||||
# // CUDA -- not implemented as pointers to contiguous chunks of
|
||||
# // memory.
|
||||
# // In the few lines below, we create an array of `BLOCK` pointers
|
||||
# // whose memory values are, e.g.:
|
||||
# // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]
|
||||
# // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time
|
||||
# int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
# float* pz [BLOCK] = z + offset;
|
||||
# float* px [BLOCK] = x + offset;
|
||||
# float* py [BLOCK] = y + offset;
|
||||
# // Simple element-wise control-flow for load/store operations can
|
||||
# // be achieved using the the ternary operator `cond ? val_true : val_false`
|
||||
# // or the conditional dereferencing operator `*?(cond)ptr
|
||||
# // Here, we make sure that we do not access memory out-of-bounds when we
|
||||
# // write-back `z`
|
||||
# bool check[BLOCK] = offset < N;
|
||||
# *?(check)pz = *?(check)px + *?(check)py;
|
||||
# }
|
||||
#
|
||||
# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(
|
||||
X, # *Pointer* to first input vector
|
||||
Y, # *Pointer* to second input vector
|
||||
Z, # *Pointer* to output vector
|
||||
N, # Size of the vector
|
||||
**meta # Optional meta-parameters for the kernel
|
||||
):
|
||||
pid = triton.program_id(0)
|
||||
# Create an offset for the blocks of pointers to be
|
||||
# processed by this program instance
|
||||
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
|
||||
# Create a mask to guard memory operations against
|
||||
# out-of-bounds accesses
|
||||
mask = offsets < N
|
||||
# Load x
|
||||
x = triton.load(X + offsets, mask=mask)
|
||||
y = triton.load(Y + offsets, mask=mask)
|
||||
# Write back x + y
|
||||
z = x + y
|
||||
triton.store(Z + offsets, z)
|
||||
|
||||
|
||||
# %%
|
||||
# Torch Bindings
|
||||
# --------------------------
|
||||
# The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
|
||||
#
|
||||
# - :code:`source: string`: the source-code of the kernel you want to create
|
||||
# - :code:`device: torch.device`: the device you want to compile this code for
|
||||
# - :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# source-code for Triton compute kernel
|
||||
# here we just copy-paste the above code without the extensive comments.
|
||||
# you may prefer to store it in a .c file and load it from there instead.
|
||||
_src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float* pz[BLOCK] = z + offset;
|
||||
float* px[BLOCK] = x + offset;
|
||||
float* py[BLOCK] = y + offset;
|
||||
// bounds checking
|
||||
bool check[BLOCK] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
# We can also declara a helper function that handles allocating the output vector
|
||||
# and enqueueing the kernel.
|
||||
|
||||
|
||||
# This function returns a callable `triton.kernel` object created from the above source code.
|
||||
# For portability, we maintain a cache of kernels for different `torch.device`
|
||||
# We compile the kernel with -DBLOCK=1024
|
||||
def make_add_kernel(device):
|
||||
cache = make_add_kernel.cache
|
||||
if device not in cache:
|
||||
defines = {'BLOCK': 1024}
|
||||
cache[device] = triton.kernel(_src, device=device, defines=defines)
|
||||
return cache[device]
|
||||
def add(x, y):
|
||||
z = torch.empty_like(x)
|
||||
N = z.shape[0]
|
||||
# The SPMD launch grid denotes the number of kernel instances that should execute in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
||||
# NOTE:
|
||||
# - torch.tensor objects are implicitly converted to pointers to their first element.
|
||||
# - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
_add[grid](x, y, z, N, BLOCK=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously.
|
||||
return z
|
||||
|
||||
|
||||
make_add_kernel.cache = dict()
|
||||
|
||||
|
||||
# This is a standard torch custom autograd Function;
|
||||
# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`
|
||||
class _add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
# *allocate output*
|
||||
z = torch.empty_like(x)
|
||||
# *create launch grid*:
|
||||
# this is a function which takes compilation parameters `opt`
|
||||
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
|
||||
# triton.cdiv is a shortcut for ceil division:
|
||||
# triton.cdiv(a, b) = (a + b - 1) // b
|
||||
N = z.shape[0]
|
||||
grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )
|
||||
# *launch kernel*:
|
||||
# pointer to the data of torch tensors can be retrieved with
|
||||
# the `.data_ptr()` method
|
||||
kernel = make_add_kernel(z.device)
|
||||
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)
|
||||
return z
|
||||
|
||||
|
||||
# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function
|
||||
add = _add.apply
|
||||
|
||||
# %%
|
||||
# We can now use the above function to compute the sum of two `torch.tensor` objects:
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# -----------
|
||||
#
|
||||
# Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:
|
||||
# We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432, device='cuda')
|
||||
y = torch.rand(98432, device='cuda')
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(za)
|
||||
|
@@ -4,8 +4,7 @@ Fused Softmax
|
||||
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- The syntax and usage of reduction operators in Triton.
|
||||
- The automatic vectorization capabilities of the Triton compiler.
|
||||
- The reduction operators in Triton.
|
||||
"""
|
||||
|
||||
# %%
|
||||
@@ -36,79 +35,45 @@ def naive_softmax(x):
|
||||
# %%
|
||||
# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
|
||||
# In this case, we would be reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
# Our softmax kernel works as follows: each program loads a row of the input X, normalizes it and writes back the result to the output Y.
|
||||
# Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
# Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
|
||||
#
|
||||
# .. code-block:: C
|
||||
#
|
||||
# __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){
|
||||
# // row index
|
||||
# int m = get_program_id(0);
|
||||
# // column indices
|
||||
# int n [BLOCK] = 0 ... BLOCK;
|
||||
# // the memory address of all the elements
|
||||
# // that we want to load can be computed as follows
|
||||
# float* px [BLOCK] = X + m*stride_xm + n;
|
||||
# // because BLOCK has to be a power of two
|
||||
# // (per Triton-C specs), it is important
|
||||
# // to guard each memory operation with predicates
|
||||
# // or we will read out of bounds
|
||||
# bool check[BLOCK] = n < N;
|
||||
# float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
# // syntax for reduction in Triton is:
|
||||
# // x[:, :, OPERATOR, :, :]
|
||||
# // ^
|
||||
# // index
|
||||
# // where operator is in {min, max, +}
|
||||
# // for 1D vectors, this is just x[OPERATOR].
|
||||
# float z [BLOCK] = x - x[max];
|
||||
# // Note that exponentials in Triton are fast
|
||||
# // but approximate (i.e., think __expf in CUDA)
|
||||
# float num [BLOCK] = exp(z);
|
||||
# float denom = num[+];
|
||||
# // The result of the reduction is now stored in y
|
||||
# float y [BLOCK] = num / denom;
|
||||
# // We write it back
|
||||
# float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
# *?(check)py = y;
|
||||
# }
|
||||
|
||||
# %%
|
||||
# Torch Bindings
|
||||
# ---------------
|
||||
# Here our torch bindings is quite similar to that of the vector addition mentioned in the previous tutorial.
|
||||
# We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.
|
||||
# This means that different values of BLOCK will result in different kernels
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# Source code for the Triton kernel
|
||||
_src = """
|
||||
__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){
|
||||
int m = get_program_id(0);
|
||||
int n [BLOCK] = 0 ... BLOCK;
|
||||
float* px [BLOCK] = X + m*stride_xm + n;
|
||||
bool check[BLOCK] = n < N;
|
||||
float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
float z [BLOCK] = x - x[max];
|
||||
float num [BLOCK] = exp(z);
|
||||
float denom = num[+];
|
||||
float y [BLOCK] = num / denom;
|
||||
float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
*?(check)py = y;
|
||||
}
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
|
||||
# row index
|
||||
m = triton.program_id(0)
|
||||
# col indices
|
||||
n = triton.arange(0, meta['BLOCK'])
|
||||
# the memory address of all the elements
|
||||
# that we want to load can be computed as follows
|
||||
X = X + m * stride_xm + n
|
||||
x = triton.load(X, mask=n < N, other=-float('inf'))
|
||||
# Substract maximum for numerical stability
|
||||
z = x - triton.max(x, axis=0)
|
||||
# Note that exponentials in Triton are fast
|
||||
# but approximate (i.e., think __expf in CUDA)
|
||||
num = triton.exp(z)
|
||||
denom = triton.sum(num, axis=0)
|
||||
y = num / denom
|
||||
# Write back to Y
|
||||
Y = Y + m * stride_ym + n
|
||||
triton.store(Y, y, mask=n < N)
|
||||
|
||||
|
||||
# %%
|
||||
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
|
||||
|
||||
|
||||
# helper function to get the smaller power-of-two larger than a given number
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
@@ -120,11 +85,9 @@ def next_power_of_2(n):
|
||||
return n
|
||||
|
||||
|
||||
# kernel caching mechanism
|
||||
def make_kernel(N, device):
|
||||
cache = make_kernel.cache
|
||||
# Now are kernels are indexed not only by the provided device but also
|
||||
# by the rounded number of columns in the input matrix
|
||||
def softmax(x):
|
||||
M, N = x.shape
|
||||
# The block size is the smallest power of two greater than the number of columns in `x`
|
||||
BLOCK = next_power_of_2(N)
|
||||
# Another trick we can use is to ask the compiler to parallelize each
|
||||
# row-normalization more aggressively -- i.e., with more warps -- vectors
|
||||
@@ -134,37 +97,13 @@ def make_kernel(N, device):
|
||||
num_warps = 4
|
||||
if BLOCK >= 2048: num_warps = 8
|
||||
if BLOCK >= 4096: num_warps = 16
|
||||
# Each (BLOCK, num_warps, device) results in a different kernel
|
||||
key = (BLOCK, num_warps, device)
|
||||
if key not in cache:
|
||||
defines = {'BLOCK': BLOCK}
|
||||
cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
|
||||
return cache[key]
|
||||
# Allocate output
|
||||
y = torch.empty_like(x)
|
||||
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK)
|
||||
return y
|
||||
|
||||
|
||||
make_kernel.cache = dict()
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
y = torch.empty_like(x)
|
||||
# The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||
M, N = y.shape
|
||||
grid = lambda opt: (M, )
|
||||
# Launch kernel
|
||||
kernel = make_kernel(N, y.device)
|
||||
kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)
|
||||
return y
|
||||
|
||||
|
||||
softmax = _softmax.apply
|
||||
|
||||
# %%
|
||||
# We can use the above softmax function to compute the row-wise softmax of a given matrix.
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# ----------
|
||||
|
@@ -1,10 +1,10 @@
|
||||
"""
|
||||
Matrix Multiplication
|
||||
======================
|
||||
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS's performance.
|
||||
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.
|
||||
You will specifically learn about:
|
||||
|
||||
- The block-level matrix multiplication operator `@`
|
||||
- Block-level matrix multiplications
|
||||
- Multi-dimensional pointer arithmetic
|
||||
- Program re-ordering for improved L2 cache hit rate
|
||||
- Automatic performance tuning
|
||||
@@ -15,7 +15,7 @@ You will specifically learn about:
|
||||
# -------------
|
||||
# Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
# They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
# For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
|
||||
#
|
||||
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
|
||||
@@ -23,322 +23,212 @@ You will specifically learn about:
|
||||
# .. code-block:: python
|
||||
#
|
||||
# # do in parallel
|
||||
# for m in range(0, M, MB):
|
||||
# for m in range(0, M, BLOCK_M):
|
||||
# # do in parallel
|
||||
# for n in range(0, N, NB):
|
||||
# acc = zeros((MB, NB), dtype=float32)
|
||||
# for k in range(0, K, KB):
|
||||
# acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB]
|
||||
# C[m : m+MB, n : n+NB] = acc;
|
||||
# for n in range(0, N, BLOCK_N):
|
||||
# acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
|
||||
# for k in range(0, K, BLOCK_K):
|
||||
# a = A[m : m+BLOCK_M, k : k+BLOCK_K]
|
||||
# b = B[k : k+BLOCK_K, n : n+BLOCK_N]
|
||||
# acc += dot(a, b)
|
||||
# C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
|
||||
#
|
||||
# where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.
|
||||
# where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
#
|
||||
# The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication.
|
||||
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop.
|
||||
# The above algorithm is actually fairly straightforward to implement in Triton.
|
||||
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop.
|
||||
#
|
||||
# Pointer Arithmetics
|
||||
# ~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`.
|
||||
# Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as:
|
||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
||||
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# &A[m : m+MB, k:k+KB] = A + (m : m+MB)[:, newaxis]*A.stride(0) + (k : k+KB)[newaxis, :];
|
||||
# &B[k : k+KB, n:n+NB] = B + (k : k+KB)[:, newaxis]*B.stride(0) + (n : n+NB)[newaxis, :];
|
||||
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :];
|
||||
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :];
|
||||
#
|
||||
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
|
||||
#
|
||||
# .. code-block:: C
|
||||
# .. code-block:: python
|
||||
# :force:
|
||||
#
|
||||
# int rm[MB] = program_id_m * MB + 0 ... MB;
|
||||
# int rn[NB] = program_id_n * NB + 0 ... NB;
|
||||
# int rk[KB] = 0 ... KB;
|
||||
# TYPE *pa[MB, KB] = A + (rm[:, newaxis] * stride_a_0 + rk [newaxis, :] * 1);
|
||||
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
|
||||
# pid_m = triton.program_id(0)
|
||||
# pid_n = triton.program_id(1)
|
||||
# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
# rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
# rk = triton.arange(0, BLOCK_K)
|
||||
# // pointer for A operand
|
||||
# pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
|
||||
# // pointer for B operand
|
||||
# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
|
||||
#
|
||||
# These pointers can then be updated in the inner loop as:
|
||||
#
|
||||
# .. code-block:: C
|
||||
# .. code-block:: python
|
||||
#
|
||||
# pa += KB * 1;
|
||||
# pb += KB * ldb;
|
||||
# pa += BLOCK_K * stride_a_1;
|
||||
# pb += BLOCK_K * stride_b_0;
|
||||
#
|
||||
#
|
||||
# L2 Cache Optimizations
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# As mentioned above, each program instance computes an :code:`[MB, NB]` block of :code:`C`.
|
||||
# As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
|
||||
# However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
|
||||
# This means that a naive row-major ordering:
|
||||
#
|
||||
# .. code-block:: C
|
||||
# .. code-block:: Python
|
||||
#
|
||||
# int program_id = get_program_id(0);
|
||||
# int grid_m = (M + MB - 1) / MB;
|
||||
# int grid_n = (N + NB - 1) / NB;
|
||||
# int program_id_m = program_id / grid_n;
|
||||
# int program_id_n = program_id % grid_n;
|
||||
# pid = triton.program_id(0);
|
||||
# grid_m = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
# grid_n = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
# pid_m = pid / grid_n;
|
||||
# pid_n = pid % grid_n;
|
||||
#
|
||||
# is unlikely to result in optimal performance.
|
||||
#
|
||||
# One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE` before switching to the next column:
|
||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
|
||||
#
|
||||
# .. code-block:: C
|
||||
#
|
||||
# int program_id = get_program_id(0);
|
||||
# int width = GROUP_SIZE * grid_n;
|
||||
# int group_id = pid / width;
|
||||
# // we need to handle the case where M % (GROUP_SIZE*BM) != 0
|
||||
# int group_size = min(grid_m - group_id * GROUP_SIZE, GROUP_SIZE);
|
||||
# int pid_m = group_id * GROUP_SIZE + (pid % group_size);
|
||||
# int pid_n = (pid % width) / (group_size);
|
||||
# pid = triton.program_id(0);
|
||||
# width = GROUP_M * grid_n;
|
||||
# group_id = pid / width;
|
||||
# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
|
||||
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
||||
# pid_m = group_id * GROUP_M + (pid % group_size);
|
||||
# pid_n = (pid % width) / (group_size);
|
||||
#
|
||||
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
#
|
||||
# Final Result
|
||||
# ~~~~~~~~~~~~~~
|
||||
#
|
||||
# We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
|
||||
# Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure.
|
||||
# This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.
|
||||
#
|
||||
# .. code-block:: C
|
||||
# :force:
|
||||
#
|
||||
# #define MAX_GROUP_SIZE 8
|
||||
#
|
||||
# __global__ void dot(TYPE* A, TYPE* B, TYPE* C,
|
||||
# int M, int N, int K,
|
||||
# int stride_a_0, int stride_b_0, int stride_c_0) {
|
||||
# // prologue
|
||||
# int pid = get_program_id(0);
|
||||
# int grid_m = (M + MB - 1) / MB;
|
||||
# int grid_n = (N + NB - 1) / NB;
|
||||
# // re-order program ID for better L2 performance
|
||||
# int width = MAX_GROUP_SIZE * grid_n;
|
||||
# int group_id = pid / width;
|
||||
# int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
|
||||
# int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
|
||||
# int pid_n = (pid % width) / (group_size);
|
||||
# // pointers to operands
|
||||
# // note the parentheses here; they force the offset
|
||||
# // computation to happen in typeof(stride_a_0) = int32 rather than
|
||||
# // typeof(A) = int64
|
||||
# int rm[MB] = pid_m * MB + 0 ... MB;
|
||||
# int rn[NB] = pid_n * NB + 0 ... NB;
|
||||
# int rk[KB] = 0 ... KB;
|
||||
# TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0);
|
||||
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
|
||||
# // reduction loop
|
||||
# float acc[MB, NB] = 0;
|
||||
# for (int k = K; k > 0; k -= KB) {
|
||||
# acc += (*pa) @ (*pb);
|
||||
# pa += KB * 1;
|
||||
# pb += KB * stride_b_0;
|
||||
# }
|
||||
# // pointers to output
|
||||
# // here we rematerialize `rm` and `rn` so that they are not live through
|
||||
# // the above reduction loop. In the future, the compiler should be able to
|
||||
# // do this automatically.
|
||||
# rm = pid_m * MB + 0 ... MB;
|
||||
# rn = pid_n * NB + 0 ... NB;
|
||||
# TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]);
|
||||
# // we write back using *?() operator. `acc` gets casted to `float32` implicitly.
|
||||
# *? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
|
||||
# }
|
||||
#
|
||||
# Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code.
|
||||
# Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
|
||||
# If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16.
|
||||
#
|
||||
|
||||
# %%
|
||||
# Torch Bindings
|
||||
# ----------------
|
||||
# Final Result
|
||||
# -------------
|
||||
#
|
||||
# Auto-Tuning
|
||||
# ~~~~~~~~~~~~~~
|
||||
#
|
||||
# In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
autotune_configs = [
|
||||
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
|
||||
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
|
||||
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
|
||||
]
|
||||
# %
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
|
||||
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid(x):
|
||||
ret_true = 1 / (1 + triton.exp(-x))
|
||||
ret_false = triton.exp(x) / (1 + triton.exp(x))
|
||||
return triton.where(x >= 0, ret_true, ret_false)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def swish(x):
|
||||
return x * sigmoid(x)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
# %
|
||||
# We can now define our kernel as normal, using all the techniques presented above
|
||||
@triton.jit
|
||||
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
|
||||
# extract meta-parameters
|
||||
BLOCK_M = META['BLOCK_M']
|
||||
BLOCK_N = META['BLOCK_N']
|
||||
BLOCK_K = META['BLOCK_K']
|
||||
GROUP_M = 8
|
||||
# matrix multiplication
|
||||
pid = triton.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
rk = triton.arange(0, BLOCK_K)
|
||||
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
a = triton.load(A)
|
||||
b = triton.load(B)
|
||||
acc += triton.dot(a, b)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
# triton can accept arbitrary activation function
|
||||
# via metaparameters!
|
||||
if META['ACTIVATION']:
|
||||
acc = META['ACTIVATION'](acc)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
triton.store(C, acc, mask=mask)
|
||||
|
||||
|
||||
# %%
|
||||
# we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
|
||||
# Here, we want to re-tune our kernel only when the shape of input matrices changes.
|
||||
|
||||
autotune_key = ["M", "N", "K"]
|
||||
|
||||
# %%
|
||||
# We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
|
||||
|
||||
src = """
|
||||
#define MAX_GROUP_SIZE 8
|
||||
|
||||
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
|
||||
int M, int N, int K,
|
||||
int lda, int ldb, int ldc) {
|
||||
int pid = get_program_id(0);
|
||||
int grid_m = (M + MB - 1) / MB;
|
||||
int grid_n = (N + NB - 1) / NB;
|
||||
int width = MAX_GROUP_SIZE * grid_n;
|
||||
int group_id = pid / width;
|
||||
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
|
||||
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
|
||||
int pid_n = (pid % width) / (group_size);
|
||||
int rm[MB] = pid_m * MB + 0 ... MB;
|
||||
int rn[NB] = pid_n * NB + 0 ... NB;
|
||||
int rk[KB] = 0 ... KB;
|
||||
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);
|
||||
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);
|
||||
float acc[MB, NB] = 0;
|
||||
for (int k = K; k > 0; k -= KB) {
|
||||
acc += (*pa) @ (*pb);
|
||||
pa += KB * 1;
|
||||
pb += KB * ldb;
|
||||
}
|
||||
rm = pid_m * MB + 0 ... MB;
|
||||
rn = pid_n * NB + 0 ... NB;
|
||||
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);
|
||||
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
|
||||
}
|
||||
"""
|
||||
# We can also create a convenience wrapper function that only takes two input tensors
|
||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel
|
||||
|
||||
|
||||
def make_kernel(device, dtype):
|
||||
key = (device, dtype)
|
||||
cache = make_kernel.cache
|
||||
if key not in cache:
|
||||
defines = {'TYPE': dtype}
|
||||
cache[key] = triton.kernel(
|
||||
src,
|
||||
device=device,
|
||||
defines=defines,
|
||||
autotune_configs=autotune_configs,
|
||||
autotune_key=autotune_key,
|
||||
)
|
||||
return cache[key]
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||
assert b.is_contiguous(), "matrix B must be contiguous"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
_matmul[grid](
|
||||
a, b, c, M, N, K, \
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||
ACTIVATION = activation
|
||||
)
|
||||
# return output
|
||||
return c
|
||||
|
||||
|
||||
make_kernel.cache = dict()
|
||||
|
||||
# %%
|
||||
# Autograd Function
|
||||
# ~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
|
||||
# To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
|
||||
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
M, Ka = a.shape
|
||||
Kb, N = b.shape
|
||||
assert Ka == Kb, "incompatible dimensions"
|
||||
assert a.is_contiguous() and b.is_contiguous(), "inputs must be contiguous"
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
kernel = make_kernel(a.device, a.dtype)
|
||||
grid = lambda opt: (triton.cdiv(M, opt.MB) * triton.cdiv(N, opt.NB), )
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
|
||||
M, N, Ka, \
|
||||
a.stride(0), b.stride(0), c.stride(0), \
|
||||
grid=grid)
|
||||
return c
|
||||
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# -----------
|
||||
#
|
||||
# We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
|
||||
# Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
|
||||
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
|
||||
|
||||
a = torch.rand((512, 768), device='cuda', dtype=torch.float16)
|
||||
b = torch.rand((768, 896), device='cuda', dtype=torch.float16)
|
||||
c_0 = dot(a, b)
|
||||
c_1 = torch.matmul(a, b)
|
||||
#torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c_0 = matmul(a, b, activation=swish)
|
||||
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
|
||||
print(c_0)
|
||||
print(c_1)
|
||||
print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
|
||||
print(triton.testing.allclose(c_0, c_1))
|
||||
|
||||
# %%
|
||||
# Benchmark
|
||||
# --------------
|
||||
#
|
||||
# Installing The CUTLASS Bindings
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
|
||||
# For this reason, we will instead compare the performance of our kernel against `CUTLASS <https://github.com/NVIDIA/cutlass/>`_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
|
||||
# To install CUTLASS, you need a recent version of cmake:
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# cd /path/to/cutlass/
|
||||
# git clone https://github.com/NVIDIA/cutlass.git
|
||||
# cd cutlass
|
||||
# mkdir build
|
||||
# cd build
|
||||
# wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
|
||||
# tar xzvf *.tar.gz
|
||||
#
|
||||
# You can then install CUTLASS as follows for V100
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=70 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s884gemm_f16_*_align8
|
||||
# make -j8 install
|
||||
#
|
||||
# Or as follows for A100:
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s16816gemm_*align8
|
||||
# make -j8 install
|
||||
#
|
||||
# Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
|
||||
# Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process.
|
||||
# To re-install Triton with the updated CUTLASS bindings, run the following command:
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
|
||||
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
|
||||
# pip uninstall -y triton
|
||||
# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
|
||||
#
|
||||
# Which we can test as follows:
|
||||
|
||||
import triton
|
||||
c_2 = triton.testing.cutlass_matmul(a, b)
|
||||
print(c_2)
|
||||
print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
|
||||
# %%
|
||||
# Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
|
||||
#
|
||||
# Square Matrix Performance
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
|
||||
@@ -347,29 +237,25 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
x_vals=[8192], # different possible values for `x_name`
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines
|
||||
y_vals=['cublas', 'triton'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton"], # label name for the lines
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
silu = torch.nn.SiLU()
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
|
||||
if provider == 'cutlass':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.testing.cutlass_matmul(a, b))
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
|
||||
# %%
|
||||
# As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.
|
||||
benchmark.run(print_data=True)
|
Reference in New Issue
Block a user