[PYTHON][EXAMPLES] Tentative support for einsum with transpositions
This commit is contained in:
@@ -430,7 +430,6 @@ public:
|
||||
void AddrOpTypeChecking();
|
||||
void DerefOpTypeChecking();
|
||||
void ReduceOpTypeChecking();
|
||||
void TransOpTypeChecking();
|
||||
void UnaryArithmOpTypeChecking();
|
||||
void CastOpTypeChecking();
|
||||
|
||||
@@ -448,6 +447,28 @@ protected:
|
||||
Expr* operand_;
|
||||
};
|
||||
|
||||
class TransOp: public Expr {
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
using PermInt = std::vector<int>;
|
||||
|
||||
public:
|
||||
static TransOp* New(const PermInt& perm, Expr* operand);
|
||||
const PermInt& getPerm() const { return perm_; }
|
||||
void Accept(Visitor* v);
|
||||
bool IsLVal() { return false; }
|
||||
void TypeChecking();
|
||||
|
||||
protected:
|
||||
TransOp(const PermInt& perm, Expr* operand)
|
||||
: Expr(operand->Tok(), nullptr), operand_(operand), perm_(perm) {}
|
||||
|
||||
private:
|
||||
Expr* operand_;
|
||||
PermInt perm_;
|
||||
};
|
||||
|
||||
|
||||
// cond ? true : false
|
||||
class ConditionalOp : public Expr {
|
||||
|
@@ -58,6 +58,7 @@ public:
|
||||
// Expression
|
||||
void VisitBinaryOp(BinaryOp* binaryOp);
|
||||
void VisitUnaryOp(UnaryOp* unaryOp);
|
||||
void VisitTransOp(TransOp* transOp);
|
||||
void VisitConditionalOp(ConditionalOp* condOp);
|
||||
void VisitFuncCall(FuncCall* funcCall);
|
||||
void VisitObject(Object* obj);
|
||||
@@ -130,6 +131,7 @@ public:
|
||||
|
||||
void VisitConditionalOp(ConditionalOp*) { should_not_happen(); }
|
||||
void VisitFuncCall(FuncCall*) { should_not_happen(); }
|
||||
void VisitTransOp(TransOp*) { should_not_happen(); }
|
||||
void VisitEnumerator(Enumerator*) { should_not_happen(); }
|
||||
void VisitConstant(Constant*) { should_not_happen(); }
|
||||
void VisitTempVar(TempVar*) { should_not_happen(); }
|
||||
|
@@ -30,6 +30,9 @@ public:
|
||||
virtual void VisitIdentifier(Identifier* ident) {
|
||||
Error(ident, "expect constant expression");
|
||||
}
|
||||
virtual void VisitTransOp(TransOp* trans) {
|
||||
Error(trans, "expect constant expression");
|
||||
}
|
||||
virtual void VisitObject(Object* obj) {
|
||||
Error(obj, "expect constant expression");
|
||||
}
|
||||
@@ -83,6 +86,9 @@ public:
|
||||
virtual void VisitFuncCall(FuncCall* funcCall) {
|
||||
Error(funcCall, "expect constant expression");
|
||||
}
|
||||
virtual void VisitTransOp(TransOp* trans) {
|
||||
Error(trans, "expect constant expression");
|
||||
}
|
||||
virtual void VisitEnumerator(Enumerator* enumer) {
|
||||
addr_.offset_ = enumer->Val();
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@
|
||||
|
||||
class BinaryOp;
|
||||
class UnaryOp;
|
||||
class TransOp;
|
||||
class ConditionalOp;
|
||||
class FuncCall;
|
||||
class Identifier;
|
||||
@@ -31,6 +32,7 @@ public:
|
||||
virtual ~Visitor() {}
|
||||
virtual void VisitBinaryOp(BinaryOp* binary) = 0;
|
||||
virtual void VisitUnaryOp(UnaryOp* unary) = 0;
|
||||
virtual void VisitTransOp(TransOp* trans) = 0;
|
||||
virtual void VisitConditionalOp(ConditionalOp* cond) = 0;
|
||||
virtual void VisitFuncCall(FuncCall* funcCall) = 0;
|
||||
virtual void VisitEnumerator(Enumerator* enumer) = 0;
|
||||
|
@@ -7,6 +7,7 @@
|
||||
|
||||
|
||||
static MemPoolImp<BinaryOp> binaryOpPool;
|
||||
static MemPoolImp<TransOp> transOpPool;
|
||||
static MemPoolImp<ConditionalOp> conditionalOpPool;
|
||||
static MemPoolImp<FuncCall> funcCallPool;
|
||||
static MemPoolImp<Declaration> initializationPool;
|
||||
@@ -78,6 +79,9 @@ void UnaryOp::Accept(Visitor* v) {
|
||||
v->VisitUnaryOp(this);
|
||||
}
|
||||
|
||||
void TransOp::Accept(Visitor* v) {
|
||||
v->VisitTransOp(this);
|
||||
}
|
||||
|
||||
void ConditionalOp::Accept(Visitor* v) {
|
||||
v->VisitConditionalOp(this);
|
||||
@@ -645,9 +649,6 @@ void UnaryOp::TypeChecking() {
|
||||
case Token::CAST:
|
||||
return CastOpTypeChecking();
|
||||
|
||||
case '^':
|
||||
return TransOpTypeChecking();
|
||||
|
||||
case Token::REDUCE:
|
||||
return ReduceOpTypeChecking();
|
||||
|
||||
@@ -702,15 +703,6 @@ void UnaryOp::ReduceOpTypeChecking() {
|
||||
type_ = TileType::New(shape, tileType->Derived());
|
||||
}
|
||||
|
||||
void UnaryOp::TransOpTypeChecking() {
|
||||
auto tileType = operand_->Type()->ToTile();
|
||||
if(!tileType)
|
||||
Error(this, "tile expected for transposition operator '^'");
|
||||
auto shape = tileType->Shape();
|
||||
std::rotate(shape.begin(), shape.begin() + 1, shape.end());
|
||||
type_ = TileType::New(shape, tileType->Derived());
|
||||
}
|
||||
|
||||
void UnaryOp::UnaryArithmOpTypeChecking() {
|
||||
auto scalType = TryExtractScalarType(this, operand_);
|
||||
if (Token::PLUS == op_ || Token::MINUS == op_) {
|
||||
@@ -769,6 +761,29 @@ void UnaryOp::CastOpTypeChecking() {
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Transposition Operator
|
||||
*/
|
||||
void TransOp::TypeChecking() {
|
||||
auto tileType = operand_->Type()->ToTile();
|
||||
if(!tileType)
|
||||
Error(this, "tile expected for transposition operator '^'");
|
||||
auto opShape = tileType->Shape();
|
||||
if(perm_.size() != opShape.size())
|
||||
Error(this, "invalid permutations");
|
||||
// permutate input shape
|
||||
TileType::ShapeInt resShape(opShape.size());
|
||||
for(int d = 0; d < opShape.size(); d++)
|
||||
resShape[d] = opShape[perm_[d]];
|
||||
type_ = TileType::New(resShape, tileType->Derived());
|
||||
}
|
||||
|
||||
TransOp* TransOp::New(const PermInt& perm, Expr* operand) {
|
||||
auto ret = new (transOpPool.Alloc()) TransOp(perm, operand);
|
||||
ret->pool_ = &transOpPool;
|
||||
ret->TypeChecking();
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Conditional Operator
|
||||
|
@@ -185,7 +185,6 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
case '~': return set_ret(bld_->create_neg(arg));
|
||||
case '!': return set_ret(bld_->create_not(arg));
|
||||
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||
case '^': return set_ret(bld_->create_trans(arg));
|
||||
case Token::REDUCE: {
|
||||
int ax, tag;
|
||||
UnaryOp::decodeRed(unary->info_, ax, tag);
|
||||
@@ -198,6 +197,12 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
return error_not_implemented();
|
||||
}
|
||||
|
||||
void Generator::VisitTransOp(TransOp *trans) {
|
||||
Visit(trans->operand_);
|
||||
ir::value* arg = ret_;
|
||||
return set_ret(bld_->create_trans(arg, trans->getPerm()));
|
||||
}
|
||||
|
||||
void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
||||
// auto &instructions = bld_->get_insert_block()->get_inst_list();
|
||||
VisitExpr(condOp->cond_);
|
||||
|
@@ -451,6 +451,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
QualType lhsQual = lhsTile->Derived();
|
||||
// create ret shape
|
||||
TileType::ShapeInt shape;
|
||||
TileType::ShapeInt axVec;
|
||||
size_t i = 0;
|
||||
const Token* tok;
|
||||
std::vector<std::pair<int, int>> redInfo;
|
||||
@@ -469,10 +470,22 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
case Token::SUB:
|
||||
case Token::MAX:
|
||||
case Token::MIN:{
|
||||
int info = UnaryOp::encodeRed(i, tok->tag_);
|
||||
redInfo.push_back({i, info});
|
||||
shape.push_back(lhsShape[i++]);
|
||||
break;
|
||||
int info = UnaryOp::encodeRed(i, tok->tag_);
|
||||
redInfo.push_back({i, info});
|
||||
shape.push_back(lhsShape[i++]);
|
||||
break;
|
||||
}
|
||||
|
||||
case '^':{
|
||||
Expr* expr = ParseConditionalExpr();
|
||||
EnsureInteger(expr);
|
||||
int ax = Evaluator<long>().Eval(expr);
|
||||
axVec.push_back(ax);
|
||||
if(ax < 0 || ax >= lhsShape.size())
|
||||
Error(tok, "unknown axis %d in transposition", ax);
|
||||
shape.push_back(lhsShape[ax]);
|
||||
i++;
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
@@ -481,8 +494,19 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
}
|
||||
}while(ts_.Try(','));
|
||||
ts_.Expect(']');
|
||||
|
||||
// transposition mode
|
||||
std::set<int> axSet(axVec.begin(), axVec.end());
|
||||
if(!axSet.empty()){
|
||||
if(axSet.size()!=lhsShape.size())
|
||||
Error(tok, "transposition must address all axes of input array");
|
||||
return TransOp::New(axVec, lhs);
|
||||
}
|
||||
|
||||
// broadcasting mode
|
||||
if(lhsShape.size() > i)
|
||||
Error(tok, "broadcasting not using all operand axes");
|
||||
|
||||
// create ret tile
|
||||
Expr* res = lhs;
|
||||
for(auto r: redInfo){
|
||||
@@ -553,7 +577,15 @@ Expr* Parser::ParseUnaryExpr() {
|
||||
case '-': return ParseUnaryOp(tok, Token::MINUS);
|
||||
case '~': return ParseUnaryOp(tok, '~');
|
||||
case '!': return ParseUnaryOp(tok, '!');
|
||||
case '^': return ParseUnaryOp(tok, Token::XOR);
|
||||
case '^': {
|
||||
auto operand = ParseCastExpr();
|
||||
TileType::ShapeInt shape = operand->Type()->ToTile()->Shape();
|
||||
TransOp::PermInt perm(shape.size());
|
||||
for(int d = 0; d < shape.size(); d++)
|
||||
perm[d] = d;
|
||||
std::rotate(perm.begin(), perm.begin() + 1, perm.end());
|
||||
return TransOp::New(perm, operand);
|
||||
}
|
||||
default:
|
||||
ts_.PutBack();
|
||||
return ParsePostfixExpr();
|
||||
|
@@ -8,41 +8,36 @@ class _einsum(triton.function):
|
||||
int std_A0, int std_B0, int std_C0,
|
||||
int std_A1, int std_B1, int std_C1) {
|
||||
// program id
|
||||
int pid0 = get_program_id(0);
|
||||
int pid1 = get_program_id(1);
|
||||
int pid2 = get_program_id(2);
|
||||
int pgm = get_program_id(0);
|
||||
int pgn = get_program_id(1);
|
||||
int pgb = get_program_id(2);
|
||||
// range
|
||||
int rma[TM] = pid0 * TM + 0 ... TM;
|
||||
int rnb[TN] = pid1 * TN + 0 ... TN;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
int rba[TB] = pid2 * TB + 0 ... TB;
|
||||
int rbb[TB] = pid2 * TB + 0 ... TB;
|
||||
int rm[TM] = pgm * TM + 0 ... TM;
|
||||
int rn[TN] = pgn * TN + 0 ... TN;
|
||||
int rb[TB] = pgb * TB + 0 ... TB;
|
||||
int rk[TK] = 0 ... TK;
|
||||
// accumulator
|
||||
TYPE c[TM, TN, TB] = 0;
|
||||
// pointers to a
|
||||
TYPE *pa[TM, TK, TB] = A + rka[newaxis, :, newaxis] * 1
|
||||
+ rma[:, newaxis, newaxis] * std_A1
|
||||
+ rba[newaxis, newaxis, :] * std_A0;
|
||||
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
|
||||
+ rm[BROADCAST_AM] * STRIDE_AM
|
||||
+ rb[newaxis, newaxis, :] * std_A0;
|
||||
// pointers to b
|
||||
TYPE *pb[TK, TN, TB] = B + rkb[:, newaxis, newaxis] * 1
|
||||
+ rnb[newaxis, :, newaxis] * std_B1
|
||||
+ rbb[newaxis, newaxis, :] * std_B0;
|
||||
TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK
|
||||
+ rn[BROADCAST_BN] * STRIDE_BN
|
||||
+ rb[newaxis, newaxis, :] * std_B0;
|
||||
// accumulation
|
||||
for(int k = dim_K; k > 0; k -= TK) {
|
||||
TYPE a[TM, TK, TB] = *pa;
|
||||
TYPE b[TK, TN, TB] = *pb;
|
||||
TYPE a[SHAPE_A] = *pa;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
c += a @ b;
|
||||
pa += TK;
|
||||
pb += TK;
|
||||
}
|
||||
// write-back
|
||||
int rmc[TM] = pid0 * TM + 0 ... TM;
|
||||
int rnc[TN] = pid1 * TN + 0 ... TN;
|
||||
int rbc[TB] = pid2 * TB + 0 ... TB;
|
||||
TYPE *pc[TM, TN, TB] = C + rmc[:, newaxis, newaxis] * std_C1
|
||||
+ rnc[newaxis, :, newaxis] * 1
|
||||
+ rbc[newaxis, newaxis, :] * std_C0;
|
||||
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
|
||||
+ rn[newaxis, :, newaxis] * 1
|
||||
+ rb[newaxis, newaxis, :] * std_C0;
|
||||
*pc = c;
|
||||
}
|
||||
"""
|
||||
@@ -138,12 +133,25 @@ class _einsum(triton.function):
|
||||
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
|
||||
triton.cdiv(bmnk[2], opt.d('TN')),
|
||||
triton.cdiv(bmnk[0], opt.d('TB'))]
|
||||
#print(std0, std1)
|
||||
macros = {# handle A transposition
|
||||
'USE_A' : 'a[^1, ^0, ^2]' if trans_a else 'a',
|
||||
'STRIDE_AK' : 'std_A1' if trans_a else '1',
|
||||
'STRIDE_AM' : '1' if trans_a else 'std_A1',
|
||||
'BROADCAST_AK': ':, newaxis, newaxis' if trans_a else 'newaxis, :, newaxis',
|
||||
'BROADCAST_AM': 'newaxis, :, newaxis' if trans_a else ':, newaxis, newaxis',
|
||||
'SHAPE_A' : 'TK, TM, TB' if trans_a else 'TM, TK, TB',
|
||||
# handle B transposition
|
||||
'USE_B' : 'b[^1, ^0, ^2]' if not trans_b else 'b',
|
||||
'STRIDE_BK' : 'std_B1' if not trans_b else '1',
|
||||
'STRIDE_BN' : '1' if not trans_b else 'std_B1',
|
||||
'BROADCAST_BK': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
|
||||
'BROADCAST_BN': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
|
||||
'SHAPE_B' : 'TN, TK, TB' if not trans_b else 'TK, TN, TB'}
|
||||
return _einsum.kernel(a, b, c,
|
||||
bmnk[1], bmnk[2], bmnk[3],
|
||||
std0[0], std0[1], std0[2],
|
||||
std1[0], std1[1], std1[2],
|
||||
grid,
|
||||
grid, **macros,
|
||||
TYPE='float', TM=32, TN=32, TK=8, TB=1)
|
||||
|
||||
|
||||
|
@@ -86,14 +86,14 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
// macros
|
||||
rt::function::options_space_t opt;
|
||||
// A access patterns
|
||||
opt.defines.push_back({"USEA", {AT? "^a" : "a" }});
|
||||
opt.defines.push_back({"USEA", {AT? "a[^1, ^0]" : "a" }});
|
||||
opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }});
|
||||
opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }});
|
||||
opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }});
|
||||
opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
|
||||
opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
|
||||
// B access patterns
|
||||
opt.defines.push_back({"USEB", {BT? "^b" : "b" }});
|
||||
opt.defines.push_back({"USEB", {BT? "b[^1, ^0]" : "b" }});
|
||||
opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }});
|
||||
opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }});
|
||||
opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }});
|
||||
|
Reference in New Issue
Block a user