[BACKEND] Better bf16 support (#588)

This commit is contained in:
daadaada
2022-07-20 12:22:37 +08:00
committed by GitHub
parent 86cab58d89
commit 9b2bc88d11
6 changed files with 180 additions and 62 deletions

View File

@@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
// types // types
#define void_ty builder_->getVoidTy() #define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy() #define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getBFloatTy() #define bf16_ty builder_->getInt16Ty()
#define f32_ty builder_->getFloatTy() #define f32_ty builder_->getFloatTy()
#define i1_ty builder_->getInt1Ty() #define i1_ty builder_->getInt1Ty()
#define i8_ty builder_->getInt8Ty() #define i8_ty builder_->getInt8Ty()
@@ -178,7 +178,7 @@ Type *generator::cvt(ir::type *ty) {
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_); case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); // use int16 as storage type
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
@@ -378,8 +378,8 @@ void generator::visit_launch_inst(ir::launch_inst *launch) {
*/ */
void generator::visit_binary_operator(ir::binary_operator*x) { void generator::visit_binary_operator(ir::binary_operator*x) {
using ll = llvm::Instruction::BinaryOps; using ll = llvm::Instruction::BinaryOps;
using tt = ir::binary_op_t;
auto cvt = [](ir::binary_op_t op){ auto cvt = [](ir::binary_op_t op){
using tt = ir::binary_op_t;
switch(op) { switch(op) {
case tt::Add: return ll::Add; case tt::Add: return ll::Add;
case tt::FAdd: return ll::FAdd; case tt::FAdd: return ll::FAdd;
@@ -406,20 +406,51 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
for(indices_t idx: idxs_.at(x)){ for(indices_t idx: idxs_.at(x)){
Value *lhs = vals_[x->get_operand(0)][idx]; Value *lhs = vals_[x->get_operand(0)][idx];
Value *rhs = vals_[x->get_operand(1)][idx]; Value *rhs = vals_[x->get_operand(1)][idx];
auto op = cvt(x->get_op()); // manually select bf16 bin op
if(op == ll::Add) if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
vals_[x][idx] = add(lhs, rhs); assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
else if(op == ll::Mul) if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b
vals_[x][idx] = mul(lhs, rhs); InlineAsm *bf16_add_asm =
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() && InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
x->get_type()->get_scalar_ty()->is_fp32_ty()){ "{ .reg .b16 c; \n\t"
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false), " mov.b16 c, 0x3f80U; \n\t" // 1.0
" div.full.f32 $0, $1, $2;", "=r,r,r", false); " fma.rn.bf16 $0, $1, c, $2; } \n\t",
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs}); "=h,h,h", false);
vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
} else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
InlineAsm *bf16_sub_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
" { .reg .b16 c; \n\t"
" mov.b16 c, 0xbf80U; \n\t" // -1.0
" fma.rn.bf16 $0, $2, c, $1;} \n\t",
"=h,h,h", false);
vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
} else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
InlineAsm *bf16_mul_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
" { .reg .b16 c; \n\t"
" mov.b16 c, 0x8000U; \n\t" // 0.0
" fma.rn.bf16 $0, $1, $2, c;} \n\t",
"=h,h,h", false);
vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
} else
throw std::runtime_error("invalid bin op for bf16");
} else { // not bf16
auto op = cvt(x->get_op());
if(op == ll::Add)
vals_[x][idx] = add(lhs, rhs);
else if(op == ll::Mul)
vals_[x][idx] = mul(lhs, rhs);
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
x->get_type()->get_scalar_ty()->is_fp32_ty()){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
} }
else else
vals_[x][idx] = bin_op(op, lhs, rhs); vals_[x][idx] = bin_op(op, lhs, rhs);
}
} }
} }
@@ -970,8 +1001,6 @@ void generator::visit_store_inst(ir::store_inst * x){
has_l2_evict_policy = false; has_l2_evict_policy = false;
auto idxs = idxs_.at(val_op); auto idxs = idxs_.at(val_op);
Type *ty = cvt(val_op->get_type()->get_scalar_ty()); Type *ty = cvt(val_op->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
if(ty->isIntegerTy(1)) if(ty->isIntegerTy(1))
ty = builder_->getInt8Ty(); ty = builder_->getInt8Ty();
for(size_t i = 0; i < idxs.size(); i += vec){ for(size_t i = 0; i < idxs.size(); i += vec){
@@ -2830,9 +2859,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
// pointer to temporary shared memory // pointer to temporary shared memory
Type *ty = cvt(out->get_type()->get_scalar_ty()); Type *ty = cvt(out->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
// Orders // Orders
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in)); analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out)); analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
@@ -3229,8 +3255,22 @@ void generator::visit_constant_int(ir::constant_int *x){
void generator::visit_constant_fp(ir::constant_fp *x){ void generator::visit_constant_fp(ir::constant_fp *x){
Type *ty = cvt(x->get_type()->get_scalar_ty()); Type *ty = cvt(x->get_type()->get_scalar_ty());
for(indices_t idx: idxs_.at(x)) for(indices_t idx: idxs_.at(x)) {
vals_[x][idx] = ConstantFP::get(ty, x->get_value()); // manually select bf16 constant
if (x->get_type()->get_scalar_ty()->is_bf16_ty()) {
// highest 16 bits of fp32
float fp32_value = x->get_value();
uint16_t bf16_raw = (*reinterpret_cast<uint32_t*>(&fp32_value)
& 0xffff0000) >> 16;
std::stringstream const_str;
const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned
InlineAsm *bf16_const = InlineAsm::get(FunctionType::get(bf16_ty, {}, false),
" mov.b16 $0, " + const_str.str() + ";",
"=h", false);
vals_[x][idx] = builder_->CreateCall(bf16_const, {});
} else
vals_[x][idx] = ConstantFP::get(ty, x->get_value());
}
} }
void generator::visit_alloc_const(ir::alloc_const *alloc) { void generator::visit_alloc_const(ir::alloc_const *alloc) {

View File

@@ -18,6 +18,8 @@ constant *constant::get_null_value(type *ty) {
return constant_int::get(ty, 0); return constant_int::get(ty, 0);
case type::FP16TyID: case type::FP16TyID:
return constant_fp::get(type::get_fp16_ty(ctx), 0); return constant_fp::get(type::get_fp16_ty(ctx), 0);
case type::BF16TyID:
return constant_fp::get(type::get_bf16_ty(ctx), 0);
case type::FP32TyID: case type::FP32TyID:
return constant_fp::get(type::get_fp32_ty(ctx), 0); return constant_fp::get(type::get_fp32_ty(ctx), 0);
case type::FP64TyID: case type::FP64TyID:

View File

@@ -33,27 +33,37 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h
shape = (shape, ) shape = (shape, )
if rs is None: if rs is None:
rs = RandomState(seed=17) rs = RandomState(seed=17)
dtype = getattr(np, dtype_str)
if dtype_str in int_dtypes + uint_dtypes: if dtype_str in int_dtypes + uint_dtypes:
iinfo = np.iinfo(getattr(np, dtype_str)) iinfo = np.iinfo(getattr(np, dtype_str))
low = iinfo.min if low is None else max(low, iinfo.min) low = iinfo.min if low is None else max(low, iinfo.min)
high = iinfo.max if high is None else min(high, iinfo.max) high = iinfo.max if high is None else min(high, iinfo.max)
dtype = getattr(np, dtype_str)
x = rs.randint(low, high, shape, dtype=dtype) x = rs.randint(low, high, shape, dtype=dtype)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x return x
elif dtype_str in float_dtypes: elif dtype_str in float_dtypes:
return rs.normal(0, 1, shape).astype(dtype) return rs.normal(0, 1, shape).astype(dtype_str)
elif dtype_str == 'bfloat16':
return (rs.normal(0, 1, shape).astype('float32').view('uint32')
& np.uint32(0xffff0000)).view('float32')
else: else:
raise RuntimeError(f'Unknown dtype {dtype_str}') raise RuntimeError(f'Unknown dtype {dtype_str}')
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]: def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrapper, torch.Tensor]:
'''
Note: We need dst_type becasue the type of x can be different from dst_type.
For example: x is of type `float32`, dst_type is `bfloat16`.
If dst_type is None, we infer dst_type from x.
'''
t = x.dtype.name t = x.dtype.name
if t in uint_dtypes: if t in uint_dtypes:
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
x_signed = x.astype(getattr(np, signed_type_name)) x_signed = x.astype(getattr(np, signed_type_name))
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
else: else:
if t == 'float32' and dst_type == 'bfloat16':
return torch.tensor(x, device=device).bfloat16()
return torch.tensor(x, device=device) return torch.tensor(x, device=device)
@@ -72,6 +82,8 @@ def to_numpy(x):
if isinstance(x, TensorWrapper): if isinstance(x, TensorWrapper):
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
elif isinstance(x, torch.Tensor): elif isinstance(x, torch.Tensor):
if x.dtype is torch.bfloat16:
return x.cpu().float().numpy()
return x.cpu().numpy() return x.cpu().numpy()
else: else:
raise ValueError(f"Not a triton-compatible tensor: {x}") raise ValueError(f"Not a triton-compatible tensor: {x}")
@@ -84,19 +96,30 @@ def patch_kernel(template, to_replace):
return kernel return kernel
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) def check_type_supported(dtype):
'''
skip test if dtype is not supported on the current device
'''
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
def test_empty_kernel(dtype_x, device='cuda'): def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128 SIZE = 128
@triton.jit @triton.jit
def kernel(X, SIZE: tl.constexpr): def kernel(X, SIZE: tl.constexpr):
pass pass
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device) check_type_supported(dtype_x)
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)
kernel[(1, )](x, SIZE=SIZE, num_warps=4) kernel[(1, )](x, SIZE=SIZE, num_warps=4)
# generic test functions # generic test functions
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
check_type_supported(dtype_x) # early return if dtype_x is not supported
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@@ -115,8 +138,8 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
# reference result # reference result
z_ref = eval(expr if numpy_expr is None else numpy_expr) z_ref = eval(expr if numpy_expr is None else numpy_expr)
# triton result # triton result
x_tri = to_triton(x, device=device) x_tri = to_triton(x, device=device, dst_type=dtype_x)
z_tri = to_triton(np.empty_like(z_ref), device=device) z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4) kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
# compare # compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@@ -154,6 +177,8 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
check_type_supported(dtype_x) # early return if dtype_x is not supported
check_type_supported(dtype_y)
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@@ -180,8 +205,8 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
if dtype_z is not None: if dtype_z is not None:
z_ref = z_ref.astype(dtype_z) z_ref = z_ref.astype(dtype_z)
# triton result # triton result
x_tri = to_triton(x, device=device) x_tri = to_triton(x, device=device, dst_type=dtype_x)
y_tri = to_triton(y, device=device) y_tri = to_triton(y, device=device, dst_type=dtype_y)
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4) kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01) np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01)
@@ -193,15 +218,20 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# remainders than stock LLVM. We currently don't expect to match it # remainders than stock LLVM. We currently don't expect to match it
# bit-for-bit. # bit-for-bit.
return (dtype_x, dtype_y) in [ return (dtype_x, dtype_y) in [
('int32', 'bfloat16'),
('int32', 'float16'), ('int32', 'float16'),
('int32', 'float32'), ('int32', 'float32'),
('int64', 'bfloat16'),
('int64', 'float16'), ('int64', 'float16'),
('int64', 'float32'), ('int64', 'float32'),
('int64', 'float64'), ('int64', 'float64'),
('uint16', 'bfloat16'),
('uint16', 'float16'), ('uint16', 'float16'),
('uint16', 'float32'), ('uint16', 'float32'),
('uint32', 'bfloat16'),
('uint32', 'float16'), ('uint32', 'float16'),
('uint32', 'float32'), ('uint32', 'float32'),
('uint64', 'bfloat16'),
('uint64', 'float16'), ('uint64', 'float16'),
('uint64', 'float32'), ('uint64', 'float32'),
('uint64', 'float64'), ('uint64', 'float64'),
@@ -215,15 +245,15 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op) (dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%'] for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes for dtype_x in dtypes + ['bfloat16']
for dtype_y in dtypes for dtype_y in dtypes + ['bfloat16']
]) ])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'): def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y' expr = f' x {op} y'
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = 'np.fmod(x, y)' numpy_expr = 'np.fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'):
# Triton promotes 16-bit floating-point / and % to 32-bit because there # Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to # are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump. # convert anyway, we may as well take the accuracy bump.
@@ -266,8 +296,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op) (dtype_x, dtype_y, op)
for op in ['&', '|', '^'] for op in ['&', '|', '^']
for dtype_x in dtypes for dtype_x in dtypes + ['bfloat16']
for dtype_y in dtypes for dtype_y in dtypes + ['bfloat16']
]) ])
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y' expr = f'x {op} y'
@@ -337,7 +367,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
# test unary ops # test unary ops
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, expr", [ @pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, ' -x') for dtype_x in dtypes (dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16']
] + [ ] + [
(dtype_x, ' ~x') for dtype_x in int_dtypes (dtype_x, ' ~x') for dtype_x in int_dtypes
]) ])
@@ -732,9 +762,10 @@ def test_f16_to_f8_rounding():
@pytest.mark.parametrize("op, dtype_str, shape", @pytest.mark.parametrize("op, dtype_str, shape",
[(op, dtype, shape) [(op, dtype, shape)
for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for dtype in dtypes for dtype in dtypes + ['bfloat16']
for shape in [32, 64, 128, 512]]) for shape in [32, 64, 128, 512]])
def test_reduce1d(op, dtype_str, shape, device='cuda'): def test_reduce1d(op, dtype_str, shape, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
# triton kernel # triton kernel
@triton.jit @triton.jit
@@ -752,9 +783,18 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
'argmin': np.argmin, 'argmax': np.argmax}[op] 'argmin': np.argmin, 'argmax': np.argmax}[op]
# numpy result # numpy result
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) z_tri_dtype_str = z_dtype_str
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
z_dtype_str = 'float32'
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
# trunc mantissa for a fair comparison of accuracy
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
z_tri_dtype_str = 'bfloat16'
else:
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
# triton result # triton result
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device) z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
device=device, dst_type=z_tri_dtype_str)
kernel[(1,)](x_tri, z_tri, BLOCK=shape) kernel[(1,)](x_tri, z_tri, BLOCK=shape)
z_tri = to_numpy(z_tri) z_tri = to_numpy(z_tri)
# compare # compare
@@ -770,7 +810,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
reduce_configs1 = [ reduce_configs1 = [
(op, dtype, (1, 1024), axis) for dtype in dtypes (op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16']
for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for axis in [1] for axis in [1]
] ]
@@ -805,11 +845,19 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
'argmin': np.argmin, 'argmax': np.argmax}[op] 'argmin': np.argmin, 'argmax': np.argmax}[op]
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
z_tri_dtype_str = z_dtype_str
# numpy result # numpy result
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
z_dtype_str = 'float32'
z_tri_dtype_str = 'bfloat16'
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# trunc mantissa for a fair comparison of accuracy
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
else:
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result # triton result
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs), z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
device=device) device=device, dst_type=z_tri_dtype_str)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
z_tri = to_numpy(z_tri) z_tri = to_numpy(z_tri)
# compare # compare
@@ -834,10 +882,11 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
@pytest.mark.parametrize("dtype_str, shape, perm", @pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm) [(dtype, shape, perm)
for dtype in ['float16', 'float32'] for dtype in ['bfloat16', 'float16', 'float32']
for shape in [(64, 64), (128, 128)] for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]]) for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'): def test_permute(dtype_str, shape, perm, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
# triton kernel # triton kernel
@triton.jit @triton.jit
@@ -852,16 +901,16 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# input # input
x = numpy_random(shape, dtype_str=dtype_str) x = numpy_random(shape, dtype_str=dtype_str)
# triton result # triton result
z_tri = to_triton(np.empty_like(x), device=device) z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
z_tri_contiguous = to_triton(np.empty_like(x), device=device) z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
x_tri = to_triton(x, device=device) x_tri = to_triton(x, device=device, dst_type=dtype_str)
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
z_tri, z_tri.stride(1), z_tri.stride(0), z_tri, z_tri.stride(1), z_tri.stride(0),
BLOCK_M=shape[0], BLOCK_N=shape[1]) BLOCK_M=shape[0], BLOCK_N=shape[1])
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0), pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1), z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
BLOCK_M=shape[0], BLOCK_N=shape[1]) BLOCK_M=shape[0], BLOCK_N=shape[1])
# torch result # numpy result
z_ref = x.transpose(*perm) z_ref = x.transpose(*perm)
# compare # compare
triton.testing.assert_almost_equal(z_tri, z_ref) triton.testing.assert_almost_equal(z_tri, z_ref)
@@ -1038,8 +1087,10 @@ def test_arange(start, device='cuda'):
# Testing masked loads with an intermate copy to shared memory run. # Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'): def test_masked_load_shared_memory(dtype, device='cuda'):
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
M = 32 M = 32
N = 32 N = 32
K = 16 K = 16

