[BACKEND] Better bf16 support (#588)
This commit is contained in:
@@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
// types
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define bf16_ty builder_->getBFloatTy()
|
||||
#define bf16_ty builder_->getInt16Ty()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define i1_ty builder_->getInt1Ty()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
@@ -178,7 +178,7 @@ Type *generator::cvt(ir::type *ty) {
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
|
||||
case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_);
|
||||
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); // use int16 as storage type
|
||||
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
||||
@@ -378,8 +378,8 @@ void generator::visit_launch_inst(ir::launch_inst *launch) {
|
||||
*/
|
||||
void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
using ll = llvm::Instruction::BinaryOps;
|
||||
using tt = ir::binary_op_t;
|
||||
auto cvt = [](ir::binary_op_t op){
|
||||
using tt = ir::binary_op_t;
|
||||
switch(op) {
|
||||
case tt::Add: return ll::Add;
|
||||
case tt::FAdd: return ll::FAdd;
|
||||
@@ -406,20 +406,51 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
|
||||
x->get_type()->get_scalar_ty()->is_fp32_ty()){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
|
||||
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
|
||||
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
|
||||
// manually select bf16 bin op
|
||||
if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
|
||||
assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
|
||||
if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b
|
||||
InlineAsm *bf16_add_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"{ .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0x3f80U; \n\t" // 1.0
|
||||
" fma.rn.bf16 $0, $1, c, $2; } \n\t",
|
||||
"=h,h,h", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
|
||||
} else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
|
||||
InlineAsm *bf16_sub_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
" { .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0xbf80U; \n\t" // -1.0
|
||||
" fma.rn.bf16 $0, $2, c, $1;} \n\t",
|
||||
"=h,h,h", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
|
||||
} else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
|
||||
InlineAsm *bf16_mul_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
" { .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0x8000U; \n\t" // 0.0
|
||||
" fma.rn.bf16 $0, $1, $2, c;} \n\t",
|
||||
"=h,h,h", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
|
||||
} else
|
||||
throw std::runtime_error("invalid bin op for bf16");
|
||||
} else { // not bf16
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
|
||||
x->get_type()->get_scalar_ty()->is_fp32_ty()){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
|
||||
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
|
||||
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
|
||||
|
||||
}
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -970,8 +1001,6 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
has_l2_evict_policy = false;
|
||||
auto idxs = idxs_.at(val_op);
|
||||
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
||||
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
|
||||
ty = f16_ty;
|
||||
if(ty->isIntegerTy(1))
|
||||
ty = builder_->getInt8Ty();
|
||||
for(size_t i = 0; i < idxs.size(); i += vec){
|
||||
@@ -2830,9 +2859,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
// pointer to temporary shared memory
|
||||
Type *ty = cvt(out->get_type()->get_scalar_ty());
|
||||
|
||||
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
|
||||
ty = f16_ty;
|
||||
|
||||
// Orders
|
||||
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
|
||||
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
|
||||
@@ -3229,8 +3255,22 @@ void generator::visit_constant_int(ir::constant_int *x){
|
||||
|
||||
void generator::visit_constant_fp(ir::constant_fp *x){
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
for(indices_t idx: idxs_.at(x))
|
||||
vals_[x][idx] = ConstantFP::get(ty, x->get_value());
|
||||
for(indices_t idx: idxs_.at(x)) {
|
||||
// manually select bf16 constant
|
||||
if (x->get_type()->get_scalar_ty()->is_bf16_ty()) {
|
||||
// highest 16 bits of fp32
|
||||
float fp32_value = x->get_value();
|
||||
uint16_t bf16_raw = (*reinterpret_cast<uint32_t*>(&fp32_value)
|
||||
& 0xffff0000) >> 16;
|
||||
std::stringstream const_str;
|
||||
const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned
|
||||
InlineAsm *bf16_const = InlineAsm::get(FunctionType::get(bf16_ty, {}, false),
|
||||
" mov.b16 $0, " + const_str.str() + ";",
|
||||
"=h", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_const, {});
|
||||
} else
|
||||
vals_[x][idx] = ConstantFP::get(ty, x->get_value());
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
|
@@ -18,6 +18,8 @@ constant *constant::get_null_value(type *ty) {
|
||||
return constant_int::get(ty, 0);
|
||||
case type::FP16TyID:
|
||||
return constant_fp::get(type::get_fp16_ty(ctx), 0);
|
||||
case type::BF16TyID:
|
||||
return constant_fp::get(type::get_bf16_ty(ctx), 0);
|
||||
case type::FP32TyID:
|
||||
return constant_fp::get(type::get_fp32_ty(ctx), 0);
|
||||
case type::FP64TyID:
|
||||
|
@@ -33,27 +33,37 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h
|
||||
shape = (shape, )
|
||||
if rs is None:
|
||||
rs = RandomState(seed=17)
|
||||
dtype = getattr(np, dtype_str)
|
||||
if dtype_str in int_dtypes + uint_dtypes:
|
||||
iinfo = np.iinfo(getattr(np, dtype_str))
|
||||
low = iinfo.min if low is None else max(low, iinfo.min)
|
||||
high = iinfo.max if high is None else min(high, iinfo.max)
|
||||
dtype = getattr(np, dtype_str)
|
||||
x = rs.randint(low, high, shape, dtype=dtype)
|
||||
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
||||
return x
|
||||
elif dtype_str in float_dtypes:
|
||||
return rs.normal(0, 1, shape).astype(dtype)
|
||||
return rs.normal(0, 1, shape).astype(dtype_str)
|
||||
elif dtype_str == 'bfloat16':
|
||||
return (rs.normal(0, 1, shape).astype('float32').view('uint32')
|
||||
& np.uint32(0xffff0000)).view('float32')
|
||||
else:
|
||||
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
||||
|
||||
|
||||
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]:
|
||||
def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrapper, torch.Tensor]:
|
||||
'''
|
||||
Note: We need dst_type becasue the type of x can be different from dst_type.
|
||||
For example: x is of type `float32`, dst_type is `bfloat16`.
|
||||
If dst_type is None, we infer dst_type from x.
|
||||
'''
|
||||
t = x.dtype.name
|
||||
if t in uint_dtypes:
|
||||
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
|
||||
x_signed = x.astype(getattr(np, signed_type_name))
|
||||
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
|
||||
else:
|
||||
if t == 'float32' and dst_type == 'bfloat16':
|
||||
return torch.tensor(x, device=device).bfloat16()
|
||||
return torch.tensor(x, device=device)
|
||||
|
||||
|
||||
@@ -72,6 +82,8 @@ def to_numpy(x):
|
||||
if isinstance(x, TensorWrapper):
|
||||
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
|
||||
elif isinstance(x, torch.Tensor):
|
||||
if x.dtype is torch.bfloat16:
|
||||
return x.cpu().float().numpy()
|
||||
return x.cpu().numpy()
|
||||
else:
|
||||
raise ValueError(f"Not a triton-compatible tensor: {x}")
|
||||
@@ -84,19 +96,30 @@ def patch_kernel(template, to_replace):
|
||||
return kernel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
|
||||
def check_type_supported(dtype):
|
||||
'''
|
||||
skip test if dtype is not supported on the current device
|
||||
'''
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, SIZE: tl.constexpr):
|
||||
pass
|
||||
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device)
|
||||
check_type_supported(dtype_x)
|
||||
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)
|
||||
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
||||
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
check_type_supported(dtype_x) # early return if dtype_x is not supported
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@@ -115,8 +138,8 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
# reference result
|
||||
z_ref = eval(expr if numpy_expr is None else numpy_expr)
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device)
|
||||
z_tri = to_triton(np.empty_like(z_ref), device=device)
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
||||
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
|
||||
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
|
||||
# compare
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
@@ -154,6 +177,8 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
||||
|
||||
|
||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
||||
check_type_supported(dtype_x) # early return if dtype_x is not supported
|
||||
check_type_supported(dtype_y)
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@@ -180,8 +205,8 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
|
||||
if dtype_z is not None:
|
||||
z_ref = z_ref.astype(dtype_z)
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
||||
y_tri = to_triton(y, device=device, dst_type=dtype_y)
|
||||
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
|
||||
kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01)
|
||||
@@ -193,15 +218,20 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
# remainders than stock LLVM. We currently don't expect to match it
|
||||
# bit-for-bit.
|
||||
return (dtype_x, dtype_y) in [
|
||||
('int32', 'bfloat16'),
|
||||
('int32', 'float16'),
|
||||
('int32', 'float32'),
|
||||
('int64', 'bfloat16'),
|
||||
('int64', 'float16'),
|
||||
('int64', 'float32'),
|
||||
('int64', 'float64'),
|
||||
('uint16', 'bfloat16'),
|
||||
('uint16', 'float16'),
|
||||
('uint16', 'float32'),
|
||||
('uint32', 'bfloat16'),
|
||||
('uint32', 'float16'),
|
||||
('uint32', 'float32'),
|
||||
('uint64', 'bfloat16'),
|
||||
('uint64', 'float16'),
|
||||
('uint64', 'float32'),
|
||||
('uint64', 'float64'),
|
||||
@@ -215,15 +245,15 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['+', '-', '*', '/', '%']
|
||||
for dtype_x in dtypes
|
||||
for dtype_y in dtypes
|
||||
for dtype_x in dtypes + ['bfloat16']
|
||||
for dtype_y in dtypes + ['bfloat16']
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f' x {op} y'
|
||||
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
|
||||
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
||||
numpy_expr = 'np.fmod(x, y)'
|
||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'):
|
||||
# Triton promotes 16-bit floating-point / and % to 32-bit because there
|
||||
# are no native div or FRem operations on float16. Since we have to
|
||||
# convert anyway, we may as well take the accuracy bump.
|
||||
@@ -266,8 +296,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['&', '|', '^']
|
||||
for dtype_x in dtypes
|
||||
for dtype_y in dtypes
|
||||
for dtype_x in dtypes + ['bfloat16']
|
||||
for dtype_y in dtypes + ['bfloat16']
|
||||
])
|
||||
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
@@ -337,7 +367,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||
# test unary ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, expr", [
|
||||
(dtype_x, ' -x') for dtype_x in dtypes
|
||||
(dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16']
|
||||
] + [
|
||||
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
||||
])
|
||||
@@ -732,9 +762,10 @@ def test_f16_to_f8_rounding():
|
||||
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||
[(op, dtype, shape)
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for dtype in dtypes
|
||||
for dtype in dtypes + ['bfloat16']
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -752,9 +783,18 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
# numpy result
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
z_tri_dtype_str = z_dtype_str
|
||||
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
z_dtype_str = 'float32'
|
||||
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
# trunc mantissa for a fair comparison of accuracy
|
||||
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
||||
z_tri_dtype_str = 'bfloat16'
|
||||
else:
|
||||
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device)
|
||||
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
@@ -770,7 +810,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
|
||||
|
||||
reduce_configs1 = [
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16']
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for axis in [1]
|
||||
]
|
||||
@@ -805,11 +845,19 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
z_tri_dtype_str = z_dtype_str
|
||||
# numpy result
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
z_dtype_str = 'float32'
|
||||
z_tri_dtype_str = 'bfloat16'
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# trunc mantissa for a fair comparison of accuracy
|
||||
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
||||
else:
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device)
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
@@ -834,10 +882,11 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
[(dtype, shape, perm)
|
||||
for dtype in ['float16', 'float32']
|
||||
for dtype in ['bfloat16', 'float16', 'float32']
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -852,16 +901,16 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# input
|
||||
x = numpy_random(shape, dtype_str=dtype_str)
|
||||
# triton result
|
||||
z_tri = to_triton(np.empty_like(x), device=device)
|
||||
z_tri_contiguous = to_triton(np.empty_like(x), device=device)
|
||||
x_tri = to_triton(x, device=device)
|
||||
z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||
z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_str)
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
z_tri, z_tri.stride(1), z_tri.stride(0),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
|
||||
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
# torch result
|
||||
# numpy result
|
||||
z_ref = x.transpose(*perm)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
@@ -1038,8 +1087,10 @@ def test_arange(start, device='cuda'):
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
M = 32
|
||||
N = 32
|
||||
K = 16
|
||||
|
@@ -2,18 +2,22 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']
|
||||
for dtype in ['bfloat16', 'float16', 'float32']
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
def test_op(M, N, dtype, mode):
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and dtype == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
|
@@ -58,14 +58,22 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
|
||||
return tl.float32
|
||||
# 3 ) if one operand is half, the other is implicitly converted to half
|
||||
# unless we're doing / or %, which do not exist natively in PTX for fp16.
|
||||
# Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
|
||||
if a_ty.is_fp16() or b_ty.is_fp16():
|
||||
if div_or_mod:
|
||||
return tl.float32
|
||||
else:
|
||||
return tl.float16
|
||||
# 4) return bf16 only if both operands are of bf16
|
||||
if a_ty.is_bf16() or b_ty.is_bf16():
|
||||
if div_or_mod:
|
||||
return tl.float32
|
||||
if a_ty.is_bf16() and b_ty.is_bf16():
|
||||
return tl.bfloat16
|
||||
return tl.float32
|
||||
if not a_ty.is_int() or not b_ty.is_int():
|
||||
assert False
|
||||
# 4 ) both operands are integer and undergo
|
||||
# 5 ) both operands are integer and undergo
|
||||
# integer promotion
|
||||
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
|
||||
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
|
||||
@@ -768,16 +776,25 @@ def atomic_cas(ptr: tl.tensor,
|
||||
cmp: tl.tensor,
|
||||
val: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
# TODO: type checking
|
||||
element_ty = ptr.type.scalar.element_ty
|
||||
if element_ty.primitive_bitwidth not in [16, 32, 64]:
|
||||
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
|
||||
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type)
|
||||
|
||||
|
||||
def atom_red_typechecking_impl(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
op: str,
|
||||
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
|
||||
if not ptr.type.scalar.is_ptr():
|
||||
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
||||
|
||||
element_ty = ptr.type.scalar.element_ty
|
||||
if element_ty is tl.float16 and op != 'add':
|
||||
raise ValueError("atomic_" + op + " does not support fp16")
|
||||
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
|
||||
raise ValueError("atomic_" + op + " does not support " + element_ty)
|
||||
if ptr.type.is_block():
|
||||
if mask:
|
||||
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
|
||||
@@ -798,7 +815,7 @@ def atomic_max(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
|
||||
sca_ty = val.type.scalar
|
||||
# direct call to atomic_max for integers
|
||||
if sca_ty.is_int():
|
||||
@@ -830,7 +847,7 @@ def atomic_min(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
|
||||
sca_ty = val.type.scalar
|
||||
# direct call to atomic_min for integers
|
||||
if sca_ty.is_int():
|
||||
@@ -870,7 +887,7 @@ def atomic_add(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
|
||||
sca_ty = val.type.scalar
|
||||
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
|
||||
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type)
|
||||
@@ -880,7 +897,7 @@ def atomic_and(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type)
|
||||
|
||||
|
||||
@@ -888,7 +905,7 @@ def atomic_or(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type)
|
||||
|
||||
|
||||
@@ -896,7 +913,7 @@ def atomic_xor(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type)
|
||||
|
||||
|
||||
@@ -904,7 +921,7 @@ def atomic_xchg(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type)
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -978,6 +995,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
|
||||
input = cast(input, tl.int32, builder)
|
||||
|
||||
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
|
||||
if scalar_ty is tl.bfloat16:
|
||||
input = cast(input, tl.float32, builder)
|
||||
|
||||
# choose the right unsigned operation
|
||||
if scalar_ty.is_int_unsigned():
|
||||
int_op_to_unit = {
|
||||
|
@@ -65,7 +65,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||
# write result in-place in PROBS
|
||||
dout = tl.load(DPROBS + row)
|
||||
din = (probs - delta) * dout
|
||||
tl.store(PROBS, din.to(tl.float16), mask=cols < N)
|
||||
tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
|
||||
|
||||
|
||||
class _cross_entropy(torch.autograd.Function):
|
||||
|
Reference in New Issue
Block a user