@@ -168,7 +168,7 @@ Scheduling languages are, without a doubt, one of the most popular approaches fo
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
|
||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indices without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
@@ -44,7 +44,7 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
||||
if (pessimistic) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
// Join all latice elements
|
||||
// Join all lattice elements
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
for (Value value : op->getResults()) {
|
||||
result |= getLatticeElement(value).join(aliasInfo);
|
||||
|
@@ -235,7 +235,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
/// Extends the liveness range by unioning the liveness range of the aliased
|
||||
/// Extends the liveness range by unionizing the liveness range of the aliased
|
||||
/// values because each allocated buffer could be an alias of others, if block
|
||||
/// arguments are involved.
|
||||
void resolveAliasBufferLiveness(
|
||||
|
@@ -372,7 +372,7 @@ unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
|
||||
unsigned
|
||||
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
assert(0 && "DotOPerandEncodingAttr::getElemsPerThread not implemented");
|
||||
assert(0 && "DotOperandEncodingAttr::getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@@ -24,7 +24,7 @@ struct CanonicalizePass
|
||||
// The following piece of code is a workaround to
|
||||
// very crudely remove dead code, by making an iteration
|
||||
// argument yield itself if it is not used to create
|
||||
// side-effects anywhere.
|
||||
// side effects anywhere.
|
||||
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
||||
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
|
||||
// condition 1: no other iter arguments depend on it
|
||||
|
@@ -29,7 +29,7 @@ namespace {
|
||||
// convert(blocked, dot_operand) ->
|
||||
// convert(blocked, mma) + convert(mma, dot_operand)
|
||||
// if this value is itself the result of a dot operation
|
||||
// this is a heuristic to accomodate some pattern seen in fused attention
|
||||
// this is a heuristic to accommodate some pattern seen in fused attention
|
||||
// kernels.
|
||||
// TODO: replace this by something more generic, i.e. layout-aware CSE
|
||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||
@@ -81,7 +81,7 @@ public:
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accomodate fused attention
|
||||
// this is a heuristics to accommodate fused attention
|
||||
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
// return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
@@ -265,7 +265,7 @@ public:
|
||||
isSharedEncoding(cvt->getOperand(0)))
|
||||
return mlir::failure();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accomodate fused attention
|
||||
// this is a heuristics to accommodate fused attention
|
||||
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
|
||||
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return mlir::failure();
|
||||
@@ -285,7 +285,7 @@ public:
|
||||
// we stop everything
|
||||
if (expensive_to_remat(currOp))
|
||||
break;
|
||||
// a conversion will be removed here (i.e. transfered to operands)
|
||||
// a conversion will be removed here (i.e. transferred to operands)
|
||||
numCvts -= 1;
|
||||
// done processing
|
||||
processed.insert(currOp);
|
||||
|
@@ -110,7 +110,7 @@ Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
|
||||
}
|
||||
|
||||
void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
// Loop-invarant value. skip
|
||||
// Loop-invariant value, skip
|
||||
if (v.getParentRegion() != &forOp.getLoopBody())
|
||||
return;
|
||||
|
||||
@@ -125,7 +125,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
|
||||
} else { // value
|
||||
// v might be in deps, but we still need to visit v.
|
||||
// This is because v might depends on value in previous iterations
|
||||
// This is because v might depend on value in previous iterations
|
||||
deps.insert(v);
|
||||
for (Value op : v.getDefiningOp()->getOperands())
|
||||
collectDeps(op, stages, deps);
|
||||
@@ -175,18 +175,18 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
// other load in the prologue, which is against the point of the pipeline
|
||||
// pass)
|
||||
for (triton::LoadOp loadOp : allLoads) {
|
||||
bool isCandiate = true;
|
||||
bool isCandidate = true;
|
||||
for (triton::LoadOp other : allLoads) {
|
||||
if (loadDeps[loadOp].contains(other)) {
|
||||
isCandiate = false;
|
||||
isCandidate = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// We only pipeline loads that have one covert_layout (to dot_op) use
|
||||
// TODO: lift this constraint in the future
|
||||
if (isCandiate && loadOp.getResult().hasOneUse()) {
|
||||
isCandiate = false;
|
||||
if (isCandidate && loadOp.getResult().hasOneUse()) {
|
||||
isCandidate = false;
|
||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType = convertLayout.getResult()
|
||||
@@ -194,7 +194,7 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
.dyn_cast<RankedTensorType>()) {
|
||||
if (auto dotOpEnc = tensorType.getEncoding()
|
||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
isCandiate = true;
|
||||
isCandidate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
@@ -208,9 +208,9 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
}
|
||||
}
|
||||
} else
|
||||
isCandiate = false;
|
||||
isCandidate = false;
|
||||
|
||||
if (isCandiate)
|
||||
if (isCandidate)
|
||||
loads.insert(loadOp);
|
||||
}
|
||||
|
||||
@@ -317,7 +317,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
Value originalResult = op->getResult(dstIdx);
|
||||
// copy_async will update the value of its only use
|
||||
// TODO: load should no be used in the preheader?
|
||||
// TODO: load should not be used in the preheader?
|
||||
if (loads.contains(originalResult)) {
|
||||
break;
|
||||
// originalResult = loadsMapping[originalResult];
|
||||
|
@@ -35,7 +35,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
});
|
||||
|
||||
//
|
||||
// materailizations
|
||||
// Materializations
|
||||
//
|
||||
// This will be called when (newArgType != origArgType)
|
||||
// This will create newArg, and map(origArg, newArg)
|
||||
|
@@ -172,7 +172,7 @@ void init_triton_ir(py::module &&m) {
|
||||
if (mlir::Operation *definingOp = self.getDefiningOp())
|
||||
definingOp->setAttr(name, attr);
|
||||
else {
|
||||
/* issue an warning */
|
||||
/* issue a warning */
|
||||
}
|
||||
})
|
||||
.def("replace_all_uses_with",
|
||||
@@ -180,7 +180,7 @@ void init_triton_ir(py::module &&m) {
|
||||
self.replaceAllUsesWith(newValue);
|
||||
});
|
||||
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_argument");
|
||||
|
||||
py::class_<mlir::Region>(m, "region")
|
||||
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
|
||||
@@ -288,7 +288,7 @@ void init_triton_ir(py::module &&m) {
|
||||
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
|
||||
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
|
||||
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
|
||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
|
||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "ConditionOp");
|
||||
|
||||
// dynamic_attr is used to transfer ownership of the MLIR context to the
|
||||
// module
|
||||
@@ -423,7 +423,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_bool_attr", &mlir::OpBuilder::getBoolAttr)
|
||||
.def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr)
|
||||
// Use arith.ConstantOp to create constants
|
||||
// // Constants
|
||||
// Constants
|
||||
.def("get_int1",
|
||||
[](mlir::OpBuilder &self, bool v) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
@@ -588,14 +588,14 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::WhileOp>(loc, retTypes, initArgs);
|
||||
})
|
||||
.def("create_condtion_op",
|
||||
.def("create_condition_op",
|
||||
[](mlir::OpBuilder &self, mlir::Value &cond,
|
||||
std::vector<mlir::Value> &args) -> mlir::scf::ConditionOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::ConditionOp>(loc, cond, args);
|
||||
})
|
||||
|
||||
// miscellious
|
||||
// miscellaneous
|
||||
.def("create_make_range",
|
||||
[](mlir::OpBuilder &self, int start, int end) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
@@ -976,15 +976,15 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
|
||||
})
|
||||
// // Input/Output
|
||||
// Input/Output
|
||||
.def("create_load",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs,
|
||||
mlir::triton::CacheModifier cacheModifer,
|
||||
mlir::triton::CacheModifier cacheModifier,
|
||||
mlir::triton::EvictionPolicy evictionPolicy,
|
||||
bool isVolatile) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, ptrs, cacheModifer, evictionPolicy, isVolatile);
|
||||
loc, ptrs, cacheModifier, evictionPolicy, isVolatile);
|
||||
})
|
||||
.def("create_store",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs,
|
||||
|
@@ -131,7 +131,7 @@ def vecadd_no_scf_tester(num_warps, block_size, shape):
|
||||
|
||||
def vecadd_fcmp_no_scf_tester(num_warps, block_size, shape):
|
||||
'''
|
||||
vecadd tester with float comparation as load/store mask.
|
||||
vecadd tester with float comparison as load/store mask.
|
||||
'''
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
|
@@ -523,8 +523,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
cond_block.merge_block_before(before_block)
|
||||
self.builder.set_insertion_point_to_end(before_block)
|
||||
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
||||
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
||||
# create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
||||
self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
||||
# merge the loop body
|
||||
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
@@ -910,7 +910,7 @@ def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = Non
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return:
|
||||
- PTX code
|
||||
- shared memory alloaction size
|
||||
- shared memory allocation size
|
||||
'''
|
||||
if compute_capability is None:
|
||||
device = torch.cuda.current_device()
|
||||
@@ -1194,7 +1194,7 @@ class CacheManager:
|
||||
os.rename(filepath + ".tmp", filepath)
|
||||
|
||||
|
||||
# utilties for generating and compiling C wrappers
|
||||
# Utilities for generating and compiling C wrappers
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
|
@@ -768,7 +768,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
||||
The two blocks must be two-dimensional and have compatible inner dimensions.
|
||||
|
||||
:param input: The first tensor to be multiplied.
|
||||
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||
@@ -1172,7 +1172,7 @@ def ravel(x):
|
||||
@triton.jit
|
||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
"""
|
||||
transformes indices of a row-major size_i*size_j matrix into those
|
||||
Transforms indices of a row-major size_i*size_j matrix into those
|
||||
of one where indices are row major for each group of size_j rows.
|
||||
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
||||
[[0 , 1 , 2 , 3 ],
|
||||
|
@@ -7,12 +7,12 @@ from triton._C.libtriton.triton import ir
|
||||
|
||||
|
||||
# Create custom exception that prints message "hello"
|
||||
class IncompatibleTypeErrorimpl(Exception):
|
||||
class IncompatibleTypeErrorImpl(Exception):
|
||||
def __init__(self, type_a, type_b):
|
||||
self.type_a = type_a
|
||||
self.type_b = type_b
|
||||
self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
|
||||
super(IncompatibleTypeErrorimpl, self).__init__(self.message)
|
||||
super(IncompatibleTypeErrorImpl, self).__init__(self.message)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===##
|
||||
@@ -88,13 +88,13 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
|
||||
def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
|
||||
if type_a.is_ptr():
|
||||
if not allow_ptr_a:
|
||||
raise IncompatibleTypeErrorimpl(type_a, type_b)
|
||||
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||
# T* + U* with T != U
|
||||
if type_b.is_ptr() and (type_a != type_b):
|
||||
raise IncompatibleTypeErrorimpl(type_a, type_b)
|
||||
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||
# T* + float
|
||||
if type_b.is_floating():
|
||||
raise IncompatibleTypeErrorimpl(type_a, type_b)
|
||||
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||
|
||||
|
||||
def binary_op_type_checking_impl(lhs: tl.tensor,
|
||||
@@ -223,7 +223,7 @@ def fdiv(input: tl.tensor,
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
|
||||
raise ValueError("both operands of fdiv must have floating poscalar type")
|
||||
raise ValueError("both operands of fdiv must have floating scalar type")
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
|
||||
ret = builder.create_fdiv(input.handle, other.handle)
|
||||
return tl.tensor(ret, input.type)
|
||||
@@ -262,7 +262,7 @@ def bitwise_op_type_checking_impl(input: tl.tensor,
|
||||
input_sca_ty = input.type.scalar
|
||||
other_sca_ty = other.type.scalar
|
||||
if not input_sca_ty.is_int() or not other_sca_ty.is_int():
|
||||
raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty)
|
||||
raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
|
||||
ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
|
||||
if ret_sca_ty != input_sca_ty:
|
||||
input = cast(input, ret_sca_ty, builder)
|
||||
|
@@ -63,7 +63,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# NOTE:
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
|
@@ -80,7 +80,7 @@ def softmax_kernel(
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
|
||||
# Subtract maximum for numerical stability
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
|
||||
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
@@ -188,4 +188,4 @@ benchmark.run(show_plots=True, print_data=True)
|
||||
#
|
||||
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
|
||||
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
|
||||
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
|
||||
# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape.
|
||||
|
Reference in New Issue
Block a user