View File

@@ -2,18 +2,22 @@ import pytest
import torch import torch
import triton import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize("M, N, dtype, mode", @pytest.mark.parametrize("M, N, dtype, mode",
[ [
(M, N, dtype, mode) for M in [1024, 821] (M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000] for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32'] for dtype in ['bfloat16', 'float16', 'float32']
for mode in ['forward', 'backward'] for mode in ['forward', 'backward']
] ]
) )
def test_op(M, N, dtype, mode): def test_op(M, N, dtype, mode):
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype] cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and dtype == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
# create inputs # create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True) x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda') idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')

View File

@@ -58,14 +58,22 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
return tl.float32 return tl.float32
# 3 ) if one operand is half, the other is implicitly converted to half # 3 ) if one operand is half, the other is implicitly converted to half
# unless we're doing / or %, which do not exist natively in PTX for fp16. # unless we're doing / or %, which do not exist natively in PTX for fp16.
# Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
if a_ty.is_fp16() or b_ty.is_fp16(): if a_ty.is_fp16() or b_ty.is_fp16():
if div_or_mod: if div_or_mod:
return tl.float32 return tl.float32
else: else:
return tl.float16 return tl.float16
# 4) return bf16 only if both operands are of bf16
if a_ty.is_bf16() or b_ty.is_bf16():
if div_or_mod:
return tl.float32
if a_ty.is_bf16() and b_ty.is_bf16():
return tl.bfloat16
return tl.float32
if not a_ty.is_int() or not b_ty.is_int(): if not a_ty.is_int() or not b_ty.is_int():
assert False assert False
# 4 ) both operands are integer and undergo # 5 ) both operands are integer and undergo
# integer promotion # integer promotion
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
@@ -768,16 +776,25 @@ def atomic_cas(ptr: tl.tensor,
cmp: tl.tensor, cmp: tl.tensor,
val: tl.tensor, val: tl.tensor,
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
# TODO: type checking element_ty = ptr.type.scalar.element_ty
if element_ty.primitive_bitwidth not in [16, 32, 64]:
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type) return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type)
def atom_red_typechecking_impl(ptr: tl.tensor, def atom_red_typechecking_impl(ptr: tl.tensor,
val: tl.tensor, val: tl.tensor,
mask: tl.tensor, mask: tl.tensor,
op: str,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
if not ptr.type.scalar.is_ptr(): if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
element_ty = ptr.type.scalar.element_ty
if element_ty is tl.float16 and op != 'add':
raise ValueError("atomic_" + op + " does not support fp16")
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
raise ValueError("atomic_" + op + " does not support " + element_ty)
if ptr.type.is_block(): if ptr.type.is_block():
if mask: if mask:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
@@ -798,7 +815,7 @@ def atomic_max(ptr: tl.tensor,
val: tl.tensor, val: tl.tensor,
mask: tl.tensor, mask: tl.tensor,
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
sca_ty = val.type.scalar sca_ty = val.type.scalar
# direct call to atomic_max for integers # direct call to atomic_max for integers
if sca_ty.is_int(): if sca_ty.is_int():
@@ -830,7 +847,7 @@ def atomic_min(ptr: tl.tensor,
val: tl.tensor, val: tl.tensor,
mask: tl.tensor, mask: tl.tensor,
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
sca_ty = val.type.scalar sca_ty = val.type.scalar
# direct call to atomic_min for integers # direct call to atomic_min for integers
if sca_ty.is_int(): if sca_ty.is_int():
@@ -870,7 +887,7 @@ def atomic_add(ptr: tl.tensor,
val: tl.tensor, val: tl.tensor,
mask: tl.tensor, mask: tl.tensor,
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
sca_ty = val.type.scalar sca_ty = val.type.scalar
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type) return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type)
@@ -880,7 +897,7 @@ def atomic_and(ptr: tl.tensor,
val: tl.tensor, val: tl.tensor,
mask: tl.tensor, mask: tl.tensor,
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type)
@@ -888,7 +905,7 @@ def atomic_or(ptr: tl.tensor,
val: tl.tensor, val: tl.tensor,
mask: tl.tensor, mask: tl.tensor,
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type)
@@ -896,7 +913,7 @@ def atomic_xor(ptr: tl.tensor,
val: tl.tensor, val: tl.tensor,
mask: tl.tensor, mask: tl.tensor,
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type)
@@ -904,7 +921,7 @@ def atomic_xchg(ptr: tl.tensor,
val: tl.tensor, val: tl.tensor,
mask: tl.tensor, mask: tl.tensor,
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type)
# ===----------------------------------------------------------------------===// # ===----------------------------------------------------------------------===//
@@ -978,6 +995,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
input = cast(input, tl.int32, builder) input = cast(input, tl.int32, builder)
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
if scalar_ty is tl.bfloat16:
input = cast(input, tl.float32, builder)
# choose the right unsigned operation # choose the right unsigned operation
if scalar_ty.is_int_unsigned(): if scalar_ty.is_int_unsigned():
int_op_to_unit = { int_op_to_unit = {

View File

@@ -65,7 +65,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
# write result in-place in PROBS # write result in-place in PROBS
dout = tl.load(DPROBS + row) dout = tl.load(DPROBS + row)
din = (probs - delta) * dout din = (probs - delta) * dout
tl.store(PROBS, din.to(tl.float16), mask=cols < N) tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
class _cross_entropy(torch.autograd.Function): class _cross_entropy(torch.autograd.Function):