[lang] added support for restrict; added macros for attributes

This commit is contained in:
Philippe Tillet
2019-08-23 20:29:12 -07:00
parent 8c6bac49d1
commit 44eb3891ae
4 changed files with 46 additions and 13 deletions

View File

@@ -82,14 +82,21 @@ R"(
#define true 1 #define true 1
#define false 0 #define false 0
#define __bool_true_false_are_defined 1 #define __bool_true_false_are_defined 1
#define __readonly __attribute__((readonly))
#define __writeonly __attribute__((writeonly))
#define __noalias __attribute__((noalias))
#define __aligned(A) __attribute__((aligned(A)))
#define __multipleof(A) __attribute__((multipleof(A)))
extern int get_program_id(int); extern int get_program_id(int);
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))), void matmul()" + a_ty + R"( * A __noalias __readonly __aligned(16),
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))), )" + b_ty + R"( * B __noalias __readonly __aligned(16),
restrict )" + c_ty + R"( * C __attribute__((aligned(16))), )" + c_ty + R"( * C __noalias __readonly __aligned(16),
int M, int N, int K, int M, int N, int K,
int lda __attribute__((multiple_of(8))), int lda __multipleof(8),
int ldb __attribute__((multiple_of(8))), int ldb __multipleof(8),
int ldc) { int ldc) {
int ridx = get_program_id(0); int ridx = get_program_id(0);
int ridy = get_program_id(1); int ridy = get_program_id(1);

View File

@@ -58,7 +58,16 @@ class TranslationUnit;
class ASTNode { class ASTNode {
public: public:
struct Attr{ struct Attr{
std::string name;
enum KindT{
MULTIPLEOF,
ALIGNED,
NOALIAS,
READONLY,
WRITEONLY
};
KindT kind;
std::vector<Expr*> vals; std::vector<Expr*> vals;
}; };
using AttrList = std::vector<Attr>; using AttrList = std::vector<Attr>;

View File

@@ -356,6 +356,8 @@ void Generator::VisitFuncDef(FuncDef* funcDef) {
args[i]->set_name(name); args[i]->set_name(name);
for(ASTNode::Attr attr: obj->GetAttrList()) for(ASTNode::Attr attr: obj->GetAttrList())
fn->add_attr(i, GenIRAttr(attr)); fn->add_attr(i, GenIRAttr(attr));
if(obj->IsRestrictQualified())
fn->add_attr(i, ir::attribute(ir::noalias));
mod_->set_value(name, nullptr, args[i]); mod_->set_value(name, nullptr, args[i]);
mod_->get_scope().types[name] = args[i]->get_type(); mod_->get_scope().types[name] = args[i]->get_type();
i++; i++;
@@ -440,22 +442,22 @@ ir::value* Generator::GenCastOp(ir::value* src, ir::type* dst_ty) {
// Triton-IR Attr // Triton-IR Attr
ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) { ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
if(attr.name == "multiple_of") { if(attr.kind == ASTNode::Attr::MULTIPLEOF) {
VisitExpr(attr.vals[0]); VisitExpr(attr.vals[0]);
auto cst = dynamic_cast<ir::constant_int*>(ret_); auto cst = dynamic_cast<ir::constant_int*>(ret_);
if(!cst) should_not_happen(); if(!cst) should_not_happen();
return ir::attribute(ir::multiple_of, cst->get_value()); return ir::attribute(ir::multiple_of, cst->get_value());
} }
if(attr.name == "aligned") { if(attr.kind == ASTNode::Attr::ALIGNED) {
VisitExpr(attr.vals[0]); VisitExpr(attr.vals[0]);
auto cst = dynamic_cast<ir::constant_int*>(ret_); auto cst = dynamic_cast<ir::constant_int*>(ret_);
return ir::attribute(ir::aligned, cst->get_value()); return ir::attribute(ir::aligned, cst->get_value());
} }
if(attr.name == "noalias") if(attr.kind == ASTNode::Attr::NOALIAS)
return ir::attribute(ir::noalias); return ir::attribute(ir::noalias);
if(attr.name == "readonly") if(attr.kind == ASTNode::Attr::READONLY)
return ir::attribute(ir::readonly); return ir::attribute(ir::readonly);
if(attr.name == "writeonly") if(attr.kind == ASTNode::Attr::WRITEONLY)
return ir::attribute(ir::writeonly); return ir::attribute(ir::writeonly);
should_not_happen(); should_not_happen();
} }

View File

@@ -1806,7 +1806,8 @@ Object* Parser::ParseParamDecl() {
auto type = ParseDeclSpec(&storageSpec, &funcSpec, nullptr); auto type = ParseDeclSpec(&storageSpec, &funcSpec, nullptr);
auto tokTypePair = ParseDeclarator(type); auto tokTypePair = ParseDeclarator(type);
auto tok = tokTypePair.tok; auto tok = tokTypePair.tok;
type = Type::MayCast(tokTypePair.type, true); QualType fullType(tokTypePair.type.GetPtr(), type.Qual());
type = Type::MayCast(fullType, true);
auto attrs = tokTypePair.attrs; auto attrs = tokTypePair.attrs;
if (!tok) { // Abstract declarator if (!tok) { // Abstract declarator
return Object::NewAnony(ts_.Peek(), type, 0, Linkage::L_NONE); return Object::NewAnony(ts_.Peek(), type, 0, Linkage::L_NONE);
@@ -2692,7 +2693,21 @@ ASTNode::Attr Parser::ParseAttribute() {
if (!ts_.Test(Token::IDENTIFIER)) if (!ts_.Test(Token::IDENTIFIER))
return ret; return ret;
auto tok = ts_.Next(); auto tok = ts_.Next();
ret.name = tok->str_; std::string name = tok->str_;
// set kind
if(name == "aligned")
ret.kind = ASTNode::Attr::ALIGNED;
else if(name == "readonly")
ret.kind = ASTNode::Attr::READONLY;
else if(name == "writeonly")
ret.kind = ASTNode::Attr::WRITEONLY;
else if(name == "multipleof")
ret.kind = ASTNode::Attr::MULTIPLEOF;
else if(name == "noalias")
ret.kind = ASTNode::Attr::NOALIAS;
else
Error(tok, "unknown attribute kind");
// set exprs
if (ts_.Try('(')) { if (ts_.Try('(')) {
if (ts_.Try(')')) if (ts_.Try(')'))
return ret; return ret;