[examples] added skeleton for pytorch wrapper
This commit is contained in:
11
examples/python/pytorch/main.py
Normal file
11
examples/python/pytorch/main.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.cpp_extension import load
|
||||
from torch.distributions import categorical
|
||||
from itertools import product
|
||||
|
||||
conv_triton = load( 'conv_triton', ['conv.cpp', 'conv.cu'], extra_cflags=['-O3'])
|
Reference in New Issue
Block a user