[BACKEND] Add C++ tests for PTXFormat and some tiny refinement (#109)
This PR does 1. Add some C++ tests for `PTXFormat` 2. Enhance the functionality of `PTXFormat`, make a `PTXInstr` instance can be called multiple times similar as a C function.
This commit is contained in:
@@ -4,8 +4,6 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
@@ -15,6 +13,7 @@ using llvm::StringRef;
|
||||
|
||||
class PTXInstr;
|
||||
class PTXInstrCommon;
|
||||
class PTXInstrExecution;
|
||||
|
||||
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||
// instructions.
|
||||
@@ -25,7 +24,8 @@ class PTXInstrCommon;
|
||||
// string and C++ if-else code.
|
||||
//
|
||||
// Usage:
|
||||
// To build: asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), "b"(p));
|
||||
// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k),
|
||||
// "b"(p));
|
||||
//
|
||||
// PTXBuilder builder;
|
||||
// auto& add = builder.create<>();
|
||||
@@ -35,7 +35,7 @@ class PTXInstrCommon;
|
||||
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
|
||||
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
|
||||
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
|
||||
// add(iOpr, jOpr, kOpr); // set operands
|
||||
// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate
|
||||
//
|
||||
// To get the asm code:
|
||||
// builder.dump()
|
||||
@@ -45,16 +45,25 @@ class PTXInstrCommon;
|
||||
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
|
||||
//
|
||||
// To get the string containing all the contraints with "," seperated,
|
||||
// builder.getConstrains() // get "=r,r,k"
|
||||
// builder.getConstraints() // get "=r,r,k"
|
||||
//
|
||||
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
|
||||
//
|
||||
// PTXBuilder builder;
|
||||
// auto& instr0 = builder.create<>();
|
||||
// auto& instr1 = builder.create<>();
|
||||
// auto& instr2 = builder.create<>();
|
||||
// auto& mov = builder.create("mov");
|
||||
// auto& cp = builder.create("cp");
|
||||
// mov(...);
|
||||
// cp(...);
|
||||
// This will get a PTX code with two instructions.
|
||||
//
|
||||
// NOTE, the instructions will be serialized in the order of creation.
|
||||
// Similar to a C function, a declared PTXInstr instance can be launched
|
||||
// multiple times with different operands, e.g.
|
||||
//
|
||||
// auto& mov = builder.create("mov");
|
||||
// mov(... some operands ...);
|
||||
// mov(... some different operands ...);
|
||||
//
|
||||
// Finally, we will get a PTX code with two mov instructions.
|
||||
//
|
||||
// There are several derived instruction type for typical instructions, for
|
||||
// example, the PtxIOInstr for ld and st instructions.
|
||||
@@ -68,6 +77,7 @@ struct PTXBuilder {
|
||||
|
||||
// for list
|
||||
Operand() = default;
|
||||
Operand(const Operation &) = delete;
|
||||
Operand(Value value, StringRef constraint)
|
||||
: value(value), constraint(constraint) {}
|
||||
|
||||
@@ -117,7 +127,7 @@ struct PTXBuilder {
|
||||
|
||||
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
|
||||
|
||||
std::string getConstrains() const;
|
||||
std::string getConstraints() const;
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
@@ -128,10 +138,12 @@ private:
|
||||
}
|
||||
|
||||
friend class PTXInstr;
|
||||
friend class PTXInstrCommon;
|
||||
|
||||
protected:
|
||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
|
||||
llvm::SmallVector<std::unique_ptr<PTXInstrExecution>, 4> executions;
|
||||
int oprCounter{};
|
||||
};
|
||||
|
||||
@@ -142,36 +154,26 @@ struct PTXInstrCommon {
|
||||
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
llvm::SmallVector<Operand *> getArgList() const;
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
// clang-format off
|
||||
void operator()(Operand* a) { operator()({a}); }
|
||||
void operator()(Operand* a, Operand* b) { operator()({a, b}); }
|
||||
void operator()(Operand* a, Operand* b, Operand* c) { operator()({a, b, c}); }
|
||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d) { operator()({a, b, c, d}); }
|
||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { operator()({a, b, c, d, e}); }
|
||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { operator()({a, b, c, d, e, f}); }
|
||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { operator()({a, b, c, d, e, f, g}); }
|
||||
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); }
|
||||
// clang-format on
|
||||
|
||||
// Set operands of this instruction.
|
||||
void operator()(llvm::ArrayRef<Operand *> oprs);
|
||||
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs);
|
||||
|
||||
protected:
|
||||
// Append the operand to the instruction's operand list.
|
||||
Operand *addOperand(Operand *opr) {
|
||||
assert(std::find(argsInOrder.begin(), argsInOrder.end(), opr) ==
|
||||
argsInOrder.end());
|
||||
argsInOrder.push_back(opr);
|
||||
return opr;
|
||||
}
|
||||
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs);
|
||||
|
||||
PTXBuilder *builder{};
|
||||
Operand *pred{};
|
||||
llvm::SmallVector<std::string, 4> instrParts;
|
||||
llvm::SmallVector<Operand *> argsInOrder;
|
||||
|
||||
friend class PTXInstrExecution;
|
||||
};
|
||||
|
||||
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||
@@ -192,19 +194,6 @@ template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||
instrParts.push_back(suffix);
|
||||
return *static_cast<ConcreteT *>(this);
|
||||
}
|
||||
|
||||
// Prefix a predicate to the instruction.
|
||||
ConcreteT &predicate(mlir::Value value, StringRef constraint) {
|
||||
pred = builder->newOperand(value, constraint);
|
||||
return *static_cast<ConcreteT *>(this);
|
||||
}
|
||||
|
||||
// Prefix a !predicate to the instruction.
|
||||
ConcreteT &predicateNot(mlir::Value value, StringRef constraint) {
|
||||
pred = builder->newOperand(value, constraint);
|
||||
pred->repr = [](int idx) { return llvm::formatv("@!%{0}", idx); };
|
||||
return *static_cast<ConcreteT *>(this);
|
||||
}
|
||||
};
|
||||
|
||||
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
||||
@@ -228,18 +217,50 @@ struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
|
||||
// Add ".v" suffix to instruction
|
||||
PtxIOInstr &v(int vecWidth, bool predicate = true) {
|
||||
if (vecWidth > 1) {
|
||||
o(llvm::formatv("v{0}", vecWidth), predicate);
|
||||
o("v" + std::to_string(vecWidth), predicate);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Add ".b" suffix to instruction
|
||||
PtxIOInstr &b(int width) {
|
||||
o(llvm::formatv("b{0}", width));
|
||||
o("b" + std::to_string(width));
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
// Record the operands and context for "launching" a PtxInstr.
|
||||
struct PTXInstrExecution {
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
llvm::SmallVector<Operand *> argsInOrder;
|
||||
|
||||
PTXInstrExecution() = default;
|
||||
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
||||
llvm::ArrayRef<Operand *> oprs)
|
||||
: instr(instr), argsInOrder(oprs.begin(), oprs.end()) {}
|
||||
|
||||
// Prefix a predicate to the instruction.
|
||||
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
||||
pred = instr->builder->newOperand(value, constraint);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Prefix a !predicate to the instruction.
|
||||
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
|
||||
pred = instr->builder->newOperand(value, constraint);
|
||||
pred->repr = [](int idx) { return "@!%" + std::to_string(idx); };
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
SmallVector<Operand *> getArgList() const;
|
||||
|
||||
PTXInstrCommon *instr{};
|
||||
Operand *pred{};
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
|
Reference in New Issue
Block a user