[GENERAL] Merged einsum feature branch. Various feature, performance
improvements and bugfixes: * Added preliminary support for extended Einstein summation in PyTriton * Significant performance improvement on FP32 kernels containing matrix multiplication * Added re-coalescing pass for FP16 kernels containing matrix multiplication * Various bugfixes
This commit is contained in:
@@ -5,24 +5,41 @@
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
inline int gcd(int a, int b) {
|
||||
if (a == 0)
|
||||
return b;
|
||||
if (b == 0)
|
||||
return a;
|
||||
if (a == b)
|
||||
return a;
|
||||
if (a > b)
|
||||
return gcd(a - b, b);
|
||||
return gcd(a, b - a);
|
||||
// Function for extended Euclidean Algorithm
|
||||
int gcd_impl(int a, int b, int *x, int *y)
|
||||
{
|
||||
// Base Case
|
||||
if (a == 0)
|
||||
{
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*y = x1;
|
||||
|
||||
return gcd;
|
||||
}
|
||||
|
||||
int gcd(int a, int b) {
|
||||
int x, y;
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
inline int lcm(int a, int b) {
|
||||
return (a * b) / gcd(a, b);
|
||||
}
|
||||
@@ -156,7 +173,7 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
if(is_constant_.find(v) != is_constant_.end())
|
||||
return is_constant_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_);
|
||||
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
@@ -448,7 +465,7 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_starting_multiple_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_value()}, starting_multiple_);
|
||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range_dyn*>(v))
|
||||
@@ -484,6 +501,7 @@ void align::populate(ir::value *v) {
|
||||
populate_is_constant(v);
|
||||
populate_starting_multiple(v);
|
||||
populate_max_contiguous(v);
|
||||
|
||||
}
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
|
Reference in New Issue
Block a user