adding tunable parameters
This commit is contained in:
@@ -78,51 +78,51 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
|
|||||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
|
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
|
||||||
std::string res =
|
std::string res =
|
||||||
R"(
|
R"(
|
||||||
#define TM 128
|
#define bool _Bool
|
||||||
#define TN 128
|
#define true 1
|
||||||
#define TK 32
|
#define false 0
|
||||||
|
#define __bool_true_false_are_defined 1
|
||||||
|
extern int get_program_id(int);
|
||||||
|
|
||||||
#define bool _Bool
|
#define TN 128
|
||||||
#define true 1
|
#define TK 32
|
||||||
#define false 0
|
|
||||||
#define __bool_true_false_are_defined 1
|
|
||||||
|
|
||||||
extern int get_program_id(int);
|
static const int TM = 128;
|
||||||
|
|
||||||
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
||||||
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
|
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
|
||||||
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
|
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
|
||||||
int M, int N, int K,
|
int M, int N, int K,
|
||||||
int lda __attribute__((multiple_of(8))),
|
int lda __attribute__((multiple_of(8))),
|
||||||
int ldb __attribute__((multiple_of(8))),
|
int ldb __attribute__((multiple_of(8))),
|
||||||
int ldc) {
|
int ldc) {
|
||||||
int ridx = get_program_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_program_id(1);
|
int ridy = get_program_id(1);
|
||||||
int rxa[{TM}] = ridx * TM + 0 ... TM;
|
int rxa[{TM}] = ridx * TM + 0 ... TM;
|
||||||
int ryb[{TN}] = ridy * TN + 0 ... TN;
|
int ryb[{TN}] = ridy * TN + 0 ... TN;
|
||||||
int rka[{TK}] = 0 ... TK;
|
int rka[{TK}] = 0 ... TK;
|
||||||
int rkb[{TK}] = 0 ... TK;
|
int rkb[{TK}] = 0 ... TK;
|
||||||
float xc[{)" + XCS + R"(}] = 0;
|
float xc[{)" + XCS + R"(}] = 0;
|
||||||
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||||
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||||
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
|
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
|
||||||
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
|
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
|
||||||
for(int k = K; k > 0; k = k - TK){
|
for(int k = K; k > 0; k = k - TK){
|
||||||
xc = )" + usea + " @ " + useb + R"( + xc;
|
xc = )" + usea + " @ " + useb + R"( + xc;
|
||||||
pa = pa + TK)" + lda0 + R"(;
|
pa = pa + TK)" + lda0 + R"(;
|
||||||
pb = pb + TK)" + ldb0 + R"(;
|
pb = pb + TK)" + ldb0 + R"(;
|
||||||
a = *pa;
|
a = *pa;
|
||||||
b = *pb;
|
b = *pb;
|
||||||
}
|
|
||||||
int rxc[{TM}] = ridx * TM + (0 ... TM);
|
|
||||||
int ryc[{TN}] = ridy * TN + (0 ... TN);
|
|
||||||
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
|
||||||
)" + c_ty + R"( c[{TM, TN}] = xc;
|
|
||||||
bool checkc0[{TM}] = rxc < M;
|
|
||||||
bool checkc1[{TN}] = ryc < N;
|
|
||||||
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
|
||||||
*pc = c;
|
|
||||||
}
|
}
|
||||||
|
int rxc[{TM}] = ridx * TM + (0 ... TM);
|
||||||
|
int ryc[{TN}] = ridy * TN + (0 ... TN);
|
||||||
|
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||||
|
)" + c_ty + R"( c[{TM, TN}] = xc;
|
||||||
|
bool checkc0[{TM}] = rxc < M;
|
||||||
|
bool checkc1[{TN}] = ryc < N;
|
||||||
|
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||||
|
*pc = c;
|
||||||
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
@@ -25,7 +25,7 @@ class Parser {
|
|||||||
friend class Generator;
|
friend class Generator;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit Parser(const TokenSequence& ts)
|
explicit Parser(TokenSequence& ts)
|
||||||
: unit_(TranslationUnit::New()),
|
: unit_(TranslationUnit::New()),
|
||||||
ts_(ts),
|
ts_(ts),
|
||||||
externalSymbols_(new Scope(nullptr, S_BLOCK)),
|
externalSymbols_(new Scope(nullptr, S_BLOCK)),
|
||||||
@@ -224,7 +224,7 @@ private:
|
|||||||
// The root of the AST
|
// The root of the AST
|
||||||
TranslationUnit* unit_;
|
TranslationUnit* unit_;
|
||||||
|
|
||||||
TokenSequence ts_;
|
TokenSequence& ts_;
|
||||||
|
|
||||||
// It is not the real scope,
|
// It is not the real scope,
|
||||||
// It contains all external symbols(resolved and not resolved)
|
// It contains all external symbols(resolved and not resolved)
|
||||||
|
@@ -223,7 +223,7 @@ public:
|
|||||||
virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); }
|
virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); }
|
||||||
virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; }
|
virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; }
|
||||||
virtual bool IsFloat() const {
|
virtual bool IsFloat() const {
|
||||||
return (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
|
return (tag_ & T_HALF) || (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
|
||||||
}
|
}
|
||||||
virtual bool IsBool() const { return tag_ & T_BOOL; }
|
virtual bool IsBool() const { return tag_ & T_BOOL; }
|
||||||
bool IsComplex() const { return tag_ & T_COMPLEX; }
|
bool IsComplex() const { return tag_ & T_COMPLEX; }
|
||||||
|
@@ -20,6 +20,7 @@
|
|||||||
#include "triton/codegen/transform/shmem/barriers.h"
|
#include "triton/codegen/transform/shmem/barriers.h"
|
||||||
#include "triton/codegen/transform/reassociate.h"
|
#include "triton/codegen/transform/reassociate.h"
|
||||||
#include "triton/codegen/transform/vectorize.h"
|
#include "triton/codegen/transform/vectorize.h"
|
||||||
|
#include "triton/lang/wgtcc/parser.h"
|
||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
class Module;
|
class Module;
|
||||||
@@ -87,9 +88,9 @@ private:
|
|||||||
typedef std::pair<options, caller> cache_val_t;
|
typedef std::pair<options, caller> cache_val_t;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
triton::lang::translation_unit *make_ast(const char *src);
|
triton::lang::translation_unit *make_ast(const std::string &src);
|
||||||
std::unique_ptr<ir::module> make_ir(triton::lang::translation_unit *program);
|
std::unique_ptr<ir::module> make_ir(Parser &parser);
|
||||||
options autotune(lang::translation_unit *ast, driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
options autotune(Parser &parser, driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
||||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options &opt);
|
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options &opt);
|
||||||
|
|
||||||
|
|
||||||
@@ -100,11 +101,12 @@ public:
|
|||||||
std::string make_tensorflow_src(const std::vector<size_t> &outputs, const std::string ¯o);
|
std::string make_tensorflow_src(const std::vector<size_t> &outputs, const std::string ¯o);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
TokenSequence ts_;
|
||||||
|
Parser parser_;
|
||||||
// execution context
|
// execution context
|
||||||
ir::context ctx_;
|
ir::context ctx_;
|
||||||
// program representations
|
// program representations
|
||||||
std::string src_;
|
std::string src_;
|
||||||
lang::translation_unit *ast_;
|
|
||||||
std::map<cache_key_t, cache_val_t> cache_;
|
std::map<cache_key_t, cache_val_t> cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -106,10 +106,10 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
|||||||
file_type_t ft) {
|
file_type_t ft) {
|
||||||
init_llvm();
|
init_llvm();
|
||||||
// debug
|
// debug
|
||||||
// llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||||
// pm.add(llvm::createVerifierPass());
|
pm.add(llvm::createVerifierPass());
|
||||||
// pm.run(*module);
|
pm.run(*module);
|
||||||
// create machine
|
// create machine
|
||||||
module->setTargetTriple(triple);
|
module->setTargetTriple(triple);
|
||||||
std::string error;
|
std::string error;
|
||||||
|
@@ -461,7 +461,10 @@ void BinaryOp::MatmulOpTypeChecking() {
|
|||||||
QualType retType = lhsType->Derived();
|
QualType retType = lhsType->Derived();
|
||||||
if(retType != rhsType->Derived())
|
if(retType != rhsType->Derived())
|
||||||
Error(this, "matrix multiplication operands have incompatible data types");
|
Error(this, "matrix multiplication operands have incompatible data types");
|
||||||
type_ = TileType::New(retShape, lhsType->Derived());
|
ArithmType* ScalType = lhsType->ScalarType()->ToArithm();
|
||||||
|
if(ScalType->Tag() & T_HALF)
|
||||||
|
ScalType = ArithmType::New(T_FLOAT);
|
||||||
|
type_ = TileType::New(retShape, ScalType);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BinaryOp::ShiftOpTypeChecking() {
|
void BinaryOp::ShiftOpTypeChecking() {
|
||||||
|
@@ -29,7 +29,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
ir::value* lhs = ret_;
|
ir::value* lhs = ret_;
|
||||||
|
|
||||||
// op info
|
// op info
|
||||||
auto type = binary->lhs_->Type();
|
auto type = binary->lhs_->Type()->ScalarType();
|
||||||
auto flt = type->IsFloat();
|
auto flt = type->IsFloat();
|
||||||
auto sign = !type->IsUnsigned();
|
auto sign = !type->IsUnsigned();
|
||||||
|
|
||||||
|
@@ -318,7 +318,7 @@ bool ArrayType::Compatible(const Type& other) const {
|
|||||||
bool TileType::Compatible(const Type& other) const {
|
bool TileType::Compatible(const Type& other) const {
|
||||||
// For two tile type to be compatible,
|
// For two tile type to be compatible,
|
||||||
// the element types must be compatible
|
// the element types must be compatible
|
||||||
// and they must have the same shape
|
// and they must have the same shapea
|
||||||
auto otherTile = other.ToTile();
|
auto otherTile = other.ToTile();
|
||||||
if(!otherTile)
|
if(!otherTile)
|
||||||
return false;
|
return false;
|
||||||
|
@@ -17,6 +17,7 @@
|
|||||||
#include "triton/ir/function.h"
|
#include "triton/ir/function.h"
|
||||||
#include "triton/tools/bench.hpp"
|
#include "triton/tools/bench.hpp"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
|
#include "triton/ir/print.h"
|
||||||
|
|
||||||
|
|
||||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||||
@@ -117,50 +118,17 @@ void function::caller::operator ()(driver::stream *stream, const std::array<size
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
// module
|
|
||||||
triton::lang::translation_unit *function::make_ast(const char *csrc) {
|
|
||||||
std::string src(csrc);
|
|
||||||
std::cout << src << std::endl;
|
|
||||||
Preprocessor cpp(&src, true);
|
|
||||||
// for (auto& def: defines)
|
|
||||||
// DefineMacro(cpp, def);
|
|
||||||
// for (auto& path: include_paths)
|
|
||||||
// cpp.AddSearchPath(path);
|
|
||||||
|
|
||||||
FILE* fp = stdout;
|
std::unique_ptr<ir::module> function::make_ir(Parser& parser) {
|
||||||
// if (specified_out_name) {
|
|
||||||
// fp = fopen(filename_out.c_str(), "w");
|
|
||||||
// }
|
|
||||||
TokenSequence ts;
|
|
||||||
cpp.Process(ts);
|
|
||||||
Parser parser(ts);
|
|
||||||
parser.Parse();
|
|
||||||
Generator gen(&parser);
|
|
||||||
ir::module out("", ctx_);
|
|
||||||
gen.Gen(&out);
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
|
|
||||||
// if (only_preprocess) {
|
|
||||||
// ts.Print(fp);
|
|
||||||
// return 0;
|
|
||||||
// }
|
|
||||||
|
|
||||||
YY_BUFFER_STATE buffer = yy_scan_string(csrc);
|
|
||||||
yyparse();
|
|
||||||
yy_delete_buffer(buffer);
|
|
||||||
triton::lang::translation_unit *program = ast_root;
|
|
||||||
return program;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<ir::module> function::make_ir(triton::lang::translation_unit *program) {
|
|
||||||
// create Triton-IR from AST
|
// create Triton-IR from AST
|
||||||
ir::module* module = new ir::module("", ctx_);
|
ir::module* module = new ir::module("", ctx_);
|
||||||
program->codegen(module);
|
Generator gen(&parser);
|
||||||
|
gen.Gen(module);
|
||||||
return std::unique_ptr<ir::module>(module);
|
return std::unique_ptr<ir::module>(module);
|
||||||
}
|
}
|
||||||
|
|
||||||
options function::autotune(lang::translation_unit *ast, driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector<arg>& args) {
|
options function::autotune(Parser& parser, driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector<arg>& args) {
|
||||||
std::unique_ptr<ir::module> ir = make_ir(ast);
|
std::unique_ptr<ir::module> ir = make_ir(parser);
|
||||||
// extract tunable values
|
// extract tunable values
|
||||||
std::vector<std::pair<std::string, ir::metaparameter*>> values;
|
std::vector<std::pair<std::string, ir::metaparameter*>> values;
|
||||||
for(auto it: ir->globals())
|
for(auto it: ir->globals())
|
||||||
@@ -186,7 +154,7 @@ options function::autotune(lang::translation_unit *ast, driver::stream* stream,
|
|||||||
for(auto it: values)
|
for(auto it: values)
|
||||||
opt.params[it.first] = params[i++];
|
opt.params[it.first] = params[i++];
|
||||||
// make binary
|
// make binary
|
||||||
auto ir = make_ir(ast);
|
auto ir = make_ir(parser);
|
||||||
auto bin = make_bin(*ir, stream->context(), opt);
|
auto bin = make_bin(*ir, stream->context(), opt);
|
||||||
// benchmark
|
// benchmark
|
||||||
ir::function *tmp = ir->get_function_list()[0];
|
ir::function *tmp = ir->get_function_list()[0];
|
||||||
@@ -242,6 +210,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
// generate llvm code
|
// generate llvm code
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||||
|
ir::print(module, std::cout);
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
selection.run(module, *llvm);
|
selection.run(module, *llvm);
|
||||||
// return binary
|
// return binary
|
||||||
std::unique_ptr<driver::module> res(driver::module::create(context, llvm.get()));
|
std::unique_ptr<driver::module> res(driver::module::create(context, llvm.get()));
|
||||||
@@ -249,9 +219,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
function::function(const std::string &src): src_(src) {
|
function::function(const std::string &src): parser_(ts_), src_(src){
|
||||||
// src -> ast
|
Preprocessor cpp(&src_, true);
|
||||||
ast_ = make_ast(src_.c_str());
|
cpp.Process(ts_);
|
||||||
|
ts_.Print();
|
||||||
|
parser_.Parse();
|
||||||
}
|
}
|
||||||
|
|
||||||
void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||||
@@ -277,8 +249,8 @@ void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_f
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* re-tune and re-compile */
|
/* re-tune and re-compile */
|
||||||
options opt = autotune(ast_, stream, grid_fn, args);
|
options opt = autotune(parser_, stream, grid_fn, args);
|
||||||
std::unique_ptr<ir::module> ir = make_ir(ast_);
|
std::unique_ptr<ir::module> ir = make_ir(parser_);
|
||||||
std::unique_ptr<driver::module> bin = make_bin(*ir, stream->context(), opt);
|
std::unique_ptr<driver::module> bin = make_bin(*ir, stream->context(), opt);
|
||||||
ir::function* fn = ir->get_function_list().front();
|
ir::function* fn = ir->get_function_list().front();
|
||||||
const caller& run = cache_.insert({key, cache_val_t{opt, caller(fn, std::move(bin), opt.num_warps*32)}}).first->second.second;
|
const caller& run = cache_.insert({key, cache_val_t{opt, caller(fn, std::move(bin), opt.num_warps*32)}}).first->second.second;
|
||||||
|
Reference in New Issue
Block a user