Add argmin argmax (#552)
This commit is contained in:
@@ -353,9 +353,6 @@ ir::value *sqrt(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_sqrt(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.min
|
||||
----------------------------------------------*/
|
||||
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
|
||||
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
@@ -367,6 +364,9 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
|
||||
throw_not_int_or_float(name);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.min
|
||||
----------------------------------------------*/
|
||||
std::string min_docstr = R"pbdoc(
|
||||
Returns the minimum value of `input`.
|
||||
)pbdoc";
|
||||
@@ -374,6 +374,16 @@ ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arg_min
|
||||
----------------------------------------------*/
|
||||
std::string min_docstr = R"pbdoc(
|
||||
Returns the minimum value's index of `input`.
|
||||
)pbdoc";
|
||||
ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.max
|
||||
----------------------------------------------*/
|
||||
@@ -384,6 +394,16 @@ ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arg_max
|
||||
----------------------------------------------*/
|
||||
std::string max_docstr = R"pbdoc(
|
||||
Returns the maximum value's index of `input`.
|
||||
)pbdoc";
|
||||
ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.sum
|
||||
----------------------------------------------*/
|
||||
|
@@ -573,8 +573,14 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("MAX", ir::reduce_inst::MAX)
|
||||
.value("UMIN", ir::reduce_inst::UMIN)
|
||||
.value("UMAX", ir::reduce_inst::UMAX)
|
||||
.value("ARGMIN", ir::reduce_inst::ARGMIN)
|
||||
.value("ARGMAX", ir::reduce_inst::ARGMAX)
|
||||
.value("ARGUMIN", ir::reduce_inst::ARGUMIN)
|
||||
.value("ARGUMAX", ir::reduce_inst::ARGUMAX)
|
||||
.value("FMIN", ir::reduce_inst::FMIN)
|
||||
.value("FMAX", ir::reduce_inst::FMAX)
|
||||
.value("ARGFMIN", ir::reduce_inst::ARGFMIN)
|
||||
.value("ARGFMAX", ir::reduce_inst::ARGFMAX)
|
||||
.value("XOR", ir::reduce_inst::XOR);
|
||||
|
||||
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
|
||||
|
@@ -690,7 +690,7 @@ def test_f16_to_f8_rounding():
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||
[(op, dtype, shape)
|
||||
for op in ['min', 'max', 'sum']
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for dtype in dtypes
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
@@ -707,28 +707,37 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||
x_tri = to_triton(x, device=device)
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
# numpy result
|
||||
z_ref = numpy_op(x).astype(getattr(np, dtype_str))
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device)
|
||||
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
if op == 'sum':
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, to_numpy(z_tri))
|
||||
if op == 'argmin' or op == 'argmax':
|
||||
# argmin and argmax can have multiple valid indices.
|
||||
# so instead we compare the values pointed by indices
|
||||
np.testing.assert_equal(x[z_ref], x[z_tri])
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
reduce_configs1 = [
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes
|
||||
for op in ['min', 'max', 'sum']
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for axis in [1]
|
||||
]
|
||||
reduce_configs2 = [
|
||||
(op, 'float32', shape, 1)
|
||||
for op in ['min', 'max', 'sum']
|
||||
(op, 'float32', shape, axis)
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
|
||||
for axis in [0, 1]
|
||||
]
|
||||
|
||||
|
||||
@@ -741,7 +750,10 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z + range_m, z)
|
||||
if AXIS == 1:
|
||||
tl.store(Z + range_m, z)
|
||||
else:
|
||||
tl.store(Z + range_n, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
|
||||
# input
|
||||
@@ -749,17 +761,30 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
x_tri = to_triton(x)
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
# numpy result
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, dtype_str))
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device)
|
||||
binary = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
if op == 'sum':
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, to_numpy(z_tri))
|
||||
if op == 'argmin' or op == 'argmax':
|
||||
# argmin and argmax can have multiple valid indices.
|
||||
# so instead we compare the values pointed by indices
|
||||
z_ref_index = np.expand_dims(z_ref, axis=axis)
|
||||
z_tri_index = np.expand_dims(z_tri, axis=axis)
|
||||
z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis)
|
||||
z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis)
|
||||
np.testing.assert_equal(z_ref_value, z_tri_value)
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
# ---------------
|
||||
# test permute
|
||||
|
@@ -1000,6 +1000,13 @@ def max(input, axis, _builder=None):
|
||||
return semantic.max(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("maximum index")
|
||||
def argmax(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.argmax(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("minimum")
|
||||
def min(input, axis, _builder=None):
|
||||
@@ -1007,6 +1014,13 @@ def min(input, axis, _builder=None):
|
||||
return semantic.min(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("minimum index")
|
||||
def argmin(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.argmin(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("sum")
|
||||
def sum(input, axis, _builder=None):
|
||||
|
@@ -961,10 +961,14 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
|
||||
# choose the right unsigned operation
|
||||
if scalar_ty.is_int_unsigned():
|
||||
if INT_OP is ir.REDUCE_OP.MIN:
|
||||
INT_OP = ir.REDUCE_OP.UMIN
|
||||
elif INT_OP is ir.REDUCE_OP.MAX:
|
||||
INT_OP = ir.REDUCE_OP.UMAX
|
||||
int_op_to_unit = {
|
||||
ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN,
|
||||
ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX,
|
||||
ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN,
|
||||
ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX,
|
||||
}
|
||||
if INT_OP in int_op_to_unit:
|
||||
INT_OP = int_op_to_unit[INT_OP]
|
||||
|
||||
# get result type
|
||||
shape = input.type.shape
|
||||
@@ -988,10 +992,18 @@ def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
|
||||
|
||||
|
||||
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
|
||||
|
||||
|
||||
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
|
||||
|
||||
|
||||
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
|
||||
|
||||
|
||||
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
|
||||
|
||||
|
Reference in New Issue
Block a user