diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index b30283ced..31ecfacba 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ // types #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() -#define bf16_ty builder_->getBFloatTy() +#define bf16_ty builder_->getInt16Ty() #define f32_ty builder_->getFloatTy() #define i1_ty builder_->getInt1Ty() #define i8_ty builder_->getInt8Ty() @@ -178,7 +178,7 @@ Type *generator::cvt(ir::type *ty) { case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); - case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_); + case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); // use int16 as storage type case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); @@ -378,8 +378,8 @@ void generator::visit_launch_inst(ir::launch_inst *launch) { */ void generator::visit_binary_operator(ir::binary_operator*x) { using ll = llvm::Instruction::BinaryOps; + using tt = ir::binary_op_t; auto cvt = [](ir::binary_op_t op){ - using tt = ir::binary_op_t; switch(op) { case tt::Add: return ll::Add; case tt::FAdd: return ll::FAdd; @@ -406,20 +406,51 @@ void generator::visit_binary_operator(ir::binary_operator*x) { for(indices_t idx: idxs_.at(x)){ Value *lhs = vals_[x->get_operand(0)][idx]; Value *rhs = vals_[x->get_operand(1)][idx]; - auto op = cvt(x->get_op()); - if(op == ll::Add) - vals_[x][idx] = add(lhs, rhs); - else if(op == ll::Mul) - vals_[x][idx] = mul(lhs, rhs); - else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() && - x->get_type()->get_scalar_ty()->is_fp32_ty()){ - InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false), - " div.full.f32 $0, $1, $2;", "=r,r,r", false); - vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs}); + // manually select bf16 bin op + if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) { + assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty()); + if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b + InlineAsm *bf16_add_asm = + InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), + "{ .reg .b16 c; \n\t" + " mov.b16 c, 0x3f80U; \n\t" // 1.0 + " fma.rn.bf16 $0, $1, c, $2; } \n\t", + "=h,h,h", false); + vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs}); + } else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a + InlineAsm *bf16_sub_asm = + InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), + " { .reg .b16 c; \n\t" + " mov.b16 c, 0xbf80U; \n\t" // -1.0 + " fma.rn.bf16 $0, $2, c, $1;} \n\t", + "=h,h,h", false); + vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs}); + } else if (x->get_op() == tt::FMul) { // a * b = a*b + 0 + InlineAsm *bf16_mul_asm = + InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), + " { .reg .b16 c; \n\t" + " mov.b16 c, 0x8000U; \n\t" // 0.0 + " fma.rn.bf16 $0, $1, $2, c;} \n\t", + "=h,h,h", false); + vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs}); + } else + throw std::runtime_error("invalid bin op for bf16"); + } else { // not bf16 + auto op = cvt(x->get_op()); + if(op == ll::Add) + vals_[x][idx] = add(lhs, rhs); + else if(op == ll::Mul) + vals_[x][idx] = mul(lhs, rhs); + else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() && + x->get_type()->get_scalar_ty()->is_fp32_ty()){ + InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false), + " div.full.f32 $0, $1, $2;", "=r,r,r", false); + vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs}); - } - else - vals_[x][idx] = bin_op(op, lhs, rhs); + } + else + vals_[x][idx] = bin_op(op, lhs, rhs); + } } } @@ -970,8 +1001,6 @@ void generator::visit_store_inst(ir::store_inst * x){ has_l2_evict_policy = false; auto idxs = idxs_.at(val_op); Type *ty = cvt(val_op->get_type()->get_scalar_ty()); - if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store - ty = f16_ty; if(ty->isIntegerTy(1)) ty = builder_->getInt8Ty(); for(size_t i = 0; i < idxs.size(); i += vec){ @@ -2830,9 +2859,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ // pointer to temporary shared memory Type *ty = cvt(out->get_type()->get_scalar_ty()); - if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store - ty = f16_ty; - // Orders analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); @@ -3229,8 +3255,22 @@ void generator::visit_constant_int(ir::constant_int *x){ void generator::visit_constant_fp(ir::constant_fp *x){ Type *ty = cvt(x->get_type()->get_scalar_ty()); - for(indices_t idx: idxs_.at(x)) - vals_[x][idx] = ConstantFP::get(ty, x->get_value()); + for(indices_t idx: idxs_.at(x)) { + // manually select bf16 constant + if (x->get_type()->get_scalar_ty()->is_bf16_ty()) { + // highest 16 bits of fp32 + float fp32_value = x->get_value(); + uint16_t bf16_raw = (*reinterpret_cast(&fp32_value) + & 0xffff0000) >> 16; + std::stringstream const_str; + const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned + InlineAsm *bf16_const = InlineAsm::get(FunctionType::get(bf16_ty, {}, false), + " mov.b16 $0, " + const_str.str() + ";", + "=h", false); + vals_[x][idx] = builder_->CreateCall(bf16_const, {}); + } else + vals_[x][idx] = ConstantFP::get(ty, x->get_value()); + } } void generator::visit_alloc_const(ir::alloc_const *alloc) { diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc index ab1f6f497..417626c92 100644 --- a/lib/ir/constant.cc +++ b/lib/ir/constant.cc @@ -18,6 +18,8 @@ constant *constant::get_null_value(type *ty) { return constant_int::get(ty, 0); case type::FP16TyID: return constant_fp::get(type::get_fp16_ty(ctx), 0); + case type::BF16TyID: + return constant_fp::get(type::get_bf16_ty(ctx), 0); case type::FP32TyID: return constant_fp::get(type::get_fp32_ty(ctx), 0); case type::FP64TyID: diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index cb2cb9c33..561ed6af5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -33,27 +33,37 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h shape = (shape, ) if rs is None: rs = RandomState(seed=17) - dtype = getattr(np, dtype_str) if dtype_str in int_dtypes + uint_dtypes: iinfo = np.iinfo(getattr(np, dtype_str)) low = iinfo.min if low is None else max(low, iinfo.min) high = iinfo.max if high is None else min(high, iinfo.max) + dtype = getattr(np, dtype_str) x = rs.randint(low, high, shape, dtype=dtype) x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. return x elif dtype_str in float_dtypes: - return rs.normal(0, 1, shape).astype(dtype) + return rs.normal(0, 1, shape).astype(dtype_str) + elif dtype_str == 'bfloat16': + return (rs.normal(0, 1, shape).astype('float32').view('uint32') + & np.uint32(0xffff0000)).view('float32') else: raise RuntimeError(f'Unknown dtype {dtype_str}') -def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]: +def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrapper, torch.Tensor]: + ''' + Note: We need dst_type becasue the type of x can be different from dst_type. + For example: x is of type `float32`, dst_type is `bfloat16`. + If dst_type is None, we infer dst_type from x. + ''' t = x.dtype.name if t in uint_dtypes: signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" x_signed = x.astype(getattr(np, signed_type_name)) return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) else: + if t == 'float32' and dst_type == 'bfloat16': + return torch.tensor(x, device=device).bfloat16() return torch.tensor(x, device=device) @@ -72,6 +82,8 @@ def to_numpy(x): if isinstance(x, TensorWrapper): return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) elif isinstance(x, torch.Tensor): + if x.dtype is torch.bfloat16: + return x.cpu().float().numpy() return x.cpu().numpy() else: raise ValueError(f"Not a triton-compatible tensor: {x}") @@ -84,19 +96,30 @@ def patch_kernel(template, to_replace): return kernel -@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) +def check_type_supported(dtype): + ''' + skip test if dtype is not supported on the current device + ''' + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + + +@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"]) def test_empty_kernel(dtype_x, device='cuda'): SIZE = 128 @triton.jit def kernel(X, SIZE: tl.constexpr): pass - x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device) + check_type_supported(dtype_x) + x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) kernel[(1, )](x, SIZE=SIZE, num_warps=4) # generic test functions def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): + check_type_supported(dtype_x) # early return if dtype_x is not supported SIZE = 128 # define the kernel / launch-grid @@ -115,8 +138,8 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): # reference result z_ref = eval(expr if numpy_expr is None else numpy_expr) # triton result - x_tri = to_triton(x, device=device) - z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4) # compare np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) @@ -154,6 +177,8 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): + check_type_supported(dtype_x) # early return if dtype_x is not supported + check_type_supported(dtype_y) SIZE = 128 # define the kernel / launch-grid @@ -180,8 +205,8 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y= if dtype_z is not None: z_ref = z_ref.astype(dtype_z) # triton result - x_tri = to_triton(x, device=device) - y_tri = to_triton(y, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4) np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01) @@ -193,15 +218,20 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: # remainders than stock LLVM. We currently don't expect to match it # bit-for-bit. return (dtype_x, dtype_y) in [ + ('int32', 'bfloat16'), ('int32', 'float16'), ('int32', 'float32'), + ('int64', 'bfloat16'), ('int64', 'float16'), ('int64', 'float32'), ('int64', 'float64'), + ('uint16', 'bfloat16'), ('uint16', 'float16'), ('uint16', 'float32'), + ('uint32', 'bfloat16'), ('uint32', 'float16'), ('uint32', 'float32'), + ('uint64', 'bfloat16'), ('uint64', 'float16'), ('uint64', 'float32'), ('uint64', 'float64'), @@ -215,15 +245,15 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: @pytest.mark.parametrize("dtype_x, dtype_y, op", [ (dtype_x, dtype_y, op) for op in ['+', '-', '*', '/', '%'] - for dtype_x in dtypes - for dtype_y in dtypes + for dtype_x in dtypes + ['bfloat16'] + for dtype_y in dtypes + ['bfloat16'] ]) def test_bin_op(dtype_x, dtype_y, op, device='cuda'): expr = f' x {op} y' if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. numpy_expr = 'np.fmod(x, y)' - elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): + elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'): # Triton promotes 16-bit floating-point / and % to 32-bit because there # are no native div or FRem operations on float16. Since we have to # convert anyway, we may as well take the accuracy bump. @@ -266,8 +296,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'): @pytest.mark.parametrize("dtype_x, dtype_y, op", [ (dtype_x, dtype_y, op) for op in ['&', '|', '^'] - for dtype_x in dtypes - for dtype_y in dtypes + for dtype_x in dtypes + ['bfloat16'] + for dtype_y in dtypes + ['bfloat16'] ]) def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): expr = f'x {op} y' @@ -337,7 +367,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): # test unary ops # --------------- @pytest.mark.parametrize("dtype_x, expr", [ - (dtype_x, ' -x') for dtype_x in dtypes + (dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16'] ] + [ (dtype_x, ' ~x') for dtype_x in int_dtypes ]) @@ -732,9 +762,10 @@ def test_f16_to_f8_rounding(): @pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in ['min', 'max', 'argmin', 'argmax', 'sum'] - for dtype in dtypes + for dtype in dtypes + ['bfloat16'] for shape in [32, 64, 128, 512]]) def test_reduce1d(op, dtype_str, shape, device='cuda'): + check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested # triton kernel @triton.jit @@ -752,9 +783,18 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'): 'argmin': np.argmin, 'argmax': np.argmax}[op] # numpy result z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str - z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + z_tri_dtype_str = z_dtype_str + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + z_tri_dtype_str = 'bfloat16' + else: + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) # triton result - z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device) + z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), + device=device, dst_type=z_tri_dtype_str) kernel[(1,)](x_tri, z_tri, BLOCK=shape) z_tri = to_numpy(z_tri) # compare @@ -770,7 +810,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'): reduce_configs1 = [ - (op, dtype, (1, 1024), axis) for dtype in dtypes + (op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16'] for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for axis in [1] ] @@ -805,11 +845,19 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax}[op] z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str + z_tri_dtype_str = z_dtype_str # numpy result - z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_tri_dtype_str = 'bfloat16' + z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton result z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs), - device=device) + device=device, dst_type=z_tri_dtype_str) kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) z_tri = to_numpy(z_tri) # compare @@ -834,10 +882,11 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): @pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) - for dtype in ['float16', 'float32'] + for dtype in ['bfloat16', 'float16', 'float32'] for shape in [(64, 64), (128, 128)] for perm in [(1, 0)]]) def test_permute(dtype_str, shape, perm, device='cuda'): + check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested # triton kernel @triton.jit @@ -852,16 +901,16 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # input x = numpy_random(shape, dtype_str=dtype_str) # triton result - z_tri = to_triton(np.empty_like(x), device=device) - z_tri_contiguous = to_triton(np.empty_like(x), device=device) - x_tri = to_triton(x, device=device) + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + x_tri = to_triton(x, device=device, dst_type=dtype_str) pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), BLOCK_M=shape[0], BLOCK_N=shape[1]) pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1]) - # torch result + # numpy result z_ref = x.transpose(*perm) # compare triton.testing.assert_almost_equal(z_tri, z_ref) @@ -1038,8 +1087,10 @@ def test_arange(start, device='cuda'): # Testing masked loads with an intermate copy to shared memory run. -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device='cuda'): + check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested + M = 32 N = 32 K = 16 diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index 08516257b..e28db4815 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -2,18 +2,22 @@ import pytest import torch import triton +import triton._C.libtriton.triton as _triton @pytest.mark.parametrize("M, N, dtype, mode", [ (M, N, dtype, mode) for M in [1024, 821] for N in [512, 857, 1871, 2089, 8573, 31000] - for dtype in ['float16', 'float32'] + for dtype in ['bfloat16', 'float16', 'float32'] for mode in ['forward', 'backward'] ] ) def test_op(M, N, dtype, mode): - dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype] + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80 and dtype == "bfloat16": + pytest.skip("Only test bfloat16 on devices with sm >= 80") + dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype] # create inputs x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True) idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda') diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index a31fec384..8878a8195 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -58,14 +58,22 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t return tl.float32 # 3 ) if one operand is half, the other is implicitly converted to half # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp if a_ty.is_fp16() or b_ty.is_fp16(): if div_or_mod: return tl.float32 else: return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() or b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + if a_ty.is_bf16() and b_ty.is_bf16(): + return tl.bfloat16 + return tl.float32 if not a_ty.is_int() or not b_ty.is_int(): assert False - # 4 ) both operands are integer and undergo + # 5 ) both operands are integer and undergo # integer promotion if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" @@ -768,16 +776,25 @@ def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, builder: ir.builder) -> tl.tensor: - # TODO: type checking + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type) def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, + op: str, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: if not ptr.type.scalar.is_ptr(): raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]: + raise ValueError("atomic_" + op + " does not support " + element_ty) if ptr.type.is_block(): if mask: mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) @@ -798,7 +815,7 @@ def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) sca_ty = val.type.scalar # direct call to atomic_max for integers if sca_ty.is_int(): @@ -830,7 +847,7 @@ def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) sca_ty = val.type.scalar # direct call to atomic_min for integers if sca_ty.is_int(): @@ -870,7 +887,7 @@ def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) sca_ty = val.type.scalar op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type) @@ -880,7 +897,7 @@ def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type) @@ -888,7 +905,7 @@ def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type) @@ -896,7 +913,7 @@ def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type) @@ -904,7 +921,7 @@ def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type) # ===----------------------------------------------------------------------===// @@ -978,6 +995,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: input = cast(input, tl.int32, builder) + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is tl.bfloat16: + input = cast(input, tl.float32, builder) + # choose the right unsigned operation if scalar_ty.is_int_unsigned(): int_op_to_unit = { diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 910417d2c..63ce81074 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -65,7 +65,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): # write result in-place in PROBS dout = tl.load(DPROBS + row) din = (probs - delta) * dout - tl.store(PROBS, din.to(tl.float16), mask=cols < N) + tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N) class _cross_entropy(torch.autograd.Function):