[DOCS] Fixed typo: triton.function -> torch.autograd.Function

This commit is contained in:
Philippe Tillet
2020-03-13 11:42:43 -04:00
committed by Philippe Tillet
parent a5e3397e6e
commit 55c800e632
2 changed files with 3 additions and 3 deletions

View File

@@ -39,7 +39,7 @@ As you will see, a wrapper for the above Triton function can be created in just
import torch
import triton
class _add(triton.function):
class _add(torch.autograd.Function):
# source-code for Triton compute kernel
src = """
__global__ void add(float* z, float* x, float* y, int N){

View File

@@ -15,7 +15,7 @@ The PyTriton API provides a :code:`triton.function` class which automatically ha
import triton
# Entry point
class _dot(triton.function):
class _dot(torch.autograd.Function):
@staticmethod
# Forward Pass
@@ -170,7 +170,7 @@ Creating custom operations for Triton and PyTorch is very similar; programmers h
return da, db, None, None, None, None, None, None, None
A callable operation can be created using the :code:`apply` method of our :code:`triton.function` class.
A callable operation can be created using the :code:`apply` method of the :code:`torch.autograd.Function` class.
.. code:: python