[CODEGEN] Bugfixes with FP32 async copy

This commit is contained in:
Philippe Tillet
2021-02-24 13:36:26 -05:00
parent 11215f0f03
commit 567a1a3d17
3 changed files with 12 additions and 11 deletions

View File

@@ -1467,26 +1467,26 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
}
shared.push_back({tmp[key], off});
}
size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
auto idx = idxs_[arg][i];
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
Value *in_base = in_gep->getPointerOperand();
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0;
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0;
in_base = cst ? in_base : in_gep;
// output ptr info
Value* out_base = shared[i].first;
int out_off = shared[i].second*2;
int out_off = shared[i].second*dtsize;
// asm
std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca";
std::string mod = (in_vec*dtsize == 16) ? ".cg" : ".ca";
// Value* false_value = vals_[x->get_false_value_operand()][idx];
// bool is_zero_false_value = false;
// if(Constant* cst = dyn_cast<Constant>(false_value))
// is_zero_false_value = cst->isZeroValue();
Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*2), i32(0));
std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ", $2;";
Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0));
std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;";
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType(), builder_->getInt32Ty()}, false);
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true);
call(iasm, {out_base, in_base, src_size});

View File

@@ -3,11 +3,12 @@ import triton
import pytest
@pytest.mark.parametrize(
"MODE, TRANS_A, TRANS_B, BLOCK",
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
for block in [16, 32, 64]],
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE",
[(mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"]
for at in [False, True] for bt in [False, True] for block in [16, 32, 64]],
)
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=128, N=256, K=384):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
# set seed
torch.random.manual_seed(0)
# create inputs

View File

@@ -44,7 +44,7 @@ import torch
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
]),
)
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):