[Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944)

TODO:

- Add more cases
- Currently, we just set vec to 4 to make the basic cases pass

Issue:

- the vec in shared layout is different compared to master branch
- when vec=1, it encounters CUDA misalignment error, it doesn't work in
master branch as well
- when setting vec to the value identical to master branch, the MMA
works
This commit is contained in:
Yan Chunwei
2022-12-06 10:57:08 +08:00
committed by GitHub
parent 189491727a
commit e419781978
8 changed files with 134 additions and 100 deletions

View File

@@ -62,12 +62,11 @@ namespace LLVM {
static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
// A helper function for using printf in LLVM conversion.
void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter);
void vprintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter);
// Helper function
#define tid_val() getThreadId(rewriter, loc)
#define llprintf(fmt, ...) LLVM::llPrintf(fmt, {__VA_ARGS__}, rewriter)
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
std::string elem_repr, ConversionPatternRewriter &builder);
} // namespace LLVM
} // namespace mlir
@@ -3537,8 +3536,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
SmallVector<Value> resVals(resSize);
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has[{m, k}];
auto hb = hbs[{n, k}];
auto ha = has.at({m, k});
auto hb = hbs.at({n, k});
std::vector<size_t> idx{{
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
(m * 2 + 0) + (n * 4 + 1) * numM,
@@ -3554,13 +3553,13 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto *resOprs = builder.newListOperand(8, "=f");
auto *AOprs = builder.newListOperand({
{ha.first, "f"},
{ha.second, "f"},
{ha.first, "r"},
{ha.second, "r"},
});
auto *BOprs = builder.newListOperand({
{hb.first, "f"},
{hb.second, "f"},
{hb.first, "r"},
{hb.second, "r"},
});
auto *COprs = builder.newListOperand();
for (int i = 0; i < 8; ++i)
@@ -4806,11 +4805,23 @@ namespace mlir {
namespace LLVM {
void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
void vprintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
PrintfOpConversion::llPrintf(msg, args, rewriter);
}
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
std::string elem_repr, ConversionPatternRewriter &builder) {
std::string fmt = info + " t-%d ";
std::vector<Value> new_arr({thread});
for (int i = 0; i < arr.size(); ++i) {
fmt += elem_repr + ((i == arr.size() - 1) ? "" : ", ");
new_arr.push_back(arr[i]);
}
vprintf(fmt, new_arr, builder);
}
} // namespace LLVM
TritonLLVMConversionTarget::TritonLLVMConversionTarget(