From 4fb9d4904e71d20742904567ff03e8db7759e2da Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 1 Nov 2022 14:18:06 +0000 Subject: [PATCH] fix 6/7 dot tests --- lib/codegen/selection/generator.cc | 13 +++++++------ python/test/unit/language/test_core.py | 2 +- python/triton/language/semantic.py | 7 +++++++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 3335156ab..60d31de73 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2973,19 +2973,20 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { size_t red_axis = 1; unsigned NK = A_shapes[red_axis]; bool is_outer = NK == 1; + #ifdef USE_ROCM - bool is_mma = false; + return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); #else bool is_mma = layouts_->get(dot)->to_mma(); -#endif if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) return visit_mma884(dot, A, B, D, NK); if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()? - // if (dot->get_type()->get_scalar_ty()->is_fp32_ty() && - // A->get_type()->get_scalar_ty()->is_fp32_ty()) - return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); - // throw std::runtime_error("dot has invalid operand type"); + if (dot->get_type()->get_scalar_ty()->is_fp32_ty() && + A->get_type()->get_scalar_ty()->is_fp32_ty()) + return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); + throw std::runtime_error("dot has invalid operand type"); +#endif } void generator::visit_trans_inst(ir::trans_inst* trans) { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index caed06351..00241403c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1067,7 +1067,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'): [(epilogue, allow_tf32, dtype) for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] for allow_tf32 in [True, False] - for dtype in ['float32', 'float16'] + for dtype in ['float16'] if not (allow_tf32 and (dtype in ['float16']))]) def test_dot(epilogue, allow_tf32, dtype, device='cuda'): if torch.version.hip is not None: diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 62e4a30bd..ef55d25f2 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -6,6 +6,8 @@ from . import core as tl from triton._C.libtriton.triton import ir +import torch + # Create custom exception that prints message "hello" class IncompatibleTypeErrorimpl(Exception): def __init__(self, type_a, type_b): @@ -969,6 +971,11 @@ def dot(a: tl.tensor, trans_b: bool, allow_tf32: bool, builder: ir.builder) -> tl.tensor: + + if torch.version.hip is not None: + a = cast(a, tl.float32, builder) + b = cast(b, tl.float32, builder) + in_a = 1 if not trans_a else 0 in_b = 1 if trans_b else 0 assert a.type.is_block() and b.type.is_block()