From e517b58d59ba96357d042d8fa5819a690d00d749 Mon Sep 17 00:00:00 2001 From: Qingyi Liu Date: Wed, 9 Nov 2022 10:18:56 +0800 Subject: [PATCH] [Triton-MLIR] Minor fixes to enable fused-softmax and layer-norm tutorials (#835) --- include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h | 2 +- lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp | 2 +- lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h index 133c7eaf3..966da36c2 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h @@ -320,7 +320,7 @@ struct PTXInstrExecution { // Prefix a !predicate to the instruction. PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) { pred = instr->builder->newOperand(value, constraint); - pred->repr = [](int idx) { return "@!%" + std::to_string(idx); }; + pred->repr = [](int idx) { return "@!$" + std::to_string(idx); }; return *this; } diff --git a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp index 3e54b0388..7b257fd3c 100644 --- a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp +++ b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp @@ -125,7 +125,7 @@ std::string PTXBuilder::dump() const { lines.push_back(exec->dump()); } - return strJoin(lines, "\r\n"); + return strJoin(lines, "\n\t"); } PTXInstrExecution &PTXInstrCommon::call(ArrayRef oprs) { diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 64dd0d76d..ec16e9aef 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1051,7 +1051,7 @@ struct LoadOpConversion if (other) { for (size_t ii = 0; ii < nWords; ++ii) { PTXInstr &mov = *ptxBuilder.create<>("mov"); - mov.o("u", width); + mov.o("u" + std::to_string(width)); size_t size = width / valueElemNbits; diff --git a/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp b/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp index efce89a67..d4024b5c1 100644 --- a/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp @@ -112,8 +112,8 @@ TEST_F(PtxAsmFormatTest, MultiLinePTX) { mov(valVal1, constVal); mov(valVal1, valVal0); - EXPECT_EQ(builder.dump(), "mov $0, 0x1;\r\n" - "mov $1, 0x1;\r\n" + EXPECT_EQ(builder.dump(), "mov $0, 0x1;\n\t" + "mov $1, 0x1;\n\t" "mov $1, $0;"); auto values = builder.getAllMLIRArgs();