[Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944)
TODO: - Add more cases - Currently, we just set vec to 4 to make the basic cases pass Issue: - the vec in shared layout is different compared to master branch - when vec=1, it encounters CUDA misalignment error, it doesn't work in master branch as well - when setting vec to the value identical to master branch, the MMA works
This commit is contained in:
@@ -62,12 +62,11 @@ namespace LLVM {
|
||||
static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
||||
|
||||
// A helper function for using printf in LLVM conversion.
|
||||
void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
// Helper function
|
||||
#define tid_val() getThreadId(rewriter, loc)
|
||||
#define llprintf(fmt, ...) LLVM::llPrintf(fmt, {__VA_ARGS__}, rewriter)
|
||||
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
std::string elem_repr, ConversionPatternRewriter &builder);
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
@@ -3537,8 +3536,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
SmallVector<Value> resVals(resSize);
|
||||
|
||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||
auto ha = has[{m, k}];
|
||||
auto hb = hbs[{n, k}];
|
||||
auto ha = has.at({m, k});
|
||||
auto hb = hbs.at({n, k});
|
||||
std::vector<size_t> idx{{
|
||||
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
|
||||
(m * 2 + 0) + (n * 4 + 1) * numM,
|
||||
@@ -3554,13 +3553,13 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
|
||||
auto *resOprs = builder.newListOperand(8, "=f");
|
||||
auto *AOprs = builder.newListOperand({
|
||||
{ha.first, "f"},
|
||||
{ha.second, "f"},
|
||||
{ha.first, "r"},
|
||||
{ha.second, "r"},
|
||||
});
|
||||
|
||||
auto *BOprs = builder.newListOperand({
|
||||
{hb.first, "f"},
|
||||
{hb.second, "f"},
|
||||
{hb.first, "r"},
|
||||
{hb.second, "r"},
|
||||
});
|
||||
auto *COprs = builder.newListOperand();
|
||||
for (int i = 0; i < 8; ++i)
|
||||
@@ -4806,11 +4805,23 @@ namespace mlir {
|
||||
|
||||
namespace LLVM {
|
||||
|
||||
void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
PrintfOpConversion::llPrintf(msg, args, rewriter);
|
||||
}
|
||||
|
||||
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
std::string elem_repr, ConversionPatternRewriter &builder) {
|
||||
std::string fmt = info + " t-%d ";
|
||||
std::vector<Value> new_arr({thread});
|
||||
for (int i = 0; i < arr.size(); ++i) {
|
||||
fmt += elem_repr + ((i == arr.size() - 1) ? "" : ", ");
|
||||
new_arr.push_back(arr[i]);
|
||||
}
|
||||
|
||||
vprintf(fmt, new_arr, builder);
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
|
||||
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||
|
Reference in New Issue
Block a user