[Triton-MLIR] Minor fixes to enable fused-softmax and layer-norm tutorials (#835)
This commit is contained in:
@@ -320,7 +320,7 @@ struct PTXInstrExecution {
|
|||||||
// Prefix a !predicate to the instruction.
|
// Prefix a !predicate to the instruction.
|
||||||
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
|
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
|
||||||
pred = instr->builder->newOperand(value, constraint);
|
pred = instr->builder->newOperand(value, constraint);
|
||||||
pred->repr = [](int idx) { return "@!%" + std::to_string(idx); };
|
pred->repr = [](int idx) { return "@!$" + std::to_string(idx); };
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -125,7 +125,7 @@ std::string PTXBuilder::dump() const {
|
|||||||
lines.push_back(exec->dump());
|
lines.push_back(exec->dump());
|
||||||
}
|
}
|
||||||
|
|
||||||
return strJoin(lines, "\r\n");
|
return strJoin(lines, "\n\t");
|
||||||
}
|
}
|
||||||
|
|
||||||
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs) {
|
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs) {
|
||||||
|
@@ -1051,7 +1051,7 @@ struct LoadOpConversion
|
|||||||
if (other) {
|
if (other) {
|
||||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||||
mov.o("u", width);
|
mov.o("u" + std::to_string(width));
|
||||||
|
|
||||||
size_t size = width / valueElemNbits;
|
size_t size = width / valueElemNbits;
|
||||||
|
|
||||||
|
@@ -112,8 +112,8 @@ TEST_F(PtxAsmFormatTest, MultiLinePTX) {
|
|||||||
mov(valVal1, constVal);
|
mov(valVal1, constVal);
|
||||||
mov(valVal1, valVal0);
|
mov(valVal1, valVal0);
|
||||||
|
|
||||||
EXPECT_EQ(builder.dump(), "mov $0, 0x1;\r\n"
|
EXPECT_EQ(builder.dump(), "mov $0, 0x1;\n\t"
|
||||||
"mov $1, 0x1;\r\n"
|
"mov $1, 0x1;\n\t"
|
||||||
"mov $1, $0;");
|
"mov $1, $0;");
|
||||||
|
|
||||||
auto values = builder.getAllMLIRArgs();
|
auto values = builder.getAllMLIRArgs();
|
||||||
|
Reference in New Issue
Block a user