adding tunable parameters

This commit is contained in:
Philippe Tillet
2019-08-22 19:21:01 -07:00
parent 87072203c1
commit 845c0e5b93
9 changed files with 76 additions and 99 deletions

View File

@@ -78,51 +78,51 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
std::string res =
R"(
#define TM 128
#define TN 128
#define TK 32
#define bool _Bool
#define true 1
#define false 0
#define __bool_true_false_are_defined 1
extern int get_program_id(int);
#define bool _Bool
#define true 1
#define false 0
#define __bool_true_false_are_defined 1
#define TN 128
#define TK 32
extern int get_program_id(int);
static const int TM = 128;
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
int M, int N, int K,
int lda __attribute__((multiple_of(8))),
int ldb __attribute__((multiple_of(8))),
int ldc) {
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[{TM}] = ridx * TM + 0 ... TM;
int ryb[{TN}] = ridy * TN + 0 ... TN;
int rka[{TK}] = 0 ... TK;
int rkb[{TK}] = 0 ... TK;
float xc[{)" + XCS + R"(}] = 0;
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
for(int k = K; k > 0; k = k - TK){
xc = )" + usea + " @ " + useb + R"( + xc;
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
a = *pa;
b = *pb;
}
int rxc[{TM}] = ridx * TM + (0 ... TM);
int ryc[{TN}] = ridy * TN + (0 ... TN);
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty + R"( c[{TM, TN}] = xc;
bool checkc0[{TM}] = rxc < M;
bool checkc1[{TN}] = ryc < N;
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
*pc = c;
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
int M, int N, int K,
int lda __attribute__((multiple_of(8))),
int ldb __attribute__((multiple_of(8))),
int ldc) {
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[{TM}] = ridx * TM + 0 ... TM;
int ryb[{TN}] = ridy * TN + 0 ... TN;
int rka[{TK}] = 0 ... TK;
int rkb[{TK}] = 0 ... TK;
float xc[{)" + XCS + R"(}] = 0;
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
for(int k = K; k > 0; k = k - TK){
xc = )" + usea + " @ " + useb + R"( + xc;
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
a = *pa;
b = *pb;
}
int rxc[{TM}] = ridx * TM + (0 ... TM);
int ryc[{TN}] = ridy * TN + (0 ... TN);
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty + R"( c[{TM, TN}] = xc;
bool checkc0[{TM}] = rxc < M;
bool checkc1[{TN}] = ryc < N;
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
*pc = c;
}
)";
return res;

View File

