[BACKEND] Add C++ tests for PTXFormat and some tiny refinement (#109)
This PR does 1. Add some C++ tests for `PTXFormat` 2. Enhance the functionality of `PTXFormat`, make a `PTXInstr` instance can be called multiple times similar as a C function.
This commit is contained in:
@@ -4,8 +4,6 @@
|
|||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/Format.h"
|
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
@@ -15,6 +13,7 @@ using llvm::StringRef;
|
|||||||
|
|
||||||
class PTXInstr;
|
class PTXInstr;
|
||||||
class PTXInstrCommon;
|
class PTXInstrCommon;
|
||||||
|
class PTXInstrExecution;
|
||||||
|
|
||||||
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||||
// instructions.
|
// instructions.
|
||||||
@@ -25,7 +24,8 @@ class PTXInstrCommon;
|
|||||||
// string and C++ if-else code.
|
// string and C++ if-else code.
|
||||||
//
|
//
|
||||||
// Usage:
|
// Usage:
|
||||||
// To build: asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), "b"(p));
|
// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k),
|
||||||
|
// "b"(p));
|
||||||
//
|
//
|
||||||
// PTXBuilder builder;
|
// PTXBuilder builder;
|
||||||
// auto& add = builder.create<>();
|
// auto& add = builder.create<>();
|
||||||
@@ -35,7 +35,7 @@ class PTXInstrCommon;
|
|||||||
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
|
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
|
||||||
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
|
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
|
||||||
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
|
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
|
||||||
// add(iOpr, jOpr, kOpr); // set operands
|
// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate
|
||||||
//
|
//
|
||||||
// To get the asm code:
|
// To get the asm code:
|
||||||
// builder.dump()
|
// builder.dump()
|
||||||
@@ -45,16 +45,25 @@ class PTXInstrCommon;
|
|||||||
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
|
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
|
||||||
//
|
//
|
||||||
// To get the string containing all the contraints with "," seperated,
|
// To get the string containing all the contraints with "," seperated,
|
||||||
// builder.getConstrains() // get "=r,r,k"
|
// builder.getConstraints() // get "=r,r,k"
|
||||||
//
|
//
|
||||||
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
|
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
|
||||||
//
|
//
|
||||||
// PTXBuilder builder;
|
// PTXBuilder builder;
|
||||||
// auto& instr0 = builder.create<>();
|
// auto& mov = builder.create("mov");
|
||||||
// auto& instr1 = builder.create<>();
|
// auto& cp = builder.create("cp");
|
||||||
// auto& instr2 = builder.create<>();
|
// mov(...);
|
||||||
|
// cp(...);
|
||||||
|
// This will get a PTX code with two instructions.
|
||||||
//
|
//
|
||||||
// NOTE, the instructions will be serialized in the order of creation.
|
// Similar to a C function, a declared PTXInstr instance can be launched
|
||||||
|
// multiple times with different operands, e.g.
|
||||||
|
//
|
||||||
|
// auto& mov = builder.create("mov");
|
||||||
|
// mov(... some operands ...);
|
||||||
|
// mov(... some different operands ...);
|
||||||
|
//
|
||||||
|
// Finally, we will get a PTX code with two mov instructions.
|
||||||
//
|
//
|
||||||
// There are several derived instruction type for typical instructions, for
|
// There are several derived instruction type for typical instructions, for
|
||||||
// example, the PtxIOInstr for ld and st instructions.
|
// example, the PtxIOInstr for ld and st instructions.
|
||||||
@@ -68,6 +77,7 @@ struct PTXBuilder {
|
|||||||
|
|
||||||
// for list
|
// for list
|
||||||
Operand() = default;
|
Operand() = default;
|
||||||
|
Operand(const Operation &) = delete;
|
||||||
Operand(Value value, StringRef constraint)
|
Operand(Value value, StringRef constraint)
|
||||||
: value(value), constraint(constraint) {}
|
: value(value), constraint(constraint) {}
|
||||||
|
|
||||||
@@ -117,7 +127,7 @@ struct PTXBuilder {
|
|||||||
|
|
||||||
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
|
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
|
||||||
|
|
||||||
std::string getConstrains() const;
|
std::string getConstraints() const;
|
||||||
|
|
||||||
std::string dump() const;
|
std::string dump() const;
|
||||||
|
|
||||||
@@ -128,10 +138,12 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
friend class PTXInstr;
|
friend class PTXInstr;
|
||||||
|
friend class PTXInstrCommon;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||||
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
|
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
|
||||||
|
llvm::SmallVector<std::unique_ptr<PTXInstrExecution>, 4> executions;
|
||||||
int oprCounter{};
|
int oprCounter{};
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -142,36 +154,26 @@ struct PTXInstrCommon {
|
|||||||
|
|
||||||
using Operand = PTXBuilder::Operand;
|
using Operand = PTXBuilder::Operand;
|
||||||
|
|
||||||
llvm::SmallVector<Operand *> getArgList() const;
|
|
||||||
|
|
||||||
std::string dump() const;
|
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
void operator()(Operand* a) { operator()({a}); }
|
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
|
||||||
void operator()(Operand* a, Operand* b) { operator()({a, b}); }
|
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
|
||||||
void operator()(Operand* a, Operand* b, Operand* c) { operator()({a, b, c}); }
|
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
|
||||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d) { operator()({a, b, c, d}); }
|
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); }
|
||||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { operator()({a, b, c, d, e}); }
|
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); }
|
||||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { operator()({a, b, c, d, e, f}); }
|
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); }
|
||||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { operator()({a, b, c, d, e, f, g}); }
|
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); }
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
// Set operands of this instruction.
|
// Set operands of this instruction.
|
||||||
void operator()(llvm::ArrayRef<Operand *> oprs);
|
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Append the operand to the instruction's operand list.
|
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs);
|
||||||
Operand *addOperand(Operand *opr) {
|
|
||||||
assert(std::find(argsInOrder.begin(), argsInOrder.end(), opr) ==
|
|
||||||
argsInOrder.end());
|
|
||||||
argsInOrder.push_back(opr);
|
|
||||||
return opr;
|
|
||||||
}
|
|
||||||
|
|
||||||
PTXBuilder *builder{};
|
PTXBuilder *builder{};
|
||||||
Operand *pred{};
|
|
||||||
llvm::SmallVector<std::string, 4> instrParts;
|
llvm::SmallVector<std::string, 4> instrParts;
|
||||||
llvm::SmallVector<Operand *> argsInOrder;
|
|
||||||
|
friend class PTXInstrExecution;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||||
@@ -192,19 +194,6 @@ template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
|||||||
instrParts.push_back(suffix);
|
instrParts.push_back(suffix);
|
||||||
return *static_cast<ConcreteT *>(this);
|
return *static_cast<ConcreteT *>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefix a predicate to the instruction.
|
|
||||||
ConcreteT &predicate(mlir::Value value, StringRef constraint) {
|
|
||||||
pred = builder->newOperand(value, constraint);
|
|
||||||
return *static_cast<ConcreteT *>(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prefix a !predicate to the instruction.
|
|
||||||
ConcreteT &predicateNot(mlir::Value value, StringRef constraint) {
|
|
||||||
pred = builder->newOperand(value, constraint);
|
|
||||||
pred->repr = [](int idx) { return llvm::formatv("@!%{0}", idx); };
|
|
||||||
return *static_cast<ConcreteT *>(this);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
||||||
@@ -228,18 +217,50 @@ struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
|
|||||||
// Add ".v" suffix to instruction
|
// Add ".v" suffix to instruction
|
||||||
PtxIOInstr &v(int vecWidth, bool predicate = true) {
|
PtxIOInstr &v(int vecWidth, bool predicate = true) {
|
||||||
if (vecWidth > 1) {
|
if (vecWidth > 1) {
|
||||||
o(llvm::formatv("v{0}", vecWidth), predicate);
|
o("v" + std::to_string(vecWidth), predicate);
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add ".b" suffix to instruction
|
// Add ".b" suffix to instruction
|
||||||
PtxIOInstr &b(int width) {
|
PtxIOInstr &b(int width) {
|
||||||
o(llvm::formatv("b{0}", width));
|
o("b" + std::to_string(width));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Record the operands and context for "launching" a PtxInstr.
|
||||||
|
struct PTXInstrExecution {
|
||||||
|
using Operand = PTXBuilder::Operand;
|
||||||
|
|
||||||
|
llvm::SmallVector<Operand *> argsInOrder;
|
||||||
|
|
||||||
|
PTXInstrExecution() = default;
|
||||||
|
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
||||||
|
llvm::ArrayRef<Operand *> oprs)
|
||||||
|
: instr(instr), argsInOrder(oprs.begin(), oprs.end()) {}
|
||||||
|
|
||||||
|
// Prefix a predicate to the instruction.
|
||||||
|
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
||||||
|
pred = instr->builder->newOperand(value, constraint);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefix a !predicate to the instruction.
|
||||||
|
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
|
||||||
|
pred = instr->builder->newOperand(value, constraint);
|
||||||
|
pred->repr = [](int idx) { return "@!%" + std::to_string(idx); };
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string dump() const;
|
||||||
|
|
||||||
|
SmallVector<Operand *> getArgList() const;
|
||||||
|
|
||||||
|
PTXInstrCommon *instr{};
|
||||||
|
Operand *pred{};
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
add_mlir_conversion_library(TritonGPUToLLVM
|
add_mlir_conversion_library(TritonGPUToLLVM
|
||||||
TritonGPUToLLVM.cpp
|
TritonGPUToLLVM.cpp
|
||||||
PtxAsmFormat.cpp
|
PtxAsmFormat.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
|
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
|
||||||
|
@@ -49,7 +49,7 @@ PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) {
|
|||||||
return newConstantOperand(ss.str());
|
return newConstantOperand(ss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PTXBuilder::getConstrains() const {
|
std::string PTXBuilder::getConstraints() const {
|
||||||
auto args = getAllArgs();
|
auto args = getAllArgs();
|
||||||
llvm::SmallVector<std::string, 4> argReprs;
|
llvm::SmallVector<std::string, 4> argReprs;
|
||||||
for (auto arg : args)
|
for (auto arg : args)
|
||||||
@@ -78,7 +78,7 @@ std::string PTXInstr::Operand::dump() const {
|
|||||||
if (repr)
|
if (repr)
|
||||||
return repr(idx);
|
return repr(idx);
|
||||||
if (!isList())
|
if (!isList())
|
||||||
return llvm::formatv("${0}", idx);
|
return "$" + std::to_string(idx);
|
||||||
|
|
||||||
llvm::SmallVector<std::string> oprs;
|
llvm::SmallVector<std::string> oprs;
|
||||||
for (auto *opr : list)
|
for (auto *opr : list)
|
||||||
@@ -90,7 +90,9 @@ PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
|
|||||||
StringRef constraint, int off) {
|
StringRef constraint, int off) {
|
||||||
auto *opr = newOperand(addr, constraint);
|
auto *opr = newOperand(addr, constraint);
|
||||||
opr->repr = [off](int idx) -> std::string {
|
opr->repr = [off](int idx) -> std::string {
|
||||||
return llvm::formatv("[ ${0} + {1} ]", idx, off);
|
std::stringstream ss;
|
||||||
|
ss << "[ $" << idx << " + " << off << " ]";
|
||||||
|
return ss.str();
|
||||||
};
|
};
|
||||||
|
|
||||||
return opr;
|
return opr;
|
||||||
@@ -98,14 +100,24 @@ PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
|
|||||||
|
|
||||||
std::string PTXBuilder::dump() const {
|
std::string PTXBuilder::dump() const {
|
||||||
llvm::SmallVector<std::string> lines;
|
llvm::SmallVector<std::string> lines;
|
||||||
for (auto &instr : instrs) {
|
for (auto &exec : executions) {
|
||||||
lines.push_back(instr->dump());
|
lines.push_back(exec->dump());
|
||||||
}
|
}
|
||||||
|
|
||||||
return strJoin(lines, "\n\t");
|
return strJoin(lines, "\r\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PTXInstrCommon::dump() const {
|
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs) {
|
||||||
|
builder->executions.emplace_back(
|
||||||
|
std::make_unique<PTXInstrExecution>(this, oprs));
|
||||||
|
return *builder->executions.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
||||||
|
return call(oprs);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PTXInstrExecution::dump() const {
|
||||||
std::string osStr;
|
std::string osStr;
|
||||||
llvm::raw_string_ostream os(osStr);
|
llvm::raw_string_ostream os(osStr);
|
||||||
if (pred)
|
if (pred)
|
||||||
@@ -114,7 +126,7 @@ std::string PTXInstrCommon::dump() const {
|
|||||||
else
|
else
|
||||||
os << pred->repr(pred->idx);
|
os << pred->repr(pred->idx);
|
||||||
|
|
||||||
std::string instrRepr = strJoin(instrParts, ".");
|
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||||
|
|
||||||
llvm::SmallVector<std::string, 4> argReprs;
|
llvm::SmallVector<std::string, 4> argReprs;
|
||||||
for (auto *arg : argsInOrder) {
|
for (auto *arg : argsInOrder) {
|
||||||
@@ -128,7 +140,8 @@ std::string PTXInstrCommon::dump() const {
|
|||||||
return osStr;
|
return osStr;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<PTXInstrCommon::Operand *> PTXInstrCommon::getArgList() const {
|
SmallVector<PTXInstrExecution::Operand *>
|
||||||
|
PTXInstrExecution::getArgList() const {
|
||||||
SmallVector<Operand *> args;
|
SmallVector<Operand *> args;
|
||||||
for (auto *arg : argsInOrder) {
|
for (auto *arg : argsInOrder) {
|
||||||
if (arg->isList())
|
if (arg->isList())
|
||||||
@@ -138,11 +151,5 @@ SmallVector<PTXInstrCommon::Operand *> PTXInstrCommon::getArgList() const {
|
|||||||
}
|
}
|
||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
|
||||||
for (auto *opr : oprs) {
|
|
||||||
addOperand(opr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -756,12 +756,12 @@ struct StoreOpConversion
|
|||||||
llMask ? maskElems[vecStart]
|
llMask ? maskElems[vecStart]
|
||||||
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
||||||
rewriter.getIntegerType(1), 1);
|
rewriter.getIntegerType(1), 1);
|
||||||
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
|
ptxStoreInstr.global().b(width).v(nWords);
|
||||||
|
|
||||||
auto *asmAddr =
|
auto *asmAddr =
|
||||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||||
|
|
||||||
ptxStoreInstr(asmAddr, asmArgList);
|
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
||||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||||
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
||||||
for (int i = 0; i < nWords; i++)
|
for (int i = 0; i < nWords; i++)
|
||||||
@@ -772,7 +772,7 @@ struct StoreOpConversion
|
|||||||
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
||||||
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
|
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
|
||||||
ptxBuilder.dump(), // asm_string
|
ptxBuilder.dump(), // asm_string
|
||||||
ptxBuilder.getConstrains(), // constraints
|
ptxBuilder.getConstraints(), // constraints
|
||||||
// TODO(Superjomn) determine the side effect.
|
// TODO(Superjomn) determine the side effect.
|
||||||
true, // has_side_effects
|
true, // has_side_effects
|
||||||
false, // is_align_stack
|
false, // is_align_stack
|
||||||
@@ -1045,8 +1045,7 @@ struct LoadOpConversion
|
|||||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||||
|
|
||||||
// Define the instruction opcode
|
// Define the instruction opcode
|
||||||
ld.predicate(pred, "b")
|
ld.o("volatile", op.isVolatile())
|
||||||
.o("violatile", op.isVolatile())
|
|
||||||
.global()
|
.global()
|
||||||
.o("ca", op.cache() == triton::CacheModifier::CA)
|
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||||
.o("cg", op.cache() == triton::CacheModifier::CG)
|
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||||
@@ -1064,15 +1063,15 @@ struct LoadOpConversion
|
|||||||
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
||||||
|
|
||||||
if (!evictOpr)
|
if (!evictOpr)
|
||||||
ld(dstsOpr, addrOpr);
|
ld(dstsOpr, addrOpr).predicate(pred, "b");
|
||||||
else
|
else
|
||||||
ld(dstsOpr, addrOpr, evictOpr);
|
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
||||||
|
|
||||||
SmallVector<Value> others;
|
SmallVector<Value> others;
|
||||||
if (other) {
|
if (other) {
|
||||||
for (size_t ii = 0; ii < nWords; ii++) {
|
for (size_t ii = 0; ii < nWords; ii++) {
|
||||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||||
mov.predicateNot(pred, "b").o("u", width);
|
mov.o("u", width);
|
||||||
|
|
||||||
size_t size = width / valueElemNbits;
|
size_t size = width / valueElemNbits;
|
||||||
|
|
||||||
@@ -1096,7 +1095,7 @@ struct LoadOpConversion
|
|||||||
others.push_back(v);
|
others.push_back(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
mov(dstsOpr->listGet(ii), opr);
|
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1114,7 +1113,7 @@ struct LoadOpConversion
|
|||||||
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
|
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
|
||||||
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
|
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
|
||||||
/*asm_string=*/ptxBuilder.dump(),
|
/*asm_string=*/ptxBuilder.dump(),
|
||||||
/*constraints=*/ptxBuilder.getConstrains(),
|
/*constraints=*/ptxBuilder.getConstraints(),
|
||||||
/*has_side_effects=*/true,
|
/*has_side_effects=*/true,
|
||||||
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
|
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
|
||||||
/*operand_attrs=*/ArrayAttr());
|
/*operand_attrs=*/ArrayAttr());
|
||||||
|
@@ -29,8 +29,8 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
|||||||
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
||||||
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
||||||
|
|
||||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$1 st.global.b32 [ $2 + 0 ], { $0 };",
|
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };",
|
||||||
// CHECK-SAME: "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
|
// CHECK-SAME: "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
|
||||||
tt.store %ptrs, %vs, %mask : tensor<128xf32>
|
tt.store %ptrs, %vs, %mask : tensor<128xf32>
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@@ -217,9 +217,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
// CHECK-LABEL: basic_store
|
// CHECK-LABEL: basic_store
|
||||||
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
|
// CHECK-SAME: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,b,l" %{{.*}}, %{{.*}}, %{{.*}} : (i32, i1, !llvm.ptr<f32, 1>) -> !llvm.void
|
// CHECK-SAME: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
|
||||||
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -16,6 +16,7 @@ function(add_triton_ut)
|
|||||||
${__SRCS})
|
${__SRCS})
|
||||||
target_link_libraries(
|
target_link_libraries(
|
||||||
${__NAME}
|
${__NAME}
|
||||||
|
PRIVATE
|
||||||
GTest::gtest_main
|
GTest::gtest_main
|
||||||
gmock
|
gmock
|
||||||
${__LIBS})
|
${__LIBS})
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
add_triton_ut(
|
add_triton_ut(
|
||||||
NAME TritonGPUToLLVMTests
|
NAME PtxAsmFormatTest
|
||||||
SRCS TritonGPUToLLVMTests.cpp
|
SRCS PtxAsmFormatTest.cpp
|
||||||
LIBS TritonGPUToLLVM
|
LIBS TritonGPUToLLVM
|
||||||
)
|
)
|
||||||
|
125
unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp
Normal file
125
unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace triton {
|
||||||
|
class PtxAsmFormatTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
static constexpr int numValues = 4;
|
||||||
|
|
||||||
|
PtxAsmFormatTest() {
|
||||||
|
ctx.loadDialect<arith::ArithmeticDialect>();
|
||||||
|
|
||||||
|
createValues();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates the test values.
|
||||||
|
void createValues() {
|
||||||
|
OpBuilder builder(&ctx);
|
||||||
|
builder.setInsertionPointToStart(&block);
|
||||||
|
|
||||||
|
// a b1 value for predicate.
|
||||||
|
v[0] = builder.create<arith::ConstantIntOp>(builder.getUnknownLoc(), 1, 1);
|
||||||
|
for (int i = 0; i < numValues; i++) {
|
||||||
|
v[i + 1] =
|
||||||
|
builder.create<arith::ConstantIntOp>(builder.getUnknownLoc(), i, 32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MLIRContext ctx;
|
||||||
|
Block block;
|
||||||
|
Value v[numValues + 1];
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(PtxAsmFormatTest, basic) {
|
||||||
|
PTXBuilder builder;
|
||||||
|
|
||||||
|
// Create the operands needed by the instructions in the PTX code.
|
||||||
|
auto *cst = builder.newConstantOperand(1);
|
||||||
|
auto *val = builder.newOperand(v[1], "=r");
|
||||||
|
|
||||||
|
// create an instruction
|
||||||
|
auto &mov = *builder.create("mov.b16");
|
||||||
|
|
||||||
|
mov(val, cst).predicate(v[0]);
|
||||||
|
ASSERT_EQ(builder.dump(), "@$1 mov.b16 $0, 0x1;");
|
||||||
|
|
||||||
|
auto values = builder.getAllMLIRArgs();
|
||||||
|
ASSERT_EQ(values[0], v[1]); // $0 -> v[1]
|
||||||
|
ASSERT_EQ(values[1], v[0]); // $1 -> v[0]
|
||||||
|
|
||||||
|
auto constraints = builder.getConstraints();
|
||||||
|
ASSERT_EQ(constraints, "=r,b"); // $0 -> =r, $1 -> b
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PtxAsmFormatTest, complexInstruction) {
|
||||||
|
using triton::CacheModifier;
|
||||||
|
using triton::EvictionPolicy;
|
||||||
|
|
||||||
|
PTXBuilder builder;
|
||||||
|
|
||||||
|
int width = 16;
|
||||||
|
int nWords = 2;
|
||||||
|
|
||||||
|
Value predicateVal = v[0];
|
||||||
|
Value addrVal = v[1];
|
||||||
|
|
||||||
|
auto addr = builder.newAddrOperand(addrVal, "l", 128 /*offset*/);
|
||||||
|
|
||||||
|
bool isVolatile = false;
|
||||||
|
auto cache = triton::CacheModifier::CA;
|
||||||
|
auto cachePriority = triton::EvictionPolicy::EVICT_FIRST;
|
||||||
|
bool hasL2EvictPolicy = true;
|
||||||
|
|
||||||
|
auto &ld =
|
||||||
|
builder
|
||||||
|
.create<PtxIOInstr>("ld") //
|
||||||
|
->o("volatile", isVolatile)
|
||||||
|
.global()
|
||||||
|
.o("ca", cache == CacheModifier::CA)
|
||||||
|
.o("cg", cache == CacheModifier::CG)
|
||||||
|
.o("L1::evict_first", cachePriority == EvictionPolicy::EVICT_FIRST)
|
||||||
|
.o("L1::evict_last", cachePriority == EvictionPolicy::EVICT_LAST)
|
||||||
|
.o("L1::cache_hint", hasL2EvictPolicy)
|
||||||
|
.v(nWords)
|
||||||
|
.b(width);
|
||||||
|
|
||||||
|
// Link the instruction to operands
|
||||||
|
ld(addr).predicate(predicateVal);
|
||||||
|
|
||||||
|
EXPECT_EQ(
|
||||||
|
builder.dump(),
|
||||||
|
"@$1 ld.global.ca.L1::evict_first.L1::cache_hint.v2.b16 [ $0 + 128 ];");
|
||||||
|
auto values = builder.getAllMLIRArgs();
|
||||||
|
EXPECT_EQ(values[0], addrVal); // $0 -> predicate
|
||||||
|
EXPECT_EQ(values[1], predicateVal); // $1 -> addr
|
||||||
|
EXPECT_EQ(builder.getConstraints(), "l,b");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PtxAsmFormatTest, MultiLinePTX) {
|
||||||
|
PTXBuilder builder;
|
||||||
|
|
||||||
|
auto *constVal = builder.newConstantOperand(1);
|
||||||
|
auto *valVal0 = builder.newOperand(v[1], "=r");
|
||||||
|
auto *valVal1 = builder.newOperand(v[2], "=r");
|
||||||
|
|
||||||
|
auto &mov = *builder.create("mov");
|
||||||
|
|
||||||
|
mov(valVal0, constVal);
|
||||||
|
mov(valVal1, constVal);
|
||||||
|
mov(valVal1, valVal0);
|
||||||
|
|
||||||
|
EXPECT_EQ(builder.dump(), "mov $0, 0x1;\r\n"
|
||||||
|
"mov $1, 0x1;\r\n"
|
||||||
|
"mov $1, $0;");
|
||||||
|
|
||||||
|
auto values = builder.getAllMLIRArgs();
|
||||||
|
EXPECT_EQ(values[0], v[1]); // $0 -> v[1]
|
||||||
|
EXPECT_EQ(values[1], v[2]); // $1 -> v[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace triton
|
||||||
|
} // namespace mlir
|
@@ -1,14 +0,0 @@
|
|||||||
//===- TritonGPUToLLVMTests.cpp - Tests for
|
|
||||||
// TritonGPUToLLVM----------------------------------===//
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
|
||||||
#include <gmock/gmock.h>
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
|
|
||||||
TEST(PtxAsmFormatTest, BasicTest) { EXPECT_EQ(true, true); }
|
|
||||||
|
|
||||||
} // namespace mlir
|
|
Reference in New Issue
Block a user