Skip to content
Snippets Groups Projects
Commit b54294e8 authored by Bhatu's avatar Bhatu
Browse files

Implement SquaredDifference

We do this as a simplification on the tensorflow graph itself.
We transform SquaredDifference(a,b) into (a-b) * (a-b).
parent 39e78075
No related branches found
No related tags found
No related merge requests found
......@@ -501,11 +501,14 @@ class Value:
return self.__val
class Node:
def __init__(self):
self.__name = "" #Name of node
self.__op = "" #Name of operation carried out by node
self.__inputs = [] #List of all inputs to the current node
self.__attr = {} #Map of (attrName, Value) of all attributes for the current node
def __init__(self, op="", inputs=None, name=""):
self.__name = name #Name of node
self.__op = op #Name of operation carried out by node
if inputs is None:
self.__inputs = [] #List of all inputs to the current node
else:
self.__inputs = inputs
self.__attr = {} #Map of (attrName, Value) of all attributes for the current node
def getName(self):
return self.__name
......
......@@ -112,6 +112,32 @@ def prefixAllPlaceHolderNodes(graph):
remNodes.append(curNode)
graph.setNodesList(placeHolderNodes + remNodes)
# List of Optimisations
# 1. Split squared difference into (a-b)*(a-b)
def simplifyGraph(graph):
allNodes = graph.getAllNodesRef()
nodesMap = graph.getAllNodes()
newNodes = []
inputsFixup = {}
for curNode in allNodes:
inputs = curNode.getInputsRef()
for i in range(len(inputs)):
if inputs[i] in inputsFixup:
inputs[i] = inputsFixup[inputs[i]]
if (curNode.getOp() == "SquaredDifference"):
sub = Graph.Node("Sub", inputs.copy(), curNode.getName() + "__sub")
mul = Graph.Node("Mul", [sub.getName(), sub.getName()], curNode.getName() + "__mul")
newNodes.append(sub)
newNodes.append(mul)
nodesMap[sub.getName()] = sub
nodesMap[mul.getName()] = mul
inputsFixup[curNode.getName()] = mul.getName()
nodesMap.pop(curNode.getName())
else:
newNodes.append(curNode)
graph.setNodesList(newNodes)
def main():
sys.setrecursionlimit(10000)
......@@ -131,6 +157,8 @@ def main():
sizeInfoFileName = os.path.join(folderName, 'sizeInfo.mtdata')
sizeInfo = readSizeInfo(sizeInfoFileName)
# Tensorflow graph level optimisations
simplifyGraph(graph)
# Place all PlaceHolder nodes together at the beginning
prefixAllPlaceHolderNodes(graph)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment