[PYTHON][EXAMPLES] Tentative support for einsum with transpositions

This commit is contained in:
Philippe Tillet
2019-10-25 19:01:21 -04:00
parent 8bd87fa19d
commit 76adcb755a
9 changed files with 137 additions and 46 deletions

View File

@@ -430,7 +430,6 @@ public:
void AddrOpTypeChecking();
void DerefOpTypeChecking();
void ReduceOpTypeChecking();
void TransOpTypeChecking();
void UnaryArithmOpTypeChecking();
void CastOpTypeChecking();
@@ -448,6 +447,28 @@ protected:
Expr* operand_;
};
class TransOp: public Expr {
friend class Generator;
public:
using PermInt = std::vector<int>;
public:
static TransOp* New(const PermInt& perm, Expr* operand);
const PermInt& getPerm() const { return perm_; }
void Accept(Visitor* v);
bool IsLVal() { return false; }
void TypeChecking();
protected:
TransOp(const PermInt& perm, Expr* operand)
: Expr(operand->Tok(), nullptr), operand_(operand), perm_(perm) {}
private:
Expr* operand_;
PermInt perm_;
};
// cond ? true false
class ConditionalOp : public Expr {

View File

@@ -58,6 +58,7 @@ public:
// Expression
void VisitBinaryOp(BinaryOp* binaryOp);
void VisitUnaryOp(UnaryOp* unaryOp);
void VisitTransOp(TransOp* transOp);
void VisitConditionalOp(ConditionalOp* condOp);
void VisitFuncCall(FuncCall* funcCall);
void VisitObject(Object* obj);
@@ -130,6 +131,7 @@ public:
void VisitConditionalOp(ConditionalOp*) { should_not_happen(); }
void VisitFuncCall(FuncCall*) { should_not_happen(); }
void VisitTransOp(TransOp*) { should_not_happen(); }
void VisitEnumerator(Enumerator*) { should_not_happen(); }
void VisitConstant(Constant*) { should_not_happen(); }
void VisitTempVar(TempVar*) { should_not_happen(); }

View File

@@ -30,6 +30,9 @@ public:
virtual void VisitIdentifier(Identifier* ident) {
Error(ident, "expect constant expression");
}
virtual void VisitTransOp(TransOp* trans) {
Error(trans, "expect constant expression");
}
virtual void VisitObject(Object* obj) {
Error(obj, "expect constant expression");
}
@@ -83,6 +86,9 @@ public:
virtual void VisitFuncCall(FuncCall* funcCall) {
Error(funcCall, "expect constant expression");
}
virtual void VisitTransOp(TransOp* trans) {
Error(trans, "expect constant expression");
}
virtual void VisitEnumerator(Enumerator* enumer) {
addr_.offset_ = enumer->Val();
}

View File

@@ -6,6 +6,7 @@
class BinaryOp;
class UnaryOp;
class TransOp;
class ConditionalOp;
class FuncCall;
class Identifier;
@@ -31,6 +32,7 @@ public:
virtual ~Visitor() {}
virtual void VisitBinaryOp(BinaryOp* binary) = 0;
virtual void VisitUnaryOp(UnaryOp* unary) = 0;
virtual void VisitTransOp(TransOp* trans) = 0;
virtual void VisitConditionalOp(ConditionalOp* cond) = 0;
virtual void VisitFuncCall(FuncCall* funcCall) = 0;
virtual void VisitEnumerator(Enumerator* enumer) = 0;

View File

@@ -7,6 +7,7 @@
static MemPoolImp<BinaryOp> binaryOpPool;
static MemPoolImp<TransOp> transOpPool;
static MemPoolImp<ConditionalOp> conditionalOpPool;
static MemPoolImp<FuncCall> funcCallPool;
static MemPoolImp<Declaration> initializationPool;
@@ -78,6 +79,9 @@ void UnaryOp::Accept(Visitor* v) {
v->VisitUnaryOp(this);
}
void TransOp::Accept(Visitor* v) {
v->VisitTransOp(this);
}
void ConditionalOp::Accept(Visitor* v) {
v->VisitConditionalOp(this);
@@ -645,9 +649,6 @@ void UnaryOp::TypeChecking() {
case Token::CAST:
return CastOpTypeChecking();
case '^':
return TransOpTypeChecking();
case Token::REDUCE:
return ReduceOpTypeChecking();
@@ -702,15 +703,6 @@ void UnaryOp::ReduceOpTypeChecking() {
type_ = TileType::New(shape, tileType->Derived());
}
void UnaryOp::TransOpTypeChecking() {
auto tileType = operand_->Type()->ToTile();
if(!tileType)
Error(this, "tile expected for transposition operator '^'");
auto shape = tileType->Shape();
std::rotate(shape.begin(), shape.begin() + 1, shape.end());
type_ = TileType::New(shape, tileType->Derived());
}
void UnaryOp::UnaryArithmOpTypeChecking() {
auto scalType = TryExtractScalarType(this, operand_);
if (Token::PLUS == op_ || Token::MINUS == op_) {
@@ -769,6 +761,29 @@ void UnaryOp::CastOpTypeChecking() {
}
}
/*
* Transposition Operator
*/
void TransOp::TypeChecking() {
auto tileType = operand_->Type()->ToTile();
if(!tileType)
Error(this, "tile expected for transposition operator '^'");
auto opShape = tileType->Shape();
if(perm_.size() != opShape.size())
Error(this, "invalid permutations");
// permutate input shape
TileType::ShapeInt resShape(opShape.size());
for(int d = 0; d < opShape.size(); d++)
resShape[d] = opShape[perm_[d]];
type_ = TileType::New(resShape, tileType->Derived());
}
TransOp* TransOp::New(const PermInt& perm, Expr* operand) {
auto ret = new (transOpPool.Alloc()) TransOp(perm, operand);
ret->pool_ = &transOpPool;
ret->TypeChecking();
return ret;
}
/*
* Conditional Operator

View File

@@ -185,7 +185,6 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
case '~': return set_ret(bld_->create_neg(arg));
case '!': return set_ret(bld_->create_not(arg));
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
case '^': return set_ret(bld_->create_trans(arg));
case Token::REDUCE: {
int ax, tag;
UnaryOp::decodeRed(unary->info_, ax, tag);
@@ -198,6 +197,12 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
return error_not_implemented();
}
void Generator::VisitTransOp(TransOp *trans) {
Visit(trans->operand_);
ir::value* arg = ret_;
return set_ret(bld_->create_trans(arg, trans->getPerm()));
}
void Generator::VisitConditionalOp(ConditionalOp* condOp) {
// auto &instructions = bld_->get_insert_block()->get_inst_list();
VisitExpr(condOp->cond_);

View File

@@ -451,6 +451,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
QualType lhsQual = lhsTile->Derived();
// create ret shape
TileType::ShapeInt shape;
TileType::ShapeInt axVec;
size_t i = 0;
const Token* tok;
std::vector<std::pair<int, int>> redInfo;
@@ -469,10 +470,22 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
case Token::SUB:
case Token::MAX:
case Token::MIN:{
int info = UnaryOp::encodeRed(i, tok->tag_);
redInfo.push_back({i, info});
shape.push_back(lhsShape[i++]);
break;
int info = UnaryOp::encodeRed(i, tok->tag_);
redInfo.push_back({i, info});
shape.push_back(lhsShape[i++]);
break;
}
case '^':{
Expr* expr = ParseConditionalExpr();
EnsureInteger(expr);
int ax = Evaluator<long>().Eval(expr);
axVec.push_back(ax);
if(ax < 0 || ax >= lhsShape.size())
Error(tok, "unknown axis %d in transposition", ax);
shape.push_back(lhsShape[ax]);
i++;
break;
}
default:
@@ -481,8 +494,19 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
}
}while(ts_.Try(','));
ts_.Expect(']');
// transposition mode
std::set<int> axSet(axVec.begin(), axVec.end());
if(!axSet.empty()){
if(axSet.size()!=lhsShape.size())
Error(tok, "transposition must address all axes of input array");
return TransOp::New(axVec, lhs);
}
// broadcasting mode
if(lhsShape.size() > i)
Error(tok, "broadcasting not using all operand axes");
// create ret tile
Expr* res = lhs;
for(auto r: redInfo){
@@ -553,7 +577,15 @@ Expr* Parser::ParseUnaryExpr() {
case '-': return ParseUnaryOp(tok, Token::MINUS);
case '~': return ParseUnaryOp(tok, '~');
case '!': return ParseUnaryOp(tok, '!');
case '^': return ParseUnaryOp(tok, Token::XOR);
case '^': {
auto operand = ParseCastExpr();
TileType::ShapeInt shape = operand->Type()->ToTile()->Shape();
TransOp::PermInt perm(shape.size());
for(int d = 0; d < shape.size(); d++)
perm[d] = d;
std::rotate(perm.begin(), perm.begin() + 1, perm.end());
return TransOp::New(perm, operand);
}
default:
ts_.PutBack();
return ParsePostfixExpr();

View File

@@ -8,41 +8,36 @@ class _einsum(triton.function):
int std_A0, int std_B0, int std_C0,
int std_A1, int std_B1, int std_C1) {
// program id
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pid2 = get_program_id(2);
int pgm = get_program_id(0);
int pgn = get_program_id(1);
int pgb = get_program_id(2);
// range
int rma[TM] = pid0 * TM + 0 ... TM;
int rnb[TN] = pid1 * TN + 0 ... TN;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
int rba[TB] = pid2 * TB + 0 ... TB;
int rbb[TB] = pid2 * TB + 0 ... TB;
int rm[TM] = pgm * TM + 0 ... TM;
int rn[TN] = pgn * TN + 0 ... TN;
int rb[TB] = pgb * TB + 0 ... TB;
int rk[TK] = 0 ... TK;
// accumulator
TYPE c[TM, TN, TB] = 0;
// pointers to a
TYPE *pa[TM, TK, TB] = A + rka[newaxis, :, newaxis] * 1
+ rma[:, newaxis, newaxis] * std_A1
+ rba[newaxis, newaxis, :] * std_A0;
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
+ rm[BROADCAST_AM] * STRIDE_AM
+ rb[newaxis, newaxis, :] * std_A0;
// pointers to b
TYPE *pb[TK, TN, TB] = B + rkb[:, newaxis, newaxis] * 1
+ rnb[newaxis, :, newaxis] * std_B1
+ rbb[newaxis, newaxis, :] * std_B0;
TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK
+ rn[BROADCAST_BN] * STRIDE_BN
+ rb[newaxis, newaxis, :] * std_B0;
// accumulation
for(int k = dim_K; k > 0; k -= TK) {
TYPE a[TM, TK, TB] = *pa;
TYPE b[TK, TN, TB] = *pb;
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
c += a @ b;
pa += TK;
pb += TK;
}
// write-back
int rmc[TM] = pid0 * TM + 0 ... TM;
int rnc[TN] = pid1 * TN + 0 ... TN;
int rbc[TB] = pid2 * TB + 0 ... TB;
TYPE *pc[TM, TN, TB] = C + rmc[:, newaxis, newaxis] * std_C1
+ rnc[newaxis, :, newaxis] * 1
+ rbc[newaxis, newaxis, :] * std_C0;
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
+ rn[newaxis, :, newaxis] * 1
+ rb[newaxis, newaxis, :] * std_C0;
*pc = c;
}
"""
@@ -138,12 +133,25 @@ class _einsum(triton.function):
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
triton.cdiv(bmnk[2], opt.d('TN')),
triton.cdiv(bmnk[0], opt.d('TB'))]
#print(std0, std1)
macros = {# handle A transposition
'USE_A' : 'a[^1, ^0, ^2]' if trans_a else 'a',
'STRIDE_AK' : 'std_A1' if trans_a else '1',
'STRIDE_AM' : '1' if trans_a else 'std_A1',
'BROADCAST_AK': ':, newaxis, newaxis' if trans_a else 'newaxis, :, newaxis',
'BROADCAST_AM': 'newaxis, :, newaxis' if trans_a else ':, newaxis, newaxis',
'SHAPE_A' : 'TK, TM, TB' if trans_a else 'TM, TK, TB',
# handle B transposition
'USE_B' : 'b[^1, ^0, ^2]' if not trans_b else 'b',
'STRIDE_BK' : 'std_B1' if not trans_b else '1',
'STRIDE_BN' : '1' if not trans_b else 'std_B1',
'BROADCAST_BK': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
'BROADCAST_BN': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
'SHAPE_B' : 'TN, TK, TB' if not trans_b else 'TK, TN, TB'}
return _einsum.kernel(a, b, c,
bmnk[1], bmnk[2], bmnk[3],
std0[0], std0[1], std0[2],
std1[0], std1[1], std1[2],
grid,
grid, **macros,
TYPE='float', TM=32, TN=32, TK=8, TB=1)

View File

@@ -86,14 +86,14 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
// macros
rt::function::options_space_t opt;
// A access patterns
opt.defines.push_back({"USEA", {AT? "^a" : "a" }});
opt.defines.push_back({"USEA", {AT? "a[^1, ^0]" : "a" }});
opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }});
opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }});
opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }});
opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
// B access patterns
opt.defines.push_back({"USEB", {BT? "^b" : "b" }});
opt.defines.push_back({"USEB", {BT? "b[^1, ^0]" : "b" }});
opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }});
opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }});
opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }});