[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)
This commit is contained in:
@@ -1195,10 +1195,11 @@ void init_triton_ir(py::module &&m) {
|
||||
operand.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = inputTensorType.getShape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
mlir::Type resType = inputTensorType.getElementType();
|
||||
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||
mlir::Type resType = withIndex ? self.getI32Type()
|
||||
: inputTensorType.getElementType();
|
||||
if (!shape.empty()) {
|
||||
resType = mlir::RankedTensorType::get(
|
||||
shape, inputTensorType.getElementType());
|
||||
resType = mlir::RankedTensorType::get(shape, resType);
|
||||
}
|
||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
||||
operand, axis);
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
@@ -13,7 +14,9 @@ dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
|
||||
|
||||
|
||||
def get_reduced_dtype(dtype):
|
||||
def get_reduced_dtype(op, dtype):
|
||||
if op in ['argmin', 'argmax']:
|
||||
return torch.int32
|
||||
if dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||
return torch.int32
|
||||
if dtype in [torch.bfloat16]:
|
||||
@@ -48,7 +51,7 @@ def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, blo
|
||||
|
||||
reduce1d_configs = [
|
||||
(op, dtype, shape)
|
||||
for op in ['sum', 'min', 'max']
|
||||
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
|
||||
for dtype in dtypes
|
||||
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
||||
]
|
||||
@@ -56,8 +59,11 @@ reduce1d_configs = [
|
||||
|
||||
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
||||
def test_reduce1d(op, dtype, shape):
|
||||
if op == 'xor_sum' and dtype in float_dtypes:
|
||||
return
|
||||
|
||||
dtype = dtype_mapping[dtype]
|
||||
reduced_dtype = get_reduced_dtype(dtype)
|
||||
reduced_dtype = get_reduced_dtype(op, dtype)
|
||||
|
||||
if dtype.is_floating_point:
|
||||
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
||||
@@ -79,8 +85,17 @@ def test_reduce1d(op, dtype, shape):
|
||||
golden_z = torch.sum(x, dtype=reduced_dtype)
|
||||
elif op == 'min':
|
||||
golden_z = torch.min(x).to(reduced_dtype)
|
||||
else:
|
||||
elif op == 'max':
|
||||
golden_z = torch.max(x).to(reduced_dtype)
|
||||
elif op == 'argmin':
|
||||
golden_z = torch.argmin(x).to(reduced_dtype)
|
||||
elif op == 'argmax':
|
||||
golden_z = torch.argmax(x).to(reduced_dtype)
|
||||
elif op == 'xor_sum':
|
||||
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy())
|
||||
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
|
||||
else:
|
||||
raise RuntimeError(f'Unknwon reduce op {op}')
|
||||
|
||||
if dtype.is_floating_point and op == 'sum':
|
||||
if shape >= 256:
|
||||
@@ -95,7 +110,7 @@ def test_reduce1d(op, dtype, shape):
|
||||
|
||||
reduce2d_configs = [
|
||||
(op, dtype, shape, axis)
|
||||
for op in ['sum', 'min', 'max']
|
||||
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
|
||||
for dtype in dtypes
|
||||
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
||||
for axis in [0, 1]
|
||||
@@ -104,8 +119,11 @@ reduce2d_configs = [
|
||||
|
||||
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
||||
def test_reduce2d(op, dtype, shape, axis):
|
||||
if op == 'xor_sum' and dtype in float_dtypes:
|
||||
return
|
||||
|
||||
dtype = dtype_mapping[dtype]
|
||||
reduced_dtype = get_reduced_dtype(dtype)
|
||||
reduced_dtype = get_reduced_dtype(op, dtype)
|
||||
reduced_shape = (shape[1 - axis],)
|
||||
|
||||
if dtype.is_floating_point:
|
||||
@@ -123,8 +141,18 @@ def test_reduce2d(op, dtype, shape, axis):
|
||||
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
|
||||
elif op == 'min':
|
||||
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||
else:
|
||||
elif op == 'max':
|
||||
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||
elif op == 'argmin':
|
||||
golden_z = torch.argmin(x, dim=axis, keepdim=False).to(reduced_dtype)
|
||||
elif op == 'argmax':
|
||||
golden_z = torch.argmax(x, dim=axis, keepdim=False).to(reduced_dtype)
|
||||
elif op == 'xor_sum':
|
||||
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy(), axis=axis, keepdims=False)
|
||||
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
|
||||
else:
|
||||
raise RuntimeError(f'Unknwon reduce op {op}')
|
||||
|
||||
if dtype.is_floating_point and op == 'sum':
|
||||
if shape[axis] >= 256:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||
|
@@ -1041,6 +1041,13 @@ def max(input, axis, _builder=None):
|
||||
return semantic.max(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("maximum index")
|
||||
def argmax(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.argmax(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("minimum")
|
||||
def min(input, axis, _builder=None):
|
||||
@@ -1048,6 +1055,13 @@ def min(input, axis, _builder=None):
|
||||
return semantic.min(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("minimum index")
|
||||
def argmin(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.argmin(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("sum")
|
||||
def sum(input, axis, _builder=None):
|
||||
|
@@ -1061,10 +1061,18 @@ def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
|
||||
|
||||
|
||||
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
|
||||
|
||||
|
||||
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
|
||||
|
||||
|
||||
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
|
||||
|
||||
|
||||
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
|
||||
|
||||
|
Reference in New Issue
Block a user