Skip to content
Snippets Groups Projects
Commit 2172e484 authored by Bhatu's avatar Bhatu
Browse files

Add support for float16 tensor types.

parent cb6efa9d
No related branches found
No related tags found
No related merge requests found
......@@ -36,12 +36,14 @@ def errIfTokensNotMinLen(tokens, minlen, lineNum, entity):
class DataTypeEnum(enum.Enum):
DT_INVALID = 0
DT_FLOAT = 1
DT_BOOL = 2
DT_INT32 = 3
DT_INT64 = 4
DT_FLOAT16 = 2
DT_BOOL = 3
DT_INT32 = 4
DT_INT64 = 5
def Parse(str):
if (str == "DT_FLOAT"): return DataTypeEnum.DT_FLOAT
elif (str == "DT_HALF"): return DataTypeEnum.DT_FLOAT16
elif (str == "DT_BOOL"): return DataTypeEnum.DT_BOOL
elif (str == "DT_INT32"): return DataTypeEnum.DT_INT32
elif (str == "DT_INT64"): return DataTypeEnum.DT_INT64
......@@ -51,6 +53,7 @@ class DataTypeEnum(enum.Enum):
def Size(dt):
if (dt == DataTypeEnum.DT_INVALID): return 0
elif (dt == DataTypeEnum.DT_FLOAT): return 4
elif (dt == DataTypeEnum.DT_FLOAT16): return 2
elif (dt == DataTypeEnum.DT_BOOL): return 1
elif (dt == DataTypeEnum.DT_INT32): return 4
elif (dt == DataTypeEnum.DT_INT64): return 8
......@@ -191,6 +194,8 @@ class Tensor:
self.__valArr = [self.__valInput]*numElements
elif ((self.__dtype == DataTypeEnum.DT_FLOAT) and (self.__valInput is not None)):
self.__valArr = [self.__valInput]*numElements
elif ((self.__dtype == DataTypeEnum.DT_FLOAT16) and (self.__valInput is not None)):
self.__valArr = [self.__valInput]*numElements
elif ((self.__dtype == DataTypeEnum.DT_INT32 or self.__dtype == DataTypeEnum.DT_INT64)
and (self.__valInput is not None)):
self.__valArr = [self.__valInput]*numElements
......@@ -310,6 +315,8 @@ class Tensor:
# self.__valArr = returnArr
if self.__dtype == DataTypeEnum.DT_FLOAT:
dtype = numpy.dtype('<f4')
elif self.__dtype == DataTypeEnum.DT_FLOAT16:
dtype = numpy.dtype('<f2')
elif self.__dtype == DataTypeEnum.DT_BOOL:
dtype = numpy.dtype('bool')
elif self.__dtype == DataTypeEnum.DT_INT32:
......@@ -324,6 +331,8 @@ class Tensor:
def getDType(self):
if self.__dtype == DataTypeEnum.DT_FLOAT:
dtype = numpy.dtype('<f4')
elif self.__dtype == DataTypeEnum.DT_FLOAT16:
dtype = numpy.dtype('<f2')
elif self.__dtype == DataTypeEnum.DT_BOOL:
dtype = numpy.dtype('bool')
elif self.__dtype == DataTypeEnum.DT_INT32:
......@@ -625,6 +634,7 @@ class Graph:
pass
elif (curToken == "versions" or curToken == "library"):
print("Versions/Library node found. Ignoring remainder graph. Line =", cnt, file=sys.stderr)
print("Graph parsing successful.", file=sys.stderr)
return True
else:
print("Unrecognized token in graph dump at line =", cnt, ", token =", curToken, file=sys.stderr)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment