From 29656c7a3b7b8865c1168ba3647b12568702978f Mon Sep 17 00:00:00 2001
From: Pratik Bhatu <prbhatu@microsoft.com>
Date: Wed, 3 Jun 2020 03:22:54 +0530
Subject: [PATCH] Add back conv3d

---
 Athos/TFCompiler/TFNodesAST.py | 77 ++++++++++++++++++++++++++++++----
 1 file changed, 68 insertions(+), 9 deletions(-)

diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py
index 194532b..8e05018 100644
--- a/Athos/TFCompiler/TFNodesAST.py
+++ b/Athos/TFCompiler/TFNodesAST.py
@@ -280,13 +280,16 @@ class TFNodesAST:
 		assert(len(inputsRef) == 2)
 		return (None, AST.Reshape(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), extraNodeInfoDict[curNode.getName()][0], None))
 
-	def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr):
-		zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = -1
+	def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD = None, FD = None, strideD = None):
+		if imgD:
+			assert(FD)
+			assert(strideD)
+		zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1
 		if (paddingUsedStr == "\"SAME\""):
 			# Reference for following:
-			#	https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn
-			totalPaddingH = totalPaddingW = 0
-			
+			#       https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn
+			totalPaddingH = totalPaddingW = totalPaddingD = 0
+
 			if (imgH % strideH == 0):
 				totalPaddingH = max(FH - strideH, 0)
 			else:
@@ -297,17 +300,30 @@ class TFNodesAST:
 			else:
 				totalPaddingW = max(FW - (imgW % strideW), 0)
 
+			if imgD:
+				if (imgD % strideD == 0):
+					totalPaddingD = max(FD - strideD, 0)
+				else:
+					totalPaddingD = max(FD - (imgD % strideD), 0)
+
 			zPadHLeft = totalPaddingH // 2
 			zPadHRight = totalPaddingH - zPadHLeft
 
 			zPadWLeft = totalPaddingW // 2
-			zPadWRight = totalPaddingW - zPadWLeft		
+			zPadWRight = totalPaddingW - zPadWLeft
+
+			zPadDLeft = totalPaddingD // 2
+			zPadDRight = totalPaddingD - zPadDLeft
+
 		elif (paddingUsedStr == "\"VALID\""):
-			zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = 0
+			zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = 0
 		else:
-			zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = -1
+			zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1
 
-		return [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight]
+		if imgD:
+			return [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight]
+		else:
+			return [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight]
 
 	def Conv2D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
 		inputsRef = curNode.getInputsRef()
@@ -348,6 +364,49 @@ class TFNodesAST:
 								AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), 
 								options))
 
+	def Conv3D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
+		inputsRef = curNode.getInputsRef()
+		assert(len(inputsRef)==2)
+
+		stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi()
+		assert(stridesUsed[0]==1 and stridesUsed[4]==1)
+		strideD = stridesUsed[1]
+		strideH = stridesUsed[2]
+		strideW = stridesUsed[3]
+
+		inputShape = extraNodeInfoDict[inputsRef[0]][0]
+		imgD = inputShape[1]
+		imgH = inputShape[2]
+		imgW = inputShape[3]
+
+		filterShape = extraNodeInfoDict[inputsRef[1]][0]
+		FD = filterShape[0]
+		FH = filterShape[1]
+		FW = filterShape[2]
+
+		paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS()
+
+		[zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD, FD, strideD )
+
+		options = {}
+		options[AST.PaddingKeysDict.FD] = FD
+		options[AST.PaddingKeysDict.FH] = FH
+		options[AST.PaddingKeysDict.FW] = FW
+		options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft
+		options[AST.PaddingKeysDict.zPadDRight] = zPadDRight
+		options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft
+		options[AST.PaddingKeysDict.zPadHRight] = zPadHRight
+		options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft
+		options[AST.PaddingKeysDict.zPadWRight] = zPadWRight
+		options[AST.PaddingKeysDict.strideD] = strideD
+		options[AST.PaddingKeysDict.strideH] = strideH
+		options[AST.PaddingKeysDict.strideW] = strideW
+		options[AST.PaddingKeysDict.ConvDim] = 3
+		return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
+																TFNodesAST.getOperatorsIdx('#'),
+																AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
+																options))
+
 	def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict, typeOfPool:str):
 		inputsRef = curNode.getInputsRef()
 		assert(len(inputsRef)==1)
-- 
GitLab