diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a9eeb2d55..a51ffc645 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1467,26 +1467,26 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ } shared.push_back({tmp[key], off}); } - + size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){ auto idx = idxs_[arg][i]; // input ptr info GetElementPtrInst *in_gep = dyn_cast(vals_[arg][idx]); Value *in_base = in_gep->getPointerOperand(); ConstantInt* cst = dyn_cast(in_gep->idx_begin()); - size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0; + size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0; in_base = cst ? in_base : in_gep; // output ptr info Value* out_base = shared[i].first; - int out_off = shared[i].second*2; + int out_off = shared[i].second*dtsize; // asm - std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca"; + std::string mod = (in_vec*dtsize == 16) ? ".cg" : ".ca"; // Value* false_value = vals_[x->get_false_value_operand()][idx]; // bool is_zero_false_value = false; // if(Constant* cst = dyn_cast(false_value)) // is_zero_false_value = cst->isZeroValue(); - Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*2), i32(0)); - std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ", $2;"; + Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0)); + std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;"; FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType(), builder_->getInt32Ty()}, false); InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true); call(iasm, {out_base, in_base, src_size}); diff --git a/python/test/test_blocksparse.py b/python/test/test_blocksparse.py index 70ae1b8ce..fde63b29d 100644 --- a/python/test/test_blocksparse.py +++ b/python/test/test_blocksparse.py @@ -3,11 +3,12 @@ import triton import pytest @pytest.mark.parametrize( - "MODE, TRANS_A, TRANS_B, BLOCK", - [(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True] - for block in [16, 32, 64]], + "MODE, TRANS_A, TRANS_B, BLOCK, DTYPE", + [(mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"] + for at in [False, True] for bt in [False, True] for block in [16, 32, 64]], ) -def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384): +def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=128, N=256, K=384): + DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] # set seed torch.random.manual_seed(0) # create inputs diff --git a/python/test/test_matmul.py b/python/test/test_matmul.py index b3ee58370..aaf738acc 100644 --- a/python/test/test_matmul.py +++ b/python/test/test_matmul.py @@ -44,7 +44,7 @@ import torch (128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE), (128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE), (128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE), - ] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True] + ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] ]), ) def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):