@@ -25,7 +25,7 @@ class Parser {
friend class Generator;
public:
explicit Parser(const TokenSequence& ts)
explicit Parser(TokenSequence& ts)
: unit_(TranslationUnit::New()),
ts_(ts),
externalSymbols_(new Scope(nullptr, S_BLOCK)),
@@ -224,7 +224,7 @@ private:
// The root of the AST
TranslationUnit* unit_;
TokenSequence ts_;
TokenSequence& ts_;
// It is not the real scope,
// It contains all external symbols(resolved and not resolved)

View File

@@ -223,7 +223,7 @@ public:
virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); }
virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; }
virtual bool IsFloat() const {
return (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
return (tag_ & T_HALF) || (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
}
virtual bool IsBool() const { return tag_ & T_BOOL; }
bool IsComplex() const { return tag_ & T_COMPLEX; }

View File

@@ -20,6 +20,7 @@
#include "triton/codegen/transform/shmem/barriers.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/vectorize.h"
#include "triton/lang/wgtcc/parser.h"
namespace llvm {
class Module;
@@ -87,9 +88,9 @@ private:
typedef std::pair<options, caller> cache_val_t;
private:
triton::lang::translation_unit *make_ast(const char *src);
std::unique_ptr<ir::module> make_ir(triton::lang::translation_unit *program);
options autotune(lang::translation_unit *ast, driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
triton::lang::translation_unit *make_ast(const std::string &src);
std::unique_ptr<ir::module> make_ir(Parser &parser);
options autotune(Parser &parser, driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options &opt);
@@ -100,11 +101,12 @@ public:
std::string make_tensorflow_src(const std::vector<size_t> &outputs, const std::string &macro);
private:
TokenSequence ts_;
Parser parser_;
// execution context
ir::context ctx_;
// program representations
std::string src_;
lang::translation_unit *ast_;
std::map<cache_key_t, cache_val_t> cache_;
};

View File

@@ -106,10 +106,10 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
file_type_t ft) {
init_llvm();
// debug
// llvm::legacy::PassManager pm;
// pm.add(llvm::createPrintModulePass(llvm::outs()));
// pm.add(llvm::createVerifierPass());
// pm.run(*module);
llvm::legacy::PassManager pm;
pm.add(llvm::createPrintModulePass(llvm::outs()));
pm.add(llvm::createVerifierPass());
pm.run(*module);
// create machine
module->setTargetTriple(triple);
std::string error;

View File

@@ -461,7 +461,10 @@ void BinaryOp::MatmulOpTypeChecking() {
QualType retType = lhsType->Derived();
if(retType != rhsType->Derived())
Error(this, "matrix multiplication operands have incompatible data types");
type_ = TileType::New(retShape, lhsType->Derived());
ArithmType* ScalType = lhsType->ScalarType()->ToArithm();
if(ScalType->Tag() & T_HALF)
ScalType = ArithmType::New(T_FLOAT);
type_ = TileType::New(retShape, ScalType);
}
void BinaryOp::ShiftOpTypeChecking() {

View File

@@ -29,7 +29,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
ir::value* lhs = ret_;
// op info
auto type = binary->lhs_->Type();
auto type = binary->lhs_->Type()->ScalarType();
auto flt = type->IsFloat();
auto sign = !type->IsUnsigned();

View File

@@ -318,7 +318,7 @@ bool ArrayType::Compatible(const Type& other) const {
bool TileType::Compatible(const Type& other) const {
// For two tile type to be compatible,
// the element types must be compatible
// and they must have the same shape
// and they must have the same shapea
auto otherTile = other.ToTile();
if(!otherTile)
return false;

View File

@@ -17,6 +17,7 @@
#include "triton/ir/function.h"
#include "triton/tools/bench.hpp"
#include "llvm/IR/Module.h"
#include "triton/ir/print.h"
typedef struct yy_buffer_state * YY_BUFFER_STATE;
@@ -117,50 +118,17 @@ void function::caller::operator ()(driver::stream *stream, const std::array<size
// module
triton::lang::translation_unit *function::make_ast(const char *csrc) {
std::string src(csrc);
std::cout << src << std::endl;
Preprocessor cpp(&src, true);
// for (auto& def: defines)
// DefineMacro(cpp, def);
// for (auto& path: include_paths)
// cpp.AddSearchPath(path);
FILE* fp = stdout;
// if (specified_out_name) {
// fp = fopen(filename_out.c_str(), "w");
// }
TokenSequence ts;
cpp.Process(ts);
Parser parser(ts);
parser.Parse();
Generator gen(&parser);
ir::module out("", ctx_);
gen.Gen(&out);
exit(EXIT_FAILURE);
// if (only_preprocess) {
// ts.Print(fp);
// return 0;
// }
YY_BUFFER_STATE buffer = yy_scan_string(csrc);
yyparse();
yy_delete_buffer(buffer);
triton::lang::translation_unit *program = ast_root;
return program;
}
std::unique_ptr<ir::module> function::make_ir(triton::lang::translation_unit *program) {
std::unique_ptr<ir::module> function::make_ir(Parser& parser) {
// create Triton-IR from AST
ir::module* module = new ir::module("", ctx_);
program->codegen(module);
Generator gen(&parser);
gen.Gen(module);
return std::unique_ptr<ir::module>(module);
}
options function::autotune(lang::translation_unit *ast, driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector<arg>& args) {
std::unique_ptr<ir::module> ir = make_ir(ast);
options function::autotune(Parser& parser, driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector<arg>& args) {
std::unique_ptr<ir::module> ir = make_ir(parser);
// extract tunable values
std::vector<std::pair<std::string, ir::metaparameter*>> values;
for(auto it: ir->globals())
@@ -186,7 +154,7 @@ options function::autotune(lang::translation_unit *ast, driver::stream* stream,
for(auto it: values)
opt.params[it.first] = params[i++];
// make binary
auto ir = make_ir(ast);
auto ir = make_ir(parser);
auto bin = make_bin(*ir, stream->context(), opt);
// benchmark
ir::function *tmp = ir->get_function_list()[0];
@@ -242,6 +210,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
// generate llvm code
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
ir::print(module, std::cout);
exit(EXIT_FAILURE);
selection.run(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, llvm.get()));
@@ -249,9 +219,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
}
function::function(const std::string &src): src_(src) {
// src -> ast
ast_ = make_ast(src_.c_str());
function::function(const std::string &src): parser_(ts_), src_(src){
Preprocessor cpp(&src_, true);
cpp.Process(ts_);
ts_.Print();
parser_.Parse();
}
void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
@@ -277,8 +249,8 @@ void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_f
}
/* re-tune and re-compile */
options opt = autotune(ast_, stream, grid_fn, args);
std::unique_ptr<ir::module> ir = make_ir(ast_);
options opt = autotune(parser_, stream, grid_fn, args);
std::unique_ptr<ir::module> ir = make_ir(parser_);
std::unique_ptr<driver::module> bin = make_bin(*ir, stream->context(), opt);
ir::function* fn = ir->get_function_list().front();
const caller& run = cache_.insert({key, cache_val_t{opt, caller(fn, std::move(bin), opt.num_warps*32)}}).first->second.second;