adding tunable parameters
This commit is contained in:
@@ -78,51 +78,51 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
#define TM 128
|
||||
#define TN 128
|
||||
#define TK 32
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
extern int get_program_id(int);
|
||||
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
#define TN 128
|
||||
#define TK 32
|
||||
|
||||
extern int get_program_id(int);
|
||||
static const int TM = 128;
|
||||
|
||||
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
||||
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
|
||||
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
|
||||
int M, int N, int K,
|
||||
int lda __attribute__((multiple_of(8))),
|
||||
int ldb __attribute__((multiple_of(8))),
|
||||
int ldc) {
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[{TM}] = ridx * TM + 0 ... TM;
|
||||
int ryb[{TN}] = ridy * TN + 0 ... TN;
|
||||
int rka[{TK}] = 0 ... TK;
|
||||
int rkb[{TK}] = 0 ... TK;
|
||||
float xc[{)" + XCS + R"(}] = 0;
|
||||
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
|
||||
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = )" + usea + " @ " + useb + R"( + xc;
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[{TM}] = ridx * TM + (0 ... TM);
|
||||
int ryc[{TN}] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[{TM, TN}] = xc;
|
||||
bool checkc0[{TM}] = rxc < M;
|
||||
bool checkc1[{TN}] = ryc < N;
|
||||
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*pc = c;
|
||||
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
||||
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
|
||||
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
|
||||
int M, int N, int K,
|
||||
int lda __attribute__((multiple_of(8))),
|
||||
int ldb __attribute__((multiple_of(8))),
|
||||
int ldc) {
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[{TM}] = ridx * TM + 0 ... TM;
|
||||
int ryb[{TN}] = ridy * TN + 0 ... TN;
|
||||
int rka[{TK}] = 0 ... TK;
|
||||
int rkb[{TK}] = 0 ... TK;
|
||||
float xc[{)" + XCS + R"(}] = 0;
|
||||
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
|
||||
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = )" + usea + " @ " + useb + R"( + xc;
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[{TM}] = ridx * TM + (0 ... TM);
|
||||
int ryc[{TN}] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[{TM, TN}] = xc;
|
||||
bool checkc0[{TM}] = rxc < M;
|
||||
bool checkc1[{TN}] = ryc < N;
|
||||
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
return res;
|
||||
|
@@ -25,7 +25,7 @@ class Parser {
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
explicit Parser(const TokenSequence& ts)
|
||||
explicit Parser(TokenSequence& ts)
|
||||
: unit_(TranslationUnit::New()),
|
||||
ts_(ts),
|
||||
externalSymbols_(new Scope(nullptr, S_BLOCK)),
|
||||
@@ -224,7 +224,7 @@ private:
|
||||
// The root of the AST
|
||||
TranslationUnit* unit_;
|
||||
|
||||
TokenSequence ts_;
|
||||
TokenSequence& ts_;
|
||||
|
||||
// It is not the real scope,
|
||||
// It contains all external symbols(resolved and not resolved)
|
||||
|
@@ -223,7 +223,7 @@ public:
|
||||
virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); }
|
||||
virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; }
|
||||
virtual bool IsFloat() const {
|
||||
return (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
|
||||
return (tag_ & T_HALF) || (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
|
||||
}
|
||||
virtual bool IsBool() const { return tag_ & T_BOOL; }
|
||||
bool IsComplex() const { return tag_ & T_COMPLEX; }
|
||||
|
@@ -20,6 +20,7 @@
|
||||
#include "triton/codegen/transform/shmem/barriers.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/vectorize.h"
|
||||
#include "triton/lang/wgtcc/parser.h"
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
@@ -87,9 +88,9 @@ private:
|
||||
typedef std::pair<options, caller> cache_val_t;
|
||||
|
||||
private:
|
||||
triton::lang::translation_unit *make_ast(const char *src);
|
||||
std::unique_ptr<ir::module> make_ir(triton::lang::translation_unit *program);
|
||||
options autotune(lang::translation_unit *ast, driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
||||
triton::lang::translation_unit *make_ast(const std::string &src);
|
||||
std::unique_ptr<ir::module> make_ir(Parser &parser);
|
||||
options autotune(Parser &parser, driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options &opt);
|
||||
|
||||
|
||||
@@ -100,11 +101,12 @@ public:
|
||||
std::string make_tensorflow_src(const std::vector<size_t> &outputs, const std::string ¯o);
|
||||
|
||||
private:
|
||||
TokenSequence ts_;
|
||||
Parser parser_;
|
||||
// execution context
|
||||
ir::context ctx_;
|
||||
// program representations
|
||||
std::string src_;
|
||||
lang::translation_unit *ast_;
|
||||
std::map<cache_key_t, cache_val_t> cache_;
|
||||
};
|
||||
|
||||
|
@@ -106,10 +106,10 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
||||
file_type_t ft) {
|
||||
init_llvm();
|
||||
// debug
|
||||
// llvm::legacy::PassManager pm;
|
||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// pm.add(llvm::createVerifierPass());
|
||||
// pm.run(*module);
|
||||
llvm::legacy::PassManager pm;
|
||||
pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
pm.add(llvm::createVerifierPass());
|
||||
pm.run(*module);
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
|
@@ -461,7 +461,10 @@ void BinaryOp::MatmulOpTypeChecking() {
|
||||
QualType retType = lhsType->Derived();
|
||||
if(retType != rhsType->Derived())
|
||||
Error(this, "matrix multiplication operands have incompatible data types");
|
||||
type_ = TileType::New(retShape, lhsType->Derived());
|
||||
ArithmType* ScalType = lhsType->ScalarType()->ToArithm();
|
||||
if(ScalType->Tag() & T_HALF)
|
||||
ScalType = ArithmType::New(T_FLOAT);
|
||||
type_ = TileType::New(retShape, ScalType);
|
||||
}
|
||||
|
||||
void BinaryOp::ShiftOpTypeChecking() {
|
||||
|
@@ -29,7 +29,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
ir::value* lhs = ret_;
|
||||
|
||||
// op info
|
||||
auto type = binary->lhs_->Type();
|
||||
auto type = binary->lhs_->Type()->ScalarType();
|
||||
auto flt = type->IsFloat();
|
||||
auto sign = !type->IsUnsigned();
|
||||
|
||||
|
@@ -318,7 +318,7 @@ bool ArrayType::Compatible(const Type& other) const {
|
||||
bool TileType::Compatible(const Type& other) const {
|
||||
// For two tile type to be compatible,
|
||||
// the element types must be compatible
|
||||
// and they must have the same shape
|
||||
// and they must have the same shapea
|
||||
auto otherTile = other.ToTile();
|
||||
if(!otherTile)
|
||||
return false;
|
||||
|
@@ -17,6 +17,7 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "triton/ir/print.h"
|
||||
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
@@ -117,50 +118,17 @@ void function::caller::operator ()(driver::stream *stream, const std::array<size
|
||||
|
||||
|
||||
|
||||
// module
|
||||
triton::lang::translation_unit *function::make_ast(const char *csrc) {
|
||||
std::string src(csrc);
|
||||
std::cout << src << std::endl;
|
||||
Preprocessor cpp(&src, true);
|
||||
// for (auto& def: defines)
|
||||
// DefineMacro(cpp, def);
|
||||
// for (auto& path: include_paths)
|
||||
// cpp.AddSearchPath(path);
|
||||
|
||||
FILE* fp = stdout;
|
||||
// if (specified_out_name) {
|
||||
// fp = fopen(filename_out.c_str(), "w");
|
||||
// }
|
||||
TokenSequence ts;
|
||||
cpp.Process(ts);
|
||||
Parser parser(ts);
|
||||
parser.Parse();
|
||||
Generator gen(&parser);
|
||||
ir::module out("", ctx_);
|
||||
gen.Gen(&out);
|
||||
exit(EXIT_FAILURE);
|
||||
|
||||
// if (only_preprocess) {
|
||||
// ts.Print(fp);
|
||||
// return 0;
|
||||
// }
|
||||
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(csrc);
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
triton::lang::translation_unit *program = ast_root;
|
||||
return program;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::module> function::make_ir(triton::lang::translation_unit *program) {
|
||||
std::unique_ptr<ir::module> function::make_ir(Parser& parser) {
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module("", ctx_);
|
||||
program->codegen(module);
|
||||
Generator gen(&parser);
|
||||
gen.Gen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
|
||||
options function::autotune(lang::translation_unit *ast, driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector<arg>& args) {
|
||||
std::unique_ptr<ir::module> ir = make_ir(ast);
|
||||
options function::autotune(Parser& parser, driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector<arg>& args) {
|
||||
std::unique_ptr<ir::module> ir = make_ir(parser);
|
||||
// extract tunable values
|
||||
std::vector<std::pair<std::string, ir::metaparameter*>> values;
|
||||
for(auto it: ir->globals())
|
||||
@@ -186,7 +154,7 @@ options function::autotune(lang::translation_unit *ast, driver::stream* stream,
|
||||
for(auto it: values)
|
||||
opt.params[it.first] = params[i++];
|
||||
// make binary
|
||||
auto ir = make_ir(ast);
|
||||
auto ir = make_ir(parser);
|
||||
auto bin = make_bin(*ir, stream->context(), opt);
|
||||
// benchmark
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
@@ -242,6 +210,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
ir::print(module, std::cout);
|
||||
exit(EXIT_FAILURE);
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, llvm.get()));
|
||||
@@ -249,9 +219,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
}
|
||||
|
||||
|
||||
function::function(const std::string &src): src_(src) {
|
||||
// src -> ast
|
||||
ast_ = make_ast(src_.c_str());
|
||||
function::function(const std::string &src): parser_(ts_), src_(src){
|
||||
Preprocessor cpp(&src_, true);
|
||||
cpp.Process(ts_);
|
||||
ts_.Print();
|
||||
parser_.Parse();
|
||||
}
|
||||
|
||||
void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
@@ -277,8 +249,8 @@ void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_f
|
||||
}
|
||||
|
||||
/* re-tune and re-compile */
|
||||
options opt = autotune(ast_, stream, grid_fn, args);
|
||||
std::unique_ptr<ir::module> ir = make_ir(ast_);
|
||||
options opt = autotune(parser_, stream, grid_fn, args);
|
||||
std::unique_ptr<ir::module> ir = make_ir(parser_);
|
||||
std::unique_ptr<driver::module> bin = make_bin(*ir, stream->context(), opt);
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
const caller& run = cache_.insert({key, cache_val_t{opt, caller(fn, std::move(bin), opt.num_warps*32)}}).first->second.second;
|
||||
|
Reference in New Issue
Block a user