fix 6/7 dot tests
This commit is contained in:
@@ -2973,19 +2973,20 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
|||||||
size_t red_axis = 1;
|
size_t red_axis = 1;
|
||||||
unsigned NK = A_shapes[red_axis];
|
unsigned NK = A_shapes[red_axis];
|
||||||
bool is_outer = NK == 1;
|
bool is_outer = NK == 1;
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
bool is_mma = false;
|
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||||
#else
|
#else
|
||||||
bool is_mma = layouts_->get(dot)->to_mma();
|
bool is_mma = layouts_->get(dot)->to_mma();
|
||||||
#endif
|
|
||||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
|
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
|
||||||
return visit_mma884(dot, A, B, D, NK);
|
return visit_mma884(dot, A, B, D, NK);
|
||||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||||
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
||||||
// if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
|
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
|
||||||
// A->get_type()->get_scalar_ty()->is_fp32_ty())
|
A->get_type()->get_scalar_ty()->is_fp32_ty())
|
||||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||||
// throw std::runtime_error("dot has invalid operand type");
|
throw std::runtime_error("dot has invalid operand type");
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_trans_inst(ir::trans_inst* trans) {
|
void generator::visit_trans_inst(ir::trans_inst* trans) {
|
||||||
|
@@ -1067,7 +1067,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
[(epilogue, allow_tf32, dtype)
|
[(epilogue, allow_tf32, dtype)
|
||||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||||
for allow_tf32 in [True, False]
|
for allow_tf32 in [True, False]
|
||||||
for dtype in ['float32', 'float16']
|
for dtype in ['float16']
|
||||||
if not (allow_tf32 and (dtype in ['float16']))])
|
if not (allow_tf32 and (dtype in ['float16']))])
|
||||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
|
@@ -6,6 +6,8 @@ from . import core as tl
|
|||||||
from triton._C.libtriton.triton import ir
|
from triton._C.libtriton.triton import ir
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
# Create custom exception that prints message "hello"
|
# Create custom exception that prints message "hello"
|
||||||
class IncompatibleTypeErrorimpl(Exception):
|
class IncompatibleTypeErrorimpl(Exception):
|
||||||
def __init__(self, type_a, type_b):
|
def __init__(self, type_a, type_b):
|
||||||
@@ -969,6 +971,11 @@ def dot(a: tl.tensor,
|
|||||||
trans_b: bool,
|
trans_b: bool,
|
||||||
allow_tf32: bool,
|
allow_tf32: bool,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
|
|
||||||
|
if torch.version.hip is not None:
|
||||||
|
a = cast(a, tl.float32, builder)
|
||||||
|
b = cast(b, tl.float32, builder)
|
||||||
|
|
||||||
in_a = 1 if not trans_a else 0
|
in_a = 1 if not trans_a else 0
|
||||||
in_b = 1 if trans_b else 0
|
in_b = 1 if trans_b else 0
|
||||||
assert a.type.is_block() and b.type.is_block()
|
assert a.type.is_block() and b.type.is_block()
|
||||||
|
Reference in New Issue
Block a user