[DOCS] Fixed typo: triton.function -> torch.autograd.Function
This commit is contained in:
committed by
Philippe Tillet
parent
a5e3397e6e
commit
55c800e632
@@ -39,7 +39,7 @@ As you will see, a wrapper for the above Triton function can be created in just
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _add(triton.function):
|
||||
class _add(torch.autograd.Function):
|
||||
# source-code for Triton compute kernel
|
||||
src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
|
@@ -15,7 +15,7 @@ The PyTriton API provides a :code:`triton.function` class which automatically ha
|
||||
import triton
|
||||
|
||||
# Entry point
|
||||
class _dot(triton.function):
|
||||
class _dot(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
# Forward Pass
|
||||
@@ -170,7 +170,7 @@ Creating custom operations for Triton and PyTorch is very similar; programmers h
|
||||
return da, db, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
A callable operation can be created using the :code:`apply` method of our :code:`triton.function` class.
|
||||
A callable operation can be created using the :code:`apply` method of the :code:`torch.autograd.Function` class.
|
||||
|
||||
.. code:: python
|
||||
|
||||
|
Reference in New Issue
Block a user