pax_global_header00006660000000000000000000000064130146556630014523gustar00rootroot0000000000000052 comment=37dac07a19eb3809b8081514bc781206241863a0 .gitignore000066400000000000000000000000101301465566300130460ustar00rootroot00000000000000build/ .travis.yml000066400000000000000000000044141301465566300132030ustar00rootroot00000000000000language: c compiler: - gcc - clang cache: directories: - $HOME/OpenBlasInstall - $HOME/GraphViz sudo: false env: - TORCH_LUA_VERSION=LUAJIT21 - TORCH_LUA_VERSION=LUA51 - TORCH_LUA_VERSION=LUA52 addons: apt: packages: - cmake - gfortran - gcc-multilib - gfortran-multilib - liblapack-dev - build-essential - gcc - g++ - curl - cmake - libreadline-dev - git-core - libqt4-core - libqt4-gui - libqt4-dev - libjpeg-dev - libpng-dev - ncurses-dev - imagemagick - libzmq3-dev - gfortran - unzip - gnuplot - gnuplot-x11 before_script: - export ROOT_TRAVIS_DIR=$(pwd) - export INSTALL_PREFIX=~/torch/install - ls $HOME/OpenBlasInstall/lib || (cd /tmp/ && git clone https://github.com/xianyi/OpenBLAS.git -b master && cd OpenBLAS && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make PREFIX=$HOME/OpenBlasInstall install) - ls $HOME/GraphViz/lib || (cd /tmp/ && wget -c http://www.graphviz.org/pub/graphviz/stable/SOURCES/graphviz-2.38.0.tar.gz && tar -xvf graphviz-2.38.0.tar.gz && cd graphviz-2.38.0 && (./configure prefix=$HOME/GraphViz/ 2>/dev/null >/dev/null) && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make install) - export LD_LIBRARY_PATH=$HOME/GraphViz/lib:$LD_LIBRARY_PATH - git clone https://github.com/torch/distro.git ~/torch --recursive - cd ~/torch && git submodule update --init --recursive - mkdir build && cd build - export CMAKE_LIBRARY_PATH=$HOME/OpenBlasInstall/include:$HOME/OpenBlasInstall/lib:$CMAKE_LIBRARY_PATH - cmake .. -DCMAKE_INSTALL_PREFIX="${INSTALL_PREFIX}" -DCMAKE_BUILD_TYPE=Release -DWITH_${TORCH_LUA_VERSION}=ON - make && make install - ${INSTALL_PREFIX}/bin/luarocks install totem - if [[ $TORCH_LUA_VERSION != 'LUAJIT21' && $TORCH_LUA_VERSION != 'LUAJIT20' ]]; then ${INSTALL_PREFIX}/bin/luarocks install luaffi; fi - cd $ROOT_TRAVIS_DIR - export LD_LIBRARY_PATH=${INSTALL_PREFIX}/lib:$LD_LIBRARY_PATH script: - ${INSTALL_PREFIX}/bin/luarocks make rocks/graph-scm-1.rockspec - export PATH=${INSTALL_PREFIX}/bin:$PATH - export TESTLUA=$(which luajit lua | head -n 1) - ${TESTLUA} -lgraph -e "print('graph loaded succesfully')" - cd test - ${TESTLUA} test_graph.lua - ${TESTLUA} test_graphviz.lua CMakeLists.txt000066400000000000000000000003031301465566300136230ustar00rootroot00000000000000 CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) CMAKE_POLICY(VERSION 2.6) FIND_PACKAGE(Torch REQUIRED) FILE(GLOB luasrc *.lua) ADD_TORCH_PACKAGE(graph "" "${luasrc}" "General Graph Package") COPYRIGHT.txt000066400000000000000000000036411301465566300132040ustar00rootroot00000000000000Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) Copyright (c) 2011-2013 NYU (Clement Farabet) Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) Copyright (c) 2006 Idiap Research Institute (Samy Bengio) Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the names of NEC Laboratories American and IDIAP Research Institute nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Edge.lua000066400000000000000000000003271301465566300124400ustar00rootroot00000000000000--[[ A Directed Edge class No methods, just two fields, from and to. ]]-- local Edge = torch.class('graph.Edge') function Edge:__init(from,to,weight) self.from = from self.to = to self.weight = weight end Node.lua000066400000000000000000000074041301465566300124640ustar00rootroot00000000000000 --[[ Node class. This class is generally used with edge to add edges into a graph. graph:add(graph.Edge(graph.Node(),graph.Node())) But, one can also easily use this node class to create a graph. It will register all the edges into its children table and one can parse the graph from any given node. The drawback is there will be no global edge table and node table, which is mostly useful to run algorithms on graphs. If all you need is just a data structure to store data and run DFS, BFS over the graph, then this method is also quick and nice. --]] local Node = torch.class('graph.Node') function Node:__init(d,p) self.data = d self.id = 0 self.children = {} self.visited = false self.marked = false end function Node:add(child) local children = self.children if type(child) == 'table' and not torch.typename(child) then for i,v in ipairs(child) do self:add(v) end elseif not children[child] then table.insert(children,child) children[child] = #children end end -- visitor function Node:visit(pre_func,post_func) if not self.visited then if pre_func then pre_func(self) end for i,child in ipairs(self.children) do child:visit(pre_func, post_func) end if post_func then post_func(self) end end end function Node:label() return tostring(self.data) end -- Create a graph from the Node traversal function Node:graph() local g = graph.Graph() local function build_graph(node) for i,child in ipairs(node.children) do g:add(graph.Edge(node,child)) end end self:bfs(build_graph) return g end function Node:dfs_dirty(func) local visitednodes = {} local dfs_func = function(node) func(node) table.insert(visitednodes,node) end local dfs_func_pre = function(node) node.visited = true end self:visit(dfs_func_pre, dfs_func) return visitednodes end function Node:dfs(func) for i,node in ipairs(self:dfs_dirty(func)) do node.visited = false end end function Node:bfs_dirty(func) local visitednodes = {} local bfsnodes = {} local bfs_func = function(node) func(node) for i,child in ipairs(node.children) do if not child.marked then child.marked = true table.insert(bfsnodes,child) end end end table.insert(bfsnodes,self) self.marked = true while #bfsnodes > 0 do local node = table.remove(bfsnodes,1) table.insert(visitednodes,node) bfs_func(node) end return visitednodes end function Node:bfs(func) for i,node in ipairs(self:bfs_dirty(func)) do node.marked = false end end function Node:hasCycle() local hascycle = false local explorednodes = {} local function pre(node) -- if someone found a cycle, just back up if hascycle then return hascycle end -- if this node was marked during dfs, then -- it is still being explored, which means we hit a cycle. if node.marked then -- at this point set visited to true so that Node:visit() does -- not explore this node again node.visited = true hascycle = true return hascycle end node.marked = true end local function post(node) -- we are done with this node, so just remove marked info node.marked = false -- set visited to true flagging that this node is done. -- we might hit it in the future through a separate path, but -- at that point we should not explore it, the Node:visit() -- will avoid visiting any visited node. node.visited = true explorednodes[node] = true end self:visit(pre, post) -- now clean-up all the nodes for node, _ in pairs(explorednodes) do node.visited = false end return hascycle end README.md000066400000000000000000000007041301465566300123470ustar00rootroot00000000000000# Graph Package This package provides graphical computation for [Torch](https://github.com/torch/torch7/blob/master/README.md). ## Requirements You need *not* `graphviz` to be able to use this library but, if you have it, you will be able to display the graphs that you have created. For installing the package run the appropriate command below: ```bash # Mac users brew install graphviz # Debian/Ubuntu users sudo apt-get install graphviz -y ``` graphviz.lua000066400000000000000000000151011301465566300134220ustar00rootroot00000000000000require 'torch' local ffiOk = false local graphvizOk = false local cgraphOk = false local ffi local graphviz local cgraph ffiOk, ffi = pcall(require, 'ffi') if ffiOk then ffi.cdef[[ typedef struct FILE FILE; typedef struct Agraph_s Agraph_t; typedef struct Agnode_s Agnode_t; extern Agraph_t *agmemread(const char *cp); extern char *agget(void *obj, char *name); extern int agclose(Agraph_t * g); extern Agnode_t *agfstnode(Agraph_t * g); extern Agnode_t *agnxtnode(Agraph_t * g, Agnode_t * n); extern Agnode_t *aglstnode(Agraph_t * g); extern Agnode_t *agprvnode(Agraph_t * g, Agnode_t * n); typedef struct Agraph_s graph_t; typedef struct GVJ_s GVJ_t; typedef struct GVG_s GVG_t; typedef struct GVC_s GVC_t; extern GVC_t *gvContext(void); extern int gvLayout(GVC_t *context, graph_t *g, const char *engine); extern int gvRender(GVC_t *context, graph_t *g, const char *format, FILE *out); extern int gvFreeLayout(GVC_t *context, graph_t *g); extern int gvFreeContext(GVC_t *context); FILE * fopen ( const char * filename, const char * mode ); int fclose ( FILE * stream ); ]] local libgvc = ffi.os == 'Windows' and 'gvc' or 'libgvc' graphvizOk, graphviz = pcall(function() return ffi.load(libgvc, true) end) if not graphvizOk then graphvizOk, graphviz = pcall(function() return ffi.load('libgvc.so.6', true) end) end local libcgraph = ffi.os == 'Windows' and 'cgraph' or 'libcgraph' cgraphOk, cgraph = pcall(function() return ffi.load(libcgraph, true) end) if not cgraphOk then cgraphOk, cgraph = pcall(function() return ffi.load('libcgraph.so.6', true) end) end else graphvizOk = false cgraphOk = false end local unpack = unpack or table.unpack -- Lua52 compatibility local NULL = (ffiOk and (not jit)) and ffi.C.NULL or nil -- LuaJIT compatibility -- Retrieve attribute data from a graphviz object. local function getAttribute(obj, name) local res = cgraph.agget(obj, ffi.cast("char*", name)) assert(res ~= ffi.cast("char*", nil), 'could not get attr ' .. name) local out = ffi.string(res) return out end -- Iterate through nodes of a graphviz graph. local function nodeIterator(graph) local node = cgraph.agfstnode(graph) local nextNode return function() if node == NULL then return end if node == cgraph.aglstnode(graph) then nextNode = NULL end nextNode = cgraph.agnxtnode(graph, node) local result = node node = nextNode return result end end -- Convert a string of comma-separated numbers to actual numbers. local function extractNumbers(n, attr) local res = {} for number in string.gmatch(attr, "[^%,]+") do table.insert(res, tonumber(number)) end assert(#res == n, "attribute is not of expected form") return unpack(res) end -- Transform from graphviz coordinates to unit square. local function getRelativePosition(node, bbox) local x0, y0, w, h = unpack(bbox) local x, y = extractNumbers(2, getAttribute(node, 'pos')) local xt = (x - x0) / w local yt = (y - y0) / h assert(xt >= 0 and xt <= 1, "bad x coordinate") assert(yt >= 0 and yt <= 1, "bad y coordinate") return xt, yt end -- Retrieve a node's ID based on its label string. local function getID(node) local label = getAttribute(node, 'label') local res = {string.find(label, "^Node(%d+)")} or {string.find(label, "%((%d+)%)\\n")} local id = res[3] assert(id ~= nil, "could not get ID from node label : <" .. tostring(label) .. ">") return tonumber(id) end --[[ Lay out a graph and return the positions of the nodes. Args: * `g` - graph to lay out. * `algorithm` - name of the graphviz algorithm to use. (default: "dot") Returns: * `torch.Tensor(n, 2)` containing the resulting positions of the nodes. where `n` is the number of nodes in the graph. Coordinates are in the interval [0, 1]. ]] function graph.graphvizLayout(g, algorithm) if not graphvizOk or not cgraphOk then error("graphviz library could not be loaded.") end local nNodes = #g.nodes local context = graphviz.gvContext() local graphvizGraph = cgraph.agmemread(g:todot()) local algorithm = algorithm or "dot" assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm), "graphviz layout failed") assert(0 == graphviz.gvRender(context, graphvizGraph, algorithm, NULL), "graphviz render failed") -- Extract bounding box. local x0, y0, x1, y1 = extractNumbers(4, getAttribute(graphvizGraph, 'bb'), ",") local w = x1 - x0 local h = y1 - y0 local bbox = { x0, y0, w, h } -- Extract node positions. local positions = torch.zeros(nNodes, 2) for node in nodeIterator(graphvizGraph) do local id = getID(node) local x, y = getRelativePosition(node, bbox) positions[id][1] = x positions[id][2] = y end -- Clean up. graphviz.gvFreeLayout(context, graphvizGraph) cgraph.agclose(graphvizGraph) graphviz.gvFreeContext(context) return positions end function graph.graphvizFile(g, algorithm, fname) if not graphvizOk or not cgraphOk then error("graphviz library could not be loaded.") end algorithm = algorithm or 'dot' local _,_,rendertype = fname:reverse():find('(%a+)%.%w+') rendertype = rendertype:reverse() local context = graphviz.gvContext() local graphvizGraph = cgraph.agmemread(g:todot()) assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm), "graphviz layout failed") local fhandle = ffi.C.fopen(fname, 'w') local ret = graphviz.gvRender(context, graphvizGraph, rendertype, fhandle) ffi.C.fclose(fhandle) assert(0 == ret, "graphviz render failed") graphviz.gvFreeLayout(context, graphvizGraph) cgraph.agclose(graphvizGraph) graphviz.gvFreeContext(context) end --[[ Given a graph, dump an SVG or display it using graphviz. Args: * `g` - graph to display * `title` - Title to display in the graph * `fname` - [optional] if given it should contain a file name without an extension, the graph is saved on disk as fname.svg and display is not shown. If not given the graph is shown on qt display (you need to have qtsvg installed and running qlua) Returns: * `qs` - the window handle for the qt display (if fname given) or nil ]] function graph.dot(g,title,fname) local qt_display = fname == nil fname = fname or os.tmpname() local fnsvg = fname .. '.svg' local fndot = fname .. '.dot' graph.graphvizFile(g, 'dot', fnsvg) graph.graphvizFile(g, 'dot', fndot) if qt_display then require 'qtsvg' local qs = qt.QSvgWidget(fnsvg) qs:show() os.remove(fnsvg) os.remove(fndot) return qs end end init.lua000066400000000000000000000146461301465566300125500ustar00rootroot00000000000000require 'torch' graph = {} require('graph.graphviz') require('graph.Node') require('graph.Edge') --[[ Defines a graph and general operations on grpahs like topsort, connected components, ... uses two tables, one for nodes, one for edges ]]-- local Graph = torch.class('graph.Graph') function Graph:__init() self.nodes = {} self.edges = {} end -- add a new edge into the graph. -- an edge has two fields, from and to that are inserted into the -- nodes table. the edge itself is inserted into the edges table. function Graph:add(edge) if type(edge) ~= 'table' then error('graph.Edge or {graph.Edges} expected') end if torch.typename(edge) then -- add edge if not self.edges[edge] then table.insert(self.edges,edge) self.edges[edge] = #self.edges end -- add from node if not self.nodes[edge.from] then table.insert(self.nodes,edge.from) self.nodes[edge.from] = #self.nodes end -- add to node if not self.nodes[edge.to] then table.insert(self.nodes,edge.to) self.nodes[edge.to] = #self.nodes end -- add the edge to the node for parsing in nodes edge.from:add(edge.to) edge.from.id = self.nodes[edge.from] edge.to.id = self.nodes[edge.to] else for i,e in ipairs(edge) do self:add(e) end end end -- Clone a Graph -- this will create new nodes, but will share the data. -- Note that primitive data types like numbers can not be shared function Graph:clone() local clone = graph.Graph() local nodes = {} for i,n in ipairs(self.nodes) do table.insert(nodes,n.new(n.data)) end for i,e in ipairs(self.edges) do local from = nodes[self.nodes[e.from]] local to = nodes[self.nodes[e.to]] clone:add(e.new(from,to)) end return clone end -- It returns a new graph where the edges are reversed. -- The nodes share the data. Note that primitive data types can -- not be shared. function Graph:reverse() local rg = graph.Graph() local mapnodes = {} for i,e in ipairs(self.edges) do mapnodes[e.from] = mapnodes[e.from] or e.from.new(e.from.data) mapnodes[e.to] = mapnodes[e.to] or e.to.new(e.to.data) local from = mapnodes[e.from] local to = mapnodes[e.to] rg:add(e.new(to,from)) end return rg,mapnodes end function Graph:hasCycle() local roots = self:roots() if #roots == 0 then return true end for i, root in ipairs(roots) do if root:hasCycle() then return true end end return false end --[[ Topological Sort ]]-- function Graph:topsort() local dummyRoot -- reverse the graph local rg,map = self:reverse() local rmap = {} for k,v in pairs(map) do rmap[v] = k end -- work on the sorted graph local sortednodes = {} local rootnodes = rg:roots() if #rootnodes == 0 then error('Graph has cycles') end if #rootnodes > 1 then dummyRoot = graph.Node('dummy_root') for _, root in ipairs(rootnodes) do dummyRoot:add(root) end else dummyRoot = rootnodes[1] end -- run -- the trick is since the dummy node does not exist in original graph, -- rmap[dummyRoot] = nil hence nothing gets inserted into the table dummyRoot:dfs(function(node) table.insert(sortednodes,rmap[node]) end) if #sortednodes ~= #self.nodes then error('Graph has cycles') end return sortednodes,rg,rootnodes end -- find root nodes function Graph:roots() local edges = self.edges local rootnodes = {} for i,edge in ipairs(edges) do --table.insert(rootnodes,edge.from) if not rootnodes[edge.from] then rootnodes[edge.from] = #rootnodes+1 end end for i,edge in ipairs(edges) do if rootnodes[edge.to] then rootnodes[edge.to] = nil end end local roots = {} for root,i in pairs(rootnodes) do table.insert(roots, root) end table.sort(roots,function(a,b) return self.nodes[a] < self.nodes[b] end ) return roots end -- find root nodes function Graph:leaves() local edges = self.edges local leafnodes = {} for i,edge in ipairs(edges) do --table.insert(rootnodes,edge.from) if not leafnodes[edge.to] then leafnodes[edge.to] = #leafnodes+1 end end for i,edge in ipairs(edges) do if leafnodes[edge.from] then leafnodes[edge.from] = nil end end local leaves = {} for leaf,i in pairs(leafnodes) do table.insert(leaves, leaf) end table.sort(leaves,function(a,b) return self.nodes[a] < self.nodes[b] end ) return leaves end function graph._dotEscape(str) if string.find(str, '[^a-zA-Z]') then -- Escape newlines and quotes. local escaped = string.gsub(str, '\n', '\\n') escaped = string.gsub(escaped, '"', '\\"') str = '"' .. escaped .. '"' end return str end --[[ Generate a string like 'color=blue tailport=s' from a table (e.g. {color = 'blue', tailport = 's'}. Its up to the user to escape strings properly. ]] local function makeAttributeString(attributes) local str = {} local keys = {} for k, _ in pairs(attributes) do table.insert(keys, k) end table.sort(keys) for _, k in ipairs(keys) do local v = attributes[k] table.insert(str, tostring(k) .. '=' .. graph._dotEscape(tostring(v))) end return ' ' .. table.concat(str, ' ') end function Graph:todot(title) local nodes = self.nodes local edges = self.edges local str = {} table.insert(str,'digraph G {\n') if title then table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n') end table.insert(str,'node [shape = oval]; ') local nodelabels = {} for i,node in ipairs(nodes) do local nodeName if node.graphNodeName then nodeName = node:graphNodeName() else nodeName = 'Node' .. node.id end local l = graph._dotEscape(nodeName .. '\n' .. node:label()) nodelabels[node] = 'n' .. node.id local graphAttributes = '' if node.graphNodeAttributes then graphAttributes = makeAttributeString( node:graphNodeAttributes()) end table.insert(str, '\n' .. nodelabels[node] .. '[label=' .. l .. graphAttributes .. '];') end table.insert(str,'\n') for i,edge in ipairs(edges) do table.insert(str,nodelabels[edge.from] .. ' -> ' .. nodelabels[edge.to] .. ';\n') end table.insert(str,'}') return table.concat(str,'') end rocks/000077500000000000000000000000001301465566300122105ustar00rootroot00000000000000rocks/graph-scm-1.rockspec000066400000000000000000000010551301465566300157630ustar00rootroot00000000000000package = "graph" version = "scm-1" source = { url = "git://github.com/torch/graph", tag = "master" } description = { summary = "Graph package for Torch", homepage = "https://github.com/torch/graph", license = "UNKNOWN" } dependencies = { "torch >= 7.0" } build = { type = "command", build_command = [[ cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) ]], install_command = "cd build && $(MAKE) install" } test/000077500000000000000000000000001301465566300120465ustar00rootroot00000000000000test/test_graph.lua000066400000000000000000000172251301465566300147200ustar00rootroot00000000000000require 'graph' require 'torch' local tester = torch.Tester() local tests = torch.TestSuite() local function create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) local g = graph.Graph() local conmat = torch.rand(nlayers, nhiddens, nhiddens):ge(droprate)[{ {1, -2}, {}, {} }] -- create nodes local nodes = { [0] = {}, [nlayers+1] = {} } local nodecntr = 1 for inode = 1, ninputs do local node = graph.Node(nodecntr) nodes[0][inode] = node nodecntr = nodecntr + 1 end for ilayer = 1, nlayers do nodes[ilayer] = {} for inode = 1, nhiddens do local node = graph.Node(nodecntr) nodes[ilayer][inode] = node nodecntr = nodecntr + 1 end end for inode = 1, noutputs do local node = graph.Node(nodecntr) nodes[nlayers+1][inode] = node nodecntr = nodecntr + 1 end -- now connect inputs to all first layer hiddens for iinput = 1, ninputs do for inode = 1, nhiddens do g:add(graph.Edge(nodes[0][iinput], nodes[1][inode])) end end -- now run through layers and connect them for ilayer = 1, nlayers-1 do for jnode = 1, nhiddens do for knode = 1, nhiddens do if conmat[ilayer][jnode][knode] == 1 then g:add(graph.Edge(nodes[ilayer][jnode], nodes[ilayer+1][knode])) end end end end -- now connect last layer hiddens to outputs for inode = 1, nhiddens do for ioutput = 1, noutputs do g:add(graph.Edge(nodes[nlayers][inode], nodes[nlayers+1][ioutput])) end end -- there might be nodes left out and not connected to anything. Connect them for i = 1, nlayers do for j = 1, nhiddens do if not g.nodes[nodes[i][j]] then local jto = torch.random(1, nhiddens) g:add(graph.Edge(nodes[i][j], nodes[i+1][jto])) conmat[i][j][jto] = 1 end end end return g, conmat end function tests.graph() local nlayers = torch.random(2,5) local ninputs = torch.random(1,10) local noutputs = torch.random(1,10) local nhiddens = torch.random(10,20) local droprates = {0, torch.uniform(0.2, 0.8), 1} for i, droprate in ipairs(droprates) do local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) local nedges = nhiddens * (ninputs+noutputs) + c:sum() local nnodes = ninputs + noutputs + nhiddens*nlayers local nroots = ninputs + c:sum(2):eq(0):sum() local nleaves = noutputs + c:sum(3):eq(0):sum() tester:asserteq(#g.edges, nedges, 'wrong number of edges') tester:asserteq(#g.nodes, nnodes, 'wrong number of nodes') tester:asserteq(#g:roots(), nroots, 'wrong number of roots') tester:asserteq(#g:leaves(), nleaves, 'wrong number of leaves') end end function tests.test_dfs() local nlayers = torch.random(5,10) local ninputs = 1 local noutputs = 1 local nhiddens = 1 local droprate = 0 local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) local roots = g:roots() local leaves = g:leaves() tester:asserteq(#roots, 1, 'expected a single root') tester:asserteq(#leaves, 1, 'expected a single leaf') local dfs_nodes = {} roots[1]:dfs(function(node) table.insert(dfs_nodes, node) end) for i, node in ipairs(dfs_nodes) do tester:asserteq(node.data, #dfs_nodes - i +1, 'dfs order wrong') end end function tests.test_bfs() local nlayers = torch.random(5,10) local ninputs = 1 local noutputs = 1 local nhiddens = 1 local droprate = 0 local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) local roots = g:roots() local leaves = g:leaves() tester:asserteq(#roots, 1, 'expected a single root') tester:asserteq(#leaves, 1, 'expected a single leaf') local bfs_nodes = {} roots[1]:bfs(function(node) table.insert(bfs_nodes, node) end) for i, node in ipairs(bfs_nodes) do tester:asserteq(node.data, i, 'bfs order wrong') end end function tests.test_topsort() local n1 = graph.Node(1) local n2 = graph.Node(2) local n3 = graph.Node(3) local n4 = graph.Node(4) local g = graph.Graph() g:add(graph.Edge(n1, n2)) g:add(graph.Edge(n1, n3)) g:add(graph.Edge(n2, n3)) g:add(graph.Edge(n2, n4)) g:add(graph.Edge(n3, n4)) local sorted = g:topsort() tester:assert(sorted[1] == n1, 'wrong sort order' ) tester:assert(sorted[2] == n2, 'wrong sort order' ) tester:assert(sorted[3] == n3, 'wrong sort order' ) tester:assert(sorted[4] == n4, 'wrong sort order' ) -- add an extra root local n0 = graph.Node(0) g:add(graph.Edge(n0, n2)) local sorted2 = g:topsort() tester:assert(sorted2[1] == n1 or sorted2[1] == n0, 'wrong sort order' ) tester:assert(sorted2[5] == n4, 'wrong sort order' ) -- add an extra leaf local n5 = graph.Node(5) g:add(graph.Edge(n3, n5)) local sorted2 = g:topsort() tester:assert(sorted2[1] == n1 or sorted2[1] == n0, 'wrong sort order' ) tester:assert(sorted2[6] == n4 or sorted2[6] == n5, 'wrong sort order' ) tester:assert(sorted2[5] == n4 or sorted2[5] == n5, 'wrong sort order' ) tester:assert(sorted2[6] ~= sorted2[5], 'wrong sort order' ) -- add a bottleneck and a new set of nodes local n11 = graph.Node(11) local n12 = graph.Node(12) local n13 = graph.Node(13) local n14 = graph.Node(14) local n15 = graph.Node(15) local n16 = graph.Node(16) g:add(graph.Edge(n4, n11)) g:add(graph.Edge(n5, n11)) g:add(graph.Edge(n11, n12)) g:add(graph.Edge(n11, n13)) g:add(graph.Edge(n12, n13)) g:add(graph.Edge(n13, n14)) g:add(graph.Edge(n14, n15)) g:add(graph.Edge(n12, n15)) g:add(graph.Edge(n13, n16)) local sorted3 = g:topsort() -- check all the first 6 sorted elements have data <= 5 for i=1, 6 do tester:assert(sorted3[i].data <= 5, 'wrong sort order') end tester:assert(sorted3[7] == n11, 'wrong sort order') tester:assert(sorted3[8] == n12, 'wrong sort order' ) tester:assert(sorted3[9] == n13, 'wrong sort order' ) tester:assert(sorted3[11] == n16 or sorted3[12] == n16, 'wrong sort order') end function tests.test_cycle() local n1 = graph.Node(1) local n2 = graph.Node(2) local n3 = graph.Node(3) local n4 = graph.Node(4) local cycle = graph.Graph() cycle:add(graph.Edge(n1, n2)) cycle:add(graph.Edge(n1, n3)) cycle:add(graph.Edge(n2, n3)) cycle:add(graph.Edge(n3, n2)) cycle:add(graph.Edge(n2, n4)) cycle:add(graph.Edge(n3, n4)) tester:asserteq(cycle:hasCycle(), true, 'Graph is supposed to have cycle') local n1 = graph.Node(1) local n2 = graph.Node(2) local n3 = graph.Node(3) local n4 = graph.Node(4) local nocycle = graph.Graph() nocycle:add(graph.Edge(n1, n2)) nocycle:add(graph.Edge(n1, n3)) nocycle:add(graph.Edge(n2, n3)) nocycle:add(graph.Edge(n2, n4)) nocycle:add(graph.Edge(n3, n4)) tester:asserteq(nocycle:hasCycle(), false, 'Graph is not supposed to have cycle') local function create_cycle(g, node0, length) local node1, node2 = node0, nil for i = 1, length-1 do node2 = graph.Node('c' .. i) local e = graph.Edge(node1, node2) g:add(e) node1 = node2 end g:add(graph.Edge(node1, node0)) end local bigcycle = graph.Graph() local n1 = graph.Node(1) local n2 = graph.Node(2) local n3 = graph.Node(3) local n4 = graph.Node(4) bigcycle:add(graph.Edge(n1, n2)) bigcycle:add(graph.Edge(n1, n3)) bigcycle:add(graph.Edge(n2, n3)) bigcycle:add(graph.Edge(n2, n4)) bigcycle:add(graph.Edge(n3, n4)) create_cycle(bigcycle, n2, 5) tester:asserteq(cycle:hasCycle(), true, 'Graph is supposed to have cycle') end return tester:add(tests):run() test/test_graphviz.lua000066400000000000000000000021621301465566300154430ustar00rootroot00000000000000require 'graph' require 'torch' local tester = torch.Tester() local tests = torch.TestSuite() function tests.layout() local g = graph.Graph() local root = graph.Node(10) local n1 = graph.Node(1) local n2 = graph.Node(2) g:add(graph.Edge(root, n1)) g:add(graph.Edge(n1, n2)) local positions = graph.graphvizLayout(g, 'dot') local xs = positions:select(2, 1) local ys = positions:select(2, 2) tester:assertlt(xs:add(-xs:mean()):norm(), 1e-3, "x coordinates should be the same") tester:assertTensorEq(ys, torch.sort(ys, true), 1e-3, "y coordinates should be ordered") end function tests.testDotEscape() tester:assert(graph._dotEscape('red') == 'red', 'Don\'t escape single words') tester:assert(graph._dotEscape('My label') == '"My label"', 'Use quotes for spaces') tester:assert(graph._dotEscape('Non[an') == '"Non[an"', 'Use quotes for non-alpha characters') tester:assert(graph._dotEscape('My\nnewline') == '"My\\nnewline"', 'Escape newlines') tester:assert(graph._dotEscape('Say "hello"') == '"Say \\"hello\\""', 'Escape quotes') end return tester:add(tests):run() test/test_old.lua000066400000000000000000000010771301465566300143730ustar00rootroot00000000000000require 'graph' dofile 'graphviz.lua' g=graph.Graph() root=graph.Node(10) n1=graph.Node(1) n2=graph.Node(2) g:add(graph.Edge(root,n1)) g:add(graph.Edge(root,n2)) nend = graph.Node(20) g:add(graph.Edge(n1,nend)) g:add(graph.Edge(n2,nend)) -- g:add(graph.Edge(nend,root)) local i = 0 print('======= BFS ==========') root:bfs(function(node) i=i+1;print('i='..i);print(node:label())end) print('======= DFS ==========') i = 0 root:dfs(function(node) i=i+1;print('i='..i);print(node:label())end) print('======= topsort ==========') s,rg,rn = g:topsort() graph.dot(g, 'g', 'g')