This package provides graphical computation for nn library in Torch7.
### torch-graph
This library requires torch-graph package to be installed.
https://github.com/koraykv/torch-graph
You do not need graphviz to be able to use this library, but if you have then you can display the graphs that you have created.
Right now, this repo is not distributed as part of torch-pkg or luarocks system. For installation follow these steps.
git clone git:https://github.com/koraykv/torch-graph.git
cd torch-graph
torch-pkg deploy
cd ..
git clone git:https://github.com/koraykv/torch-nngraph.git
cd torch-nngraph
torch-pkg deploy
The aim of this library is to provide users of nn library with tools to easily create complicated architectures. Any given nn module is going to be bundled into a graph node. The __call operator of an instance of nn.Module is used to create architectures as if one is writing function calls.
One hidden layer network
require 'nngraph'
x1 = nn.Linear(20,10)()
mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1))))
mlp = nn.gModule({x1},{mout})
x = torch.rand(20)
dx = torch.rand(1)
mlp:updateOutput(x)
mlp:updateGradInput(x,dx)
mlp:accGradParameters(x,dx)
-- draw graph (the forward graph, '.fg')
graph.dot(mlp.fg,'MLP')
Read this diagram from top to bottom, with the first and last nodes being dummy nodes that regroup all inputs and outputs of the graph. The 'module' entry describes the function of the node, as applies to 'input', and producing an result of the shape 'gradOutput'; 'mapindex' contains pointers to the parent nodes.
require 'nngraph'
x1=nn.Linear(20,20)()
x2=nn.Linear(10,10)()
m0=nn.Linear(20,1)(nn.Tanh()(x1))
m1=nn.Linear(10,1)(nn.Tanh()(x2))
madd=nn.CAddTable()({m0,m1})
m2=nn.Sigmoid()(madd)
m3=nn.Tanh()(madd)
gmod = nn.gModule({x1,x2},{m2,m3})
x = torch.rand(20)
y = torch.rand(10)
gmod:updateOutput({x,y})
gmod:updateGradInput({x,y},{torch.rand(1),torch.rand(1)})
graph.dot(gmod.fg,'Big MLP')
m = nn.Sequential()
m:add(nn.SplitTable(1))
m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
input = nn.Identity()()
input1,input2 = m(input):split(2)
m3 = nn.JoinTable(1)({input1,input2})
g = nn.gModule({input},{m3})
indata = torch.rand(2,10)
gdata = torch.rand(50)
g:forward(indata)
g:backward(indata,gdata)
graph.dot(g.fg,'Forward Graph')
graph.dot(g.bg,'Backward Graph')
input = nn.Identity()()
L1 = nn.Tanh()(nn.Linear(10,20)(input))
L2 = nn.Tanh()(nn.Linear(30,60)(nn.JoinTable(1)({input,L1})))
L3 = nn.Tanh()(nn.Linear(80,160)(nn.JoinTable(1)({L1,L2})))
g = nn.gModule({input},{L3})
indata = torch.rand(10)
gdata = torch.rand(160)
g:forward(indata)
g:backward(indata,gdata)
graph.dot(g.fg,'Forward Graph')
graph.dot(g.bg,'Backward Graph')