[CODEGEN] Bugfixes with FP32 async copy
This commit is contained in:
@@ -1467,26 +1467,26 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
||||
}
|
||||
shared.push_back({tmp[key], off});
|
||||
}
|
||||
|
||||
size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
|
||||
auto idx = idxs_[arg][i];
|
||||
// input ptr info
|
||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
|
||||
Value *in_base = in_gep->getPointerOperand();
|
||||
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
||||
size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0;
|
||||
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0;
|
||||
in_base = cst ? in_base : in_gep;
|
||||
// output ptr info
|
||||
Value* out_base = shared[i].first;
|
||||
int out_off = shared[i].second*2;
|
||||
int out_off = shared[i].second*dtsize;
|
||||
// asm
|
||||
std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca";
|
||||
std::string mod = (in_vec*dtsize == 16) ? ".cg" : ".ca";
|
||||
// Value* false_value = vals_[x->get_false_value_operand()][idx];
|
||||
// bool is_zero_false_value = false;
|
||||
// if(Constant* cst = dyn_cast<Constant>(false_value))
|
||||
// is_zero_false_value = cst->isZeroValue();
|
||||
Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*2), i32(0));
|
||||
std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ", $2;";
|
||||
Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0));
|
||||
std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;";
|
||||
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType(), builder_->getInt32Ty()}, false);
|
||||
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true);
|
||||
call(iasm, {out_base, in_base, src_size});
|
||||
|
@@ -3,11 +3,12 @@ import triton
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"MODE, TRANS_A, TRANS_B, BLOCK",
|
||||
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
|
||||
for block in [16, 32, 64]],
|
||||
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE",
|
||||
[(mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"]
|
||||
for at in [False, True] for bt in [False, True] for block in [16, 32, 64]],
|
||||
)
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=128, N=256, K=384):
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
|
@@ -44,7 +44,7 @@ import torch
|
||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
]),
|
||||
)
|
||||
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
|
Reference in New Issue
Block a user