Files
triton/examples/python/pytorch/main.py
2019-05-03 14:30:06 -04:00

12 lines
334 B
Python

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.cpp_extension import load
from torch.distributions import categorical
from itertools import product
conv_triton = load( 'conv_triton', ['conv.cpp', 'conv.cu'], extra_cflags=['-O3'])