Skip to content
Snippets Groups Projects
Commit b54294e8 authored by Bhatu's avatar Bhatu
Browse files

Implement SquaredDifference

We do this as a simplification on the tensorflow graph itself.
We transform SquaredDifference(a,b) into (a-b) * (a-b).
parent 39e78075
No related branches found
No related tags found
No related merge requests found
...@@ -501,11 +501,14 @@ class Value: ...@@ -501,11 +501,14 @@ class Value:
return self.__val return self.__val
class Node: class Node:
def __init__(self): def __init__(self, op="", inputs=None, name=""):
self.__name = "" #Name of node self.__name = name #Name of node
self.__op = "" #Name of operation carried out by node self.__op = op #Name of operation carried out by node
self.__inputs = [] #List of all inputs to the current node if inputs is None:
self.__attr = {} #Map of (attrName, Value) of all attributes for the current node self.__inputs = [] #List of all inputs to the current node
else:
self.__inputs = inputs
self.__attr = {} #Map of (attrName, Value) of all attributes for the current node
def getName(self): def getName(self):
return self.__name return self.__name
......
...@@ -112,6 +112,32 @@ def prefixAllPlaceHolderNodes(graph): ...@@ -112,6 +112,32 @@ def prefixAllPlaceHolderNodes(graph):
remNodes.append(curNode) remNodes.append(curNode)
graph.setNodesList(placeHolderNodes + remNodes) graph.setNodesList(placeHolderNodes + remNodes)
# List of Optimisations
# 1. Split squared difference into (a-b)*(a-b)
def simplifyGraph(graph):
allNodes = graph.getAllNodesRef()
nodesMap = graph.getAllNodes()
newNodes = []
inputsFixup = {}
for curNode in allNodes:
inputs = curNode.getInputsRef()
for i in range(len(inputs)):
if inputs[i] in inputsFixup:
inputs[i] = inputsFixup[inputs[i]]
if (curNode.getOp() == "SquaredDifference"):
sub = Graph.Node("Sub", inputs.copy(), curNode.getName() + "__sub")
mul = Graph.Node("Mul", [sub.getName(), sub.getName()], curNode.getName() + "__mul")
newNodes.append(sub)
newNodes.append(mul)
nodesMap[sub.getName()] = sub
nodesMap[mul.getName()] = mul
inputsFixup[curNode.getName()] = mul.getName()
nodesMap.pop(curNode.getName())
else:
newNodes.append(curNode)
graph.setNodesList(newNodes)
def main(): def main():
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
...@@ -131,6 +157,8 @@ def main(): ...@@ -131,6 +157,8 @@ def main():
sizeInfoFileName = os.path.join(folderName, 'sizeInfo.mtdata') sizeInfoFileName = os.path.join(folderName, 'sizeInfo.mtdata')
sizeInfo = readSizeInfo(sizeInfoFileName) sizeInfo = readSizeInfo(sizeInfoFileName)
# Tensorflow graph level optimisations
simplifyGraph(graph)
# Place all PlaceHolder nodes together at the beginning # Place all PlaceHolder nodes together at the beginning
prefixAllPlaceHolderNodes(graph) prefixAllPlaceHolderNodes(graph)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment