diff --git a/Athos/.gitignore b/Athos/.gitignore index 2467b19510f899768ec81f95a5ae3abc67605f1c..5386dd8b3676fb1a30d95b24bd5fa41af0344cf5 100644 --- a/Athos/.gitignore +++ b/Athos/.gitignore @@ -1,8 +1,11 @@ -*.inp +*.inprm *.outp *.mtdata *.pkl *.out *.ezpc *.cpp -__pycache__/ \ No newline at end of file +SeeDot/debug/ +*__temp1.ezpc +*__temp2.ezpc +__pycache__/ diff --git a/Athos/HelperScripts/process_models/change_onnx_output.py b/Athos/HelperScripts/process_models/change_onnx_output.py new file mode 100644 index 0000000000000000000000000000000000000000..ba86f1ced0f312d360aee29f951adfdba98e1430 --- /dev/null +++ b/Athos/HelperScripts/process_models/change_onnx_output.py @@ -0,0 +1,139 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2020 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import onnx +import onnxruntime +import numpy as np +from onnx import helper, shape_inference, checker +from onnx import ValueInfoProto, ModelProto, TensorProto +import os + +model_name = "shufflenet_may17.onnx" +output_model_name = "processed_" + model_name +inputs = ['data'] +nodes_to_remove = ['LabelSelector', 'LabelIndexExtractor', 'ZipMap', + 'activation37'] +new_output_names = ['fc'] +batch_size = 1 + +def fix_shape(shape_list, batch_size): + if 'None' not in shape_list: + return shape_list + else: + shape_list[0] = batch_size + assert ('None' not in shape_list) , """Other than batch size there are input + params with unkown dimension""" + return shape_list + +def fix_inp_shape(inp, batch_size): + if inp.type.tensor_type.shape.dim[0].dim_param == 'None': + inp.type.tensor_type.shape.dim[0].dim_value = batch_size + return + +def get_np_type_from_onnxruntime(typ_str): + np_types = { + 'tensor(float)' : np.float32, + 'tensor(float64)' : np.float64, + 'tensor(int)' : np.int32, + 'tensor(int64)' : np.int64 + } + return np_types[typ_str] + +def get_onnx_type(arr): + onnx_types = { + np.float32 : TensorProto.FLOAT, + np.float64 : TensorProto.DOUBLE, + np.int32 : TensorProto.INT32, + np.int64 : TensorProto.INT64 + } + return onnx_types[arr.dtype.type] + + +model = onnx.load(model_name) +# 1. Inputs to remove +# Inputs to dead nodes should not show up as inputs for the model +# and also not in the initialization list. +inputs_to_remove = [ inp for i in model.graph.node + if i.name in nodes_to_remove for inp in i.input ] +new_inputs = [ i for i in model.graph.input if i.name not in inputs_to_remove ] + +# Fix batch size +fix_inp_shape(new_inputs[0], batch_size) + +# 2. Remove their initializers +new_initializers = [ init for init in model.graph.initializer + if init.name not in nodes_to_remove + and init.name not in inputs_to_remove ] + +# 3. Remove nodes +new_nodes = [ n for n in model.graph.node if n.name not in nodes_to_remove ] + + +# Get Ouput Tensor Types to create ValueInfo for output info +# by running model on dummy input +temp_model = ModelProto() +temp_model.CopyFrom(model) +for i in new_output_names: + op = ValueInfoProto() + op.name = i + temp_model.graph.output.append(op) +onnx.save(temp_model, '__temp.onnx') +sess = onnxruntime.InferenceSession('__temp.onnx') +sess_inps = sess.get_inputs() +input_dict = {} +for i in sess_inps: + shape = fix_shape(i.shape, batch_size) + typ = get_np_type_from_onnxruntime(i.type) + input_dict[i.name] = np.random.rand(*shape).astype(typ) + +output_tensors = sess.run(new_output_names, input_dict) +if os.path.exists("__temp.onnx"): + os.remove("__temp.onnx") + +# 4. Create new output list +new_outputs = [] +for i in range(0,len(new_output_names)): + name = new_output_names[i] + typ = get_onnx_type(output_tensors[i]) + shape = output_tensors[i].shape + val_info = helper.make_tensor_value_info(name, typ, shape) + new_outputs.append(val_info) + +new_graph = helper.make_graph(new_nodes, + model.graph.name, + new_inputs, + new_outputs, + initializer=new_initializers, + doc_string=model.graph.doc_string, + value_info=model.graph.value_info) +new_model = helper.make_model(new_graph, + ir_version=model.ir_version, + doc_string=model.doc_string, + model_version=model.model_version, + domain=model.domain, + producer_name='MPCOpRemover') +new_model.metadata_props.extend(model.metadata_props) +new_model.opset_import.pop() +new_model.opset_import.extend(model.opset_import) +onnx.save(new_model, 'processed_'+model_name) diff --git a/Athos/HelperScripts/process_models/convert_keras_to_onnx.py b/Athos/HelperScripts/process_models/convert_keras_to_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..17e04d1df5c97009d292f5c21955f84f58194b57 --- /dev/null +++ b/Athos/HelperScripts/process_models/convert_keras_to_onnx.py @@ -0,0 +1,36 @@ +import tensorflow as tf +import onnx +from onnx import shape_inference +import keras2onnx + +model_filename = 'chest_xray_covid19_model.h5' +output_filename = 'covid_resnet.onnx' +input_h = 224 +input_w = 224 + +tf.keras.backend.set_learning_phase(0) +keras_model = tf.keras.models.load_model(model_filename) +onnx_model = keras2onnx.convert_keras(keras_model, keras_model.name) + +def set_input_dim(onnx_model, idx, val): + onnx_model.graph.input[0].type.tensor_type.shape.dim[idx].dim_value = val + +def get_input_dim(onnx_model, idx): + return onnx_model.graph.input[0].type.tensor_type.shape.dim[idx].dim_value + +#If input dims are parametric we need to materialize the dims to constants +# N H W C +dims = { "n" : 0, "h" : 1, "w" : 2, "c" : 3} +n = get_input_dim(onnx_model, dims["n"]) +h = get_input_dim(onnx_model, dims["h"]) +w = get_input_dim(onnx_model, dims["w"]) +c = get_input_dim(onnx_model, dims["c"]) + +if 0 in [n,h,w,c]: + set_input_dim(onnx_model, dims["n"], 1) + set_input_dim(onnx_model, dims["h"], input_h) + set_input_dim(onnx_model, dims["w"], input_w) + +fixed_model = onnx.shape_inference.infer_shapes(onnx_model) +onnx.checker.check_model(fixed_model) +onnx.save_model(fixed_model, output_filename) diff --git a/Athos/HelperScripts/process_models/convert_keras_to_tf.py b/Athos/HelperScripts/process_models/convert_keras_to_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..f98eebd8690b0cb4a74d0d071e8426f442a9d8db --- /dev/null +++ b/Athos/HelperScripts/process_models/convert_keras_to_tf.py @@ -0,0 +1,26 @@ +import tensorflow as tf + +model_filename = 'chest_xray_covid19_model.h5' +output_filename = 'covid_resnet.pb' + +def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): + graph = session.graph + with graph.as_default(): + freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) + output_names = output_names or [] + output_names += [v.op.name for v in tf.global_variables()] + input_graph_def = graph.as_graph_def() + if clear_devices: + for node in input_graph_def.node: + node.device = "" + frozen_graph = tf.graph_util.convert_variables_to_constants( + session, input_graph_def, output_names, freeze_var_names) + return frozen_graph + +tf.keras.backend.set_learning_phase(0) + +with tf.keras.utils.CustomObjectScope({'GlorotUniform': tf.keras.initializers.glorot_uniform()}): + model = tf.keras.models.load_model(model_filename) + frozen_graph = freeze_session(tf.keras.backend.get_session(), + output_names=[out.op.name for out in model.outputs]) + tf.train.write_graph(frozen_graph, ".", output_filename, as_text=False) diff --git a/Athos/ONNXCompiler/.gitignore b/Athos/ONNXCompiler/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..86737c88997aa49fd1bec4b47e872f9975af66f0 --- /dev/null +++ b/Athos/ONNXCompiler/.gitignore @@ -0,0 +1,8 @@ +models/ +debug/ +*.cpp +*.inp +*.h +*.ezpc +*.h +*.npy \ No newline at end of file diff --git a/Athos/ONNXCompiler/ONNXNodesAST.py b/Athos/ONNXCompiler/ONNXNodesAST.py new file mode 100644 index 0000000000000000000000000000000000000000..2509f13e68f4342557873379b35d51764573d5d1 --- /dev/null +++ b/Athos/ONNXCompiler/ONNXNodesAST.py @@ -0,0 +1,898 @@ +''' + +Authors: Shubham Ugare. + +Copyright: +Copyright (c) 2018 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import AST.AST as AST +from onnx import mapping +from onnx import TensorProto +from numbers import Number + +DEBUG = False +out_var_prefix = 'J' + +class OnnxNode(object): + """ + Reimplementation of NodeProto from ONNX, but in a form + more convenient to work with from Python. + """ + + def __init__(self, node): + self.name = str(node.name) + self.op_type = str(node.op_type) + self.domain = str(node.domain) + self.attrs = dict([(attr.name, + translate_onnx(attr.name, convert_onnx(attr))) + for attr in node.attribute]) + self.inputs = list(node.input) + self.outputs = list(node.output) + self.node_proto = node + +__onnx_attr_translator = { + "axis": lambda x: int(x), + "axes": lambda x: [int(a) for a in x], + "dtype": lambda x: onnx2seedot(x), + "keepdims": lambda x: bool(x), + "to": lambda x: onnx2seedot(x), +} + + +def convert_onnx(attr): + return __convert_onnx_attribute_proto(attr) + + +def __convert_onnx_attribute_proto(attr_proto): + """ + Convert an ONNX AttributeProto into an appropriate Python object + for the type. + NB: Tensor attribute gets returned as the straight proto. + """ + if attr_proto.HasField('f'): + return attr_proto.f + elif attr_proto.HasField('i'): + return attr_proto.i + elif attr_proto.HasField('s'): + return str(attr_proto.s, 'utf-8') + elif attr_proto.HasField('t'): + return attr_proto.t # this is a proto! + elif attr_proto.HasField('g'): + return attr_proto.g + elif attr_proto.floats: + return list(attr_proto.floats) + elif attr_proto.ints: + return list(attr_proto.ints) + elif attr_proto.strings: + str_list = list(attr_proto.strings) + if IS_PYTHON3: + str_list = list(map(lambda x: str(x, 'utf-8'), str_list)) + return str_list + elif attr_proto.HasField('sparse_tensor'): + return attr_proto.sparse_tensor + else: + raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto)) + +def translate_onnx(key, val): + return __onnx_attr_translator.get(key, lambda x: x)(val) + +def onnx2seedot(dtype): + return TENSOR_TYPE_TO_SEEDOT_TYPE[_onnx_dtype(dtype)] + +def _onnx_dtype(dtype): + if isinstance(dtype, Number): + onnx_dype = dtype + elif isinstance(dtype, str): + onnx_dype = TensorProto.DataType.Value(dtype) + else: + raise RuntimeError("dtype should be number or str.") + return onnx_dype + +TENSOR_TYPE_TO_SEEDOT_TYPE = { + int(TensorProto.FLOAT): 'float32', + int(TensorProto.UINT8): 'uint8', + int(TensorProto.INT8): 'int8', + int(TensorProto.UINT16): 'uint16', + int(TensorProto.INT16): 'int16', + int(TensorProto.INT32): 'int32', + int(TensorProto.INT64): 'int64', + int(TensorProto.BOOL): 'bool', + int(TensorProto.FLOAT16): 'float16', + int(TensorProto.DOUBLE): 'float64', + int(TensorProto.COMPLEX64): 'complex64', + int(TensorProto.COMPLEX128): 'complex128', + int(TensorProto.UINT32): 'uint32', + int(TensorProto.UINT64): 'uint64', + int(TensorProto.STRING): 'string' +} + +def getOperatorsIdx(token): + #TODO : remove usage of this + return AST.Operators.convSymbolToEnumValue(token) + +def get_seedot_shape_order(old_shape): + if(len(old_shape) == 4): + # Case when spatial dimension is 2 + # inverse of [1, 3, 4, 2] is [1, 4, 2, 3] + return ([old_shape[0], old_shape[2], old_shape[3], old_shape[1]], [1, 4, 2, 3]) + else: + # Casr when spatial dimension is 3 + # inverse of [1, 3, 4, 5, 2] is [1, 5, 2, 3, 4] + return ([old_shape[0], old_shape[2], old_shape[3], old_shape[4], old_shape[1]], [1, 5, 2, 3, 4]) + +def get_seedot_filter_shape_order(filter_shape): + if(len(filter_shape) == 4): + # Case when spatial dimension is 2 + # inverse of [3, 4, 2, 1] is [4, 3, 1, 2] + return ([filter_shape[2], filter_shape[3], filter_shape[1], filter_shape[0]], [4, 3, 1, 2]) + else: + # Casr when spatial dimension is 3 + # inverse of [3, 4, 5, 2, 1] is [5, 4, 1, 2, 3] + return ([filter_shape[2], filter_shape[3], filter_shape[4], filter_shape[1], filter_shape[0]], [5, 4, 1, 2, 3]) + +def get_onnx_order(onnx_shape): + if(len(onnx_shape) == 4): + # inverse of [1, 4, 2, 3] is [1, 3, 4, 2] + return [1, 3, 4, 2] + else: + # inverse of [1, 5, 2, 3, 4] is [1, 3, 4, 5, 2] + return [1, 3, 4, 5, 2] + +def get_reshaped_input_ast(input_name, value_info, node_name_to_out_var_dict): + onnx_input_shape = list(value_info[input_name][1]) + (seedot_input_shape, seedot_input_order) = get_seedot_shape_order(onnx_input_shape) + return AST.Reshape(AST.ID(node_name_to_out_var_dict[input_name]), seedot_input_shape, seedot_input_order) + +def get_reshaped_bias_ast(bias_name, value_info, node_name_to_out_var_dict, dim): + if(dim == 2): + return AST.Reshape(AST.ID(node_name_to_out_var_dict[bias_name]), [1, 1, 1, value_info[bias_name][1][0]], None) + else: + return AST.Reshape(AST.ID(node_name_to_out_var_dict[bias_name]), [1, 1, 1, 1, value_info[bias_name][1][0]], None) + +def get_reshaped_filter_ast(filter_name, value_info, node_name_to_out_var_dict): + onnx_filter_shape = list(value_info[filter_name][1]) + (seedot_filter_shape, seedot_filter_order) = get_seedot_filter_shape_order(onnx_filter_shape) + return AST.Reshape(AST.ID(node_name_to_out_var_dict[filter_name]), seedot_filter_shape, seedot_filter_order) + +def get_reshaped_output_ast(onnx_output_name, value_info, output_name): + onnx_output_shape = list(value_info[onnx_output_name][1]) + onnx_output_order = get_onnx_order(onnx_output_shape) + return AST.Reshape(AST.ID(output_name), onnx_output_shape, onnx_output_order) + +def get_new_var_name(out_var_count): + return out_var_prefix + str(out_var_count) + +def update_program_with_new_node(innermost_let_ast_node, new_node, new_node_name, mtdAST): + cur_out_var_ast_node = AST.ID(new_node_name) + new_let_node = AST.Let(cur_out_var_ast_node, new_node, cur_out_var_ast_node) + mtdAST.visit(new_let_node, {AST.ASTNode.mtdKeyTFOpName : 'no', AST.ASTNode.mtdKeyTFNodeName : 'no'}) + # Updating the innermost Let AST node and the expression for previous Let Node + innermost_let_ast_node.expr = new_let_node + innermost_let_ast_node = new_let_node + + # node_name_to_out_var_dict[node.outputs[0]] = new_node_name + return innermost_let_ast_node + +class ONNXNodesAST: + + # value_info: dictionary of name -> (type, dimension tuple) + def Input(node, value_info, node_name_to_out_var_dict): + if(DEBUG): + print(node.outputs[0]) + # There are two types of inputs + dims = list(node.dims if hasattr(node, 'dims') else ([val.dim_value for val in node.type.tensor_type.shape.dim])) + data_type = node.data_type if hasattr (node, 'data_type') else node.type.tensor_type.elem_type + return AST.Input(dims, onnx2seedot(data_type)) + + + def Cast(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + inputsRef = node.inputs + assert(len(inputsRef) == 1) + # destType = node.attrs['to'] + + # seedot_output_ast = AST.UninterpFuncCall(value_info[node.outputs[0]][1], + # 'Cast', + # [AST.ID(inputsRef[0]), + # AST.ID(destType), + # AST.ID(destType) + # ]) + # output_name = get_new_var_name(out_var_count) + # innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + # out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = inputsRef[0] + + return (innermost_let_ast_node, out_var_count) + + def Pad(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + inputsRef = node.inputs + # Skip constant_val input (last input) + inpLen = len(inputsRef) - 1 + assert(inpLen == 2) + inputs = [AST.ID(node_name_to_out_var_dict[inputsRef[x]]) for x in range(0, inpLen)] + mode = node.attrs['mode'] + assert(mode == 'constant') + seedot_output_ast = AST.UninterpFuncCall(list(value_info[node.outputs[0]][1]), + 'PadONNX', inputs) + + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def Concat(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + inputsRef = node.inputs + N = len(inputsRef) + + inputs = [AST.ID(node_name_to_out_var_dict[inputsRef[x]]) for x in range(0, len(inputsRef))] + axis = node.attrs['axis'] + + seedot_output_ast = AST.UninterpFuncCall(list(value_info[node.outputs[0]][1]), + 'Concat'+str(N) + 'T', + inputs + [AST.Int(axis, 32, False)], + outputDiffInpDims=1 + ) + + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def Relu(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + + inputsRef = node.inputs + assert(len(inputsRef)==1) + + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.Func(getOperatorsIdx('relu'), AST.ID(reshaped_input_name)) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + if(DEBUG): + print(node.outputs[0]) + print(onnx_input_shape, '->', seedot_input_shape, '->', onnx_output_shape) + + return (innermost_let_ast_node, out_var_count) + # return AST.Func(getOperatorsIdx('relu'), AST.ID(node_name_to_out_var_dict[inputsRef[0]])) + + def Add(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + inputsRef = node.inputs + assert(len(inputsRef) == 2) + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) + out_var_count += 1 + + reshaped_input_name1 = get_new_var_name(out_var_count) + reshaped_input1 = get_reshaped_input_ast(inputsRef[1], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input1, reshaped_input_name1, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), + getOperatorsIdx('+'), + AST.ID(reshaped_input_name1) + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + if(DEBUG): + print(node.outputs[0]) + print(onnx_input_shape, onnx_input_shape1, '->', seedot_input_shape, seedot_input_shape1, '->', onnx_output_shape) + + return (innermost_let_ast_node, out_var_count) + + + def Gemm(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + inputsRef = node.inputs + assert(len(inputsRef) == 3) + input1AST = AST.ID(node_name_to_out_var_dict[inputsRef[0]]) + input2AST = AST.ID(node_name_to_out_var_dict[inputsRef[1]]) + + if('transA' in node.attrs and node.attrs['transA']): input1AST = AST.Transp(input1AST) + if('transB' in node.attrs and node.attrs['transB']): input2AST = AST.Transp(input2AST) + + # W*x + b + seedot_output_ast = AST.BOp(AST.BOp(input1AST, getOperatorsIdx('*'), input2AST), getOperatorsIdx('+'), AST.ID(node_name_to_out_var_dict[inputsRef[2]])) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def Constant(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + # TODO: Use AST.decl for defining a tensor. If used as a parameter for Reshape then we don't need it for now. + return (innermost_let_ast_node, out_var_count) + + def Transpose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + + inputsRef = node.inputs + assert(len(inputsRef)==1) + + seedot_output_ast = AST.Transpose(AST.ID(node_name_to_out_var_dict[inputsRef[0]]), node.attrs['perm']) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + # Only supports split into equal parts + def Split(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + inputsRef = node.inputs + output_count = len(node.outputs) + + for cur_count in range(output_count): + seedot_output_ast = AST.UninterpFuncCall(list(value_info[node.outputs[cur_count]][1]), 'Split', + [AST.ID(node_name_to_out_var_dict[inputsRef[0]]), AST.Int(node.attrs['axis'], 32, False), AST.Int(cur_count, 32, False), AST.Int(output_count, 32, False)]) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[cur_count]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def ReduceMean(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + inputsRef = node.inputs + + keepdims = node.attrs['keepdims'] + axes = node.attrs['axes'] + + # currently handling only this case + # currently support only 0 case + assert(keepdims == 0) + assert(len(axes) == 2) + + seedot_output_ast = AST.UninterpFuncCall(value_info[node.outputs[0]][1], 'ReduceMeanONNX', + [AST.ID(node_name_to_out_var_dict[inputsRef[0]]), AST.Int(axes[0], 32, False), AST.Int(axes[1], 32, False)]) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = output_name + return (innermost_let_ast_node, out_var_count) + + def BatchNormalization(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + + inputsRef = node.inputs + # Are running mean and var used for something? + assert(len(inputsRef)==5) + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.FusedBatchNorm(AST.ID(reshaped_input_name), + AST.ID(node_name_to_out_var_dict[inputsRef[1]]), + AST.ID(node_name_to_out_var_dict[inputsRef[2]]), + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + if(DEBUG): + print(node.outputs[0]) + print(onnx_input_shape, '->', seedot_input_shape, '->', onnx_output_shape) + + return (innermost_let_ast_node, out_var_count) + + def Reshape(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + + inputsRef = node.inputs + assert(len(inputsRef)==2) + # print(list(value_info[node.outputs[0]][1])) + + seedot_output_ast = AST.Reshape(AST.ID(node_name_to_out_var_dict[inputsRef[0]]), list(value_info[node.outputs[0]][1]), None) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def Flatten(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + + inputsRef = node.inputs + assert(len(inputsRef)==1) + + seedot_output_ast = AST.Reshape(AST.ID(node_name_to_out_var_dict[inputsRef[0]]), list(value_info[node.outputs[0]][1]), None) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def Conv(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + + inputsRef = node.inputs + # since two dimensions represent N: Number of batches and CI: Input channel + inputShape = value_info[inputsRef[0]][1] + spatial_size = len(inputShape)-2 + + if spatial_size == 2: + (innermost_let_ast_node, out_var_count, output_name) = ONNXNodesAST.conv2d(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST) + elif spatial_size == 3: + (innermost_let_ast_node, out_var_count, output_name) = ONNXNodesAST.conv3d(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST) + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast(node.outputs[0],value_info, output_name) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + return (innermost_let_ast_node, out_var_count) + + def conv2d(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + inputsRef = node.inputs + inputShape = value_info[inputsRef[0]][1] + filterShape = value_info[inputsRef[1]][1] + + stridesUsed = node.attrs['strides'] + + assert(len(inputsRef)==2 or len(inputsRef)==3) + assert(len(stridesUsed)==2) + assert(value_info[node.inputs[1]][1][2:] == tuple(node.attrs['kernel_shape'])) + + group = node.attrs['group'] if 'group' in node.attrs else 1 + [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] if 'pads' in node.attrs else [0,0,0,0] + # we assume VALID case when the padding is in string format + + options = {} + options[AST.PaddingKeysDict.FH] = filterShape[2] + options[AST.PaddingKeysDict.FW] = filterShape[3] + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideH] = stridesUsed[0] + options[AST.PaddingKeysDict.strideW] = stridesUsed[1] + options[AST.PaddingKeysDict.ConvDim] = 2 + options[AST.PaddingKeysDict.group] = group + + # print(inputShape, filterShape) + assert (inputShape[1] == filterShape[1]*group) + # For Input: + # [N, CI, H, W] is the Onnx order it should be changed to + # [N, H, W, CI] order + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) + out_var_count += 1 + + # For filter: + # [CO, CI1, FH, FW] is the Onnx order it should be changed to + # [FH, FW, CI1, CO] order + reshaped_filter_name = get_new_var_name(out_var_count) + reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), getOperatorsIdx('#'), AST.ID(reshaped_filter_name), options) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + # If there is bias to be added then reshape and add it + if (len(inputsRef) == 3): + reshaped_bias_name = get_new_var_name(out_var_count) + reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info, node_name_to_out_var_dict, 2) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.BOp(AST.ID(output_name), getOperatorsIdx('+'), AST.ID(reshaped_bias_name), options) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + return (innermost_let_ast_node, out_var_count, output_name) + + def conv3d(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + inputsRef = node.inputs + inputShape = value_info[inputsRef[0]][1] + filterShape = value_info[inputsRef[1]][1] + stridesUsed = node.attrs['strides'] + + assert(len(inputsRef)==2 or len(inputsRef)==3) + assert(len(stridesUsed)==3) + assert(value_info[node.inputs[1]][1][2:] == tuple(node.attrs['kernel_shape'])) + # verify this order + [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] + + options = {} + options[AST.PaddingKeysDict.FD] = filterShape[2] + options[AST.PaddingKeysDict.FH] = filterShape[3] + options[AST.PaddingKeysDict.FW] = filterShape[4] + options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft + options[AST.PaddingKeysDict.zPadDRight] = zPadDRight + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideD] = stridesUsed[0] + options[AST.PaddingKeysDict.strideH] = stridesUsed[1] + options[AST.PaddingKeysDict.strideW] = stridesUsed[2] + options[AST.PaddingKeysDict.ConvDim] = 3 + + assert (inputShape[1] == filterShape[1]) + # For Input: + # [N, CI, D, H, W] is the Onnx order it should be changed to + # [N, D, H, W, CI] order + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) + out_var_count += 1 + + # For filter: + # [CO, CI1, FD, FH, FW] is the Onnx order it should be changed to + # [FD, FH, FW, CI1, CO] order + reshaped_filter_name = get_new_var_name(out_var_count) + reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), getOperatorsIdx('#'), AST.ID(reshaped_filter_name), options) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + # If there is bias to be added then reshape and add it + if (len(inputsRef) == 3): + reshaped_bias_name = get_new_var_name(out_var_count) + reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info, node_name_to_out_var_dict, 3) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.BOp(AST.ID(output_name), getOperatorsIdx('+'), AST.ID(reshaped_bias_name), options) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + return (innermost_let_ast_node, out_var_count, output_name) + + def MaxPool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + return ONNXNodesAST.helper_processPool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST, 'MAXPOOL') + + def AvgPool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + return ONNXNodesAST.helper_processPool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST, 'AVGPOOL') + + def GlobalAveragePool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + inputsRef = node.inputs + assert(len(inputsRef)==1) + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.Pool(AST.Pool.PoolType.AvgPool, + AST.ID(reshaped_input_name), + { + AST.PaddingKeysDict.FH: value_info[inputsRef[0]][1][2], + AST.PaddingKeysDict.FW: value_info[inputsRef[0]][1][3], + AST.PaddingKeysDict.zPadHLeft: 0, + AST.PaddingKeysDict.zPadHRight: 0, + AST.PaddingKeysDict.zPadWLeft: 0, + AST.PaddingKeysDict.zPadWRight: 0, + AST.PaddingKeysDict.strideH: 1, + AST.PaddingKeysDict.strideW: 1 + } + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + return (innermost_let_ast_node, out_var_count) + + def helper_processPool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST, typeOfPool): + node = OnnxNode(node) + if(DEBUG): + print(node) + inputsRef = node.inputs + assert(len(inputsRef)==1) + + stridesUsed = node.attrs['strides'] + strideH = stridesUsed[0] + strideW = stridesUsed[1] + + kSizeUsed = node.attrs['kernel_shape'] + # assert((kSizeUsed[0] == 1) and (kSizeUsed[3] == 1)) + kSizeH = kSizeUsed[0] + kSizeW = kSizeUsed[1] + + inputShape = value_info[inputsRef[0]][1] + # print(inputShape) + imgH = inputShape[2] + imgW = inputShape[3] + + # verify order + [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] + + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) + out_var_count += 1 + + poolType = None + if typeOfPool=='MAXPOOL': poolType = AST.Pool.PoolType.MaxPool + elif typeOfPool=='AVGPOOL': poolType = AST.Pool.PoolType.AvgPool + else: + print("Unknown type of pooling layer.", file=sys.stderr) + assert(False) + seedot_output_ast = AST.Pool(poolType, + AST.ID(reshaped_input_name), + { + AST.PaddingKeysDict.FH: kSizeH, + AST.PaddingKeysDict.FW: kSizeW, + AST.PaddingKeysDict.zPadHLeft: zPadHLeft, + AST.PaddingKeysDict.zPadHRight: zPadHRight, + AST.PaddingKeysDict.zPadWLeft: zPadWLeft, + AST.PaddingKeysDict.zPadWRight: zPadWRight, + AST.PaddingKeysDict.strideH: strideH, + AST.PaddingKeysDict.strideW: strideW + } + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + return (innermost_let_ast_node, out_var_count) + + def ConvTranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + node = OnnxNode(node) + if(DEBUG): + print(node) + + inputsRef = node.inputs + # since two dimensions represent N: Number of batches and CI: Input channel + inputShape = value_info[inputsRef[0]][1] + spatial_size = len(inputShape)-2 + if spatial_size == 2: + (innermost_let_ast_node, out_var_count, output_name) = ONNXNodesAST.conv2dtranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST) + elif spatial_size == 3: + (innermost_let_ast_node, out_var_count, output_name) = ONNXNodesAST.conv3dtranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST) + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast(node.outputs[0],value_info, output_name) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + return (innermost_let_ast_node, out_var_count) + + def conv2dtranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + inputsRef = node.inputs + inputShape = value_info[inputsRef[0]][1] + filterShape = value_info[inputsRef[1]][1] + stridesUsed = node.attrs['strides'] + outputShape = value_info[node.outputs[0]][1] + + # sometimes there is a bias to be added as well + assert(len(inputsRef)==2 or len(inputsRef)==3) + assert(len(stridesUsed)==2) + assert(value_info[node.inputs[1]][1][2:] == tuple(node.attrs['kernel_shape'])) + # verify this order + [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] + + options = {} + options[AST.PaddingKeysDict.FH] = filterShape[2] + options[AST.PaddingKeysDict.FW] = filterShape[3] + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideH] = stridesUsed[0] + options[AST.PaddingKeysDict.strideW] = stridesUsed[1] + options[AST.PaddingKeysDict.ConvDim] = 2 + options[AST.PaddingKeysDict.outputImgH] = outputShape[2] + options[AST.PaddingKeysDict.outputImgW] = outputShape[3] + + assert (inputShape[1] == filterShape[0]) + # For Input: + # [N, CI, H, W] is the Onnx order it should be changed to + # [N, H, W, CI] order + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) + out_var_count += 1 + # For filter: + # [CI, CO, FH, FW] is the Onnx order it should be changed to + # [FH, FW, CI1, CO] order + reshaped_filter_name = get_new_var_name(out_var_count) + reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), getOperatorsIdx('#T'), AST.ID(reshaped_filter_name), options) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + # If there is bias to be added then reshape and add it + if (len(inputsRef) == 3): + biasShape = value_info[inputsRef[2]][1] + reshaped_bias_name = get_new_var_name(out_var_count) + reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info, node_name_to_out_var_dict, 2) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.BOp(AST.ID(output_name), getOperatorsIdx('+'), AST.ID(reshaped_bias_name), options) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + return (innermost_let_ast_node, out_var_count, output_name) + + def conv3dtranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): + inputsRef = node.inputs + inputShape = value_info[inputsRef[0]][1] + filterShape = value_info[inputsRef[1]][1] + stridesUsed = node.attrs['strides'] + outputShape = value_info[node.outputs[0]][1] + + # sometimes there is a bias to be added as well + assert(len(inputsRef)==2 or len(inputsRef)==3) + assert(len(stridesUsed)==3) + assert(value_info[node.inputs[1]][1][2:] == tuple(node.attrs['kernel_shape'])) + # verify this order + [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] + + options = {} + options[AST.PaddingKeysDict.FD] = filterShape[2] + options[AST.PaddingKeysDict.FH] = filterShape[3] + options[AST.PaddingKeysDict.FW] = filterShape[4] + options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft + options[AST.PaddingKeysDict.zPadDRight] = zPadDRight + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideD] = stridesUsed[0] + options[AST.PaddingKeysDict.strideH] = stridesUsed[1] + options[AST.PaddingKeysDict.strideW] = stridesUsed[2] + options[AST.PaddingKeysDict.ConvDim] = 3 + options[AST.PaddingKeysDict.outputImgD] = outputShape[2] + options[AST.PaddingKeysDict.outputImgH] = outputShape[3] + options[AST.PaddingKeysDict.outputImgW] = outputShape[4] + + assert (inputShape[1] == filterShape[0]) + # For Input: + # [N, CI, D, H, W] is the Onnx order it should be changed to + # [N, D, H, W, CI] order + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) + out_var_count += 1 + # For filter: + # [CI, CO, FD, FH, FW] is the Onnx order it should be changed to + # [FD, FH, FW, CI1, CO] order + reshaped_filter_name = get_new_var_name(out_var_count) + reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), getOperatorsIdx('#T'), AST.ID(reshaped_filter_name), options) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + # If there is bias to be added then reshape and add it + if (len(inputsRef) == 3): + biasShape = value_info[inputsRef[2]][1] + reshaped_bias_name = get_new_var_name(out_var_count) + reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info, node_name_to_out_var_dict, 3) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST) + out_var_count += 1 + + seedot_output_ast = AST.BOp(AST.ID(output_name), getOperatorsIdx('+'), AST.ID(reshaped_bias_name), options) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + out_var_count += 1 + + + + return (innermost_let_ast_node, out_var_count, output_name) + \ No newline at end of file diff --git a/Athos/ONNXCompiler/Readme.md b/Athos/ONNXCompiler/Readme.md new file mode 100644 index 0000000000000000000000000000000000000000..67449cfd07468d841a1fcb2dcee3e8ceeee10a59 --- /dev/null +++ b/Athos/ONNXCompiler/Readme.md @@ -0,0 +1,46 @@ +# Introduction +This part of the code compiles the onnx model to SeeDot AST. + +A model name must be provided to the `compile.sh` script and the model must be placed in `./models` directory +The script can be run with `./compile.sh model_name.onnx` command on the command line + +1) The script calls `onnx_run.py` to generate a random input of size matching the input size of the model. `onnx_run.py` further runs the model using `onnxruntime` and stores the output result as a `numpy` array. The input is stored as `model_name_input.npy` and the output is stored as `model_name_output.npy` + +2) Then it runs `process_onnx.py`. This python code combines `model_name_input.npy` and the values of other variables stored in the model to generate a `model_name_input.h` file which is later fed to the final code as input. `model_name_input.h` has all the values stored as fixed-point integers using the value of scale in the script. + +3) Then it runs `onnx inference` to calculate the input and output size for each onnx node. and it parses the onnx model using `OnnxNodesAST.py` and creates a `SeeDot` AST which is stored as `model_name.pkl` (using pickle) + +4) The `compile.sh` script further converts the SeeDot AST to EzPC code and the `EzPC` code is finally converted to the `CPP` program. This CPP program is compiled and ran with the given input. The output is stored as `debug/cpp_output_raw.txt`. Again, using the same scale this raw output is converted to the floating-point output and stored in `debug/cpp_output.txt` for easier manual comparison with the original onnx output. + +# Debugging and Logging +Since debugging the code is an arduous task, several things are logged in the following files + +To log the values of specific variables, the script can be run in debug mode using `./compile.sh model_name.onnx name_of_onnx_node` + +`onnx_seedot_name_map.txt` It stores a map from onnx names to SeeDot names of variables + +`seedot_ezpc_name_map.txt` It stores a map from SeeDot names to EzPC names of variables + +`onnx_ezpc_name_map.txt` The above two maps are combined to create a map that shows the mapping from onnx names to ezpc/cpp names + +`cpp_output_raw.txt` It contains the raw output after running the final code. In case if the script is run on `debug` mode with a debug name specified then the output has the values of the selected debug variable instead of the final variable. + +`cpp_output.txt` The above file is parsed and converted into a format where all fixed point integer values are converted to the easily readable floating format. As earlier in the case of `debug` mode the output contains the value of debug variable. + +`onnx_debug.txt` In the debug mode this file contains the value of selected onnx node computed using onnx runtime. + +`onnx_output.txt` This file contains the value of output computed using onnx runtime. + +`seedot_ast.txt` output of process_onnx.py is logged in this. It includes the seedot ast generated. + +`seedot_to_ezpc_output.txt` output of seedot compilation to ezpc is logged in this. + +# Dependency +Other than EzPC dependencies +`onnx` +`onnxruntime` + +# Testing +python3 -m unittest + + diff --git a/Athos/ONNXCompiler/__init__.py b/Athos/ONNXCompiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Athos/ONNXCompiler/common.py b/Athos/ONNXCompiler/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ae4cef33801e7a9c3188f0994b0825de6341c953 --- /dev/null +++ b/Athos/ONNXCompiler/common.py @@ -0,0 +1,109 @@ + +''' + +Authors: Shubham Ugare. + +Copyright: +Copyright (c) 2018 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' +import numpy +import os +import _pickle as pickle +import re + +def proto_val_to_dimension_tuple(proto_val): + return tuple([dim.dim_value for dim in proto_val.type.tensor_type.shape.dim]) + +def numpy_float_array_to_fixed_point_val_str(input_array, scale): + cnt = 0 + chunk = '' + for val in numpy.nditer(input_array): + val = int(val*(2**scale)) + chunk += str(val) + '\n' + cnt += 1 + return (chunk, cnt) + +def numpy_float_array_to_float_val_str(input_array): + chunk = '' + for val in numpy.nditer(input_array): + chunk += str(val) + '\n' + return chunk + +def write_debug_info(node_name_to_out_var_dict): + if not os.path.exists('debug'): + os.makedirs('debug') + + with open('debug/onnx_seedot_name_map.pkl', 'wb') as f: + pickle.dump(node_name_to_out_var_dict, f) + + with open('debug/onnx_seedot_name_map.txt', 'w') as f: + for val in node_name_to_out_var_dict: + f.write(val + ' ' + node_name_to_out_var_dict[val] + '\n') + + +def merge_name_map(): + onnx_seedot_name_map = pickle.load(open('debug/onnx_seedot_name_map.pkl', 'rb')) + seedot_ezpc_name_map = pickle.load(open('debug/seedot_ezpc_name_map.pkl', 'rb')) + + with open('debug/onnx_ezpc_name_map.txt', 'w') as f: + for val in onnx_seedot_name_map: + f.write(val + ' ' + seedot_ezpc_name_map[onnx_seedot_name_map[val]]) + +def get_seedot_name_from_onnx_name(onnx_name): + onnx_seedot_name_map = pickle.load(open('debug/onnx_seedot_name_map.pkl', 'rb')) + print(onnx_seedot_name_map[onnx_name]) + +def parse_output(scale): + f = open('debug/cpp_output_raw.txt', 'r') + g = open('debug/cpp_output.txt', 'w') + chunk = '' + for line in f: + if line.rstrip().replace('-','0').isdigit(): + val = float(line.rstrip()) + val = val/(2**scale) + chunk += str(val) + '\n' + g.write(chunk) + g.close() + +def extract_txt_to_numpy_array(file): + f = open(file, 'r') + op = [float(line.rstrip()) for line in f] + f.close() + return numpy.array(op, dtype=numpy.float32) + +def match_debug(decimal=4): + a = extract_txt_to_numpy_array('debug/onnx_debug.txt') + b = extract_txt_to_numpy_array('debug/cpp_output.txt') + numpy.testing.assert_almost_equal(a, b, decimal) + +def match_output(decimal=4): + a = extract_txt_to_numpy_array('debug/onnx_output.txt') + b = extract_txt_to_numpy_array('debug/cpp_output.txt') + numpy.testing.assert_almost_equal(a, b, decimal) + +def add_openmp_threading_to_convolution(file): + with open(file, 'r+') as f: + newfilename = file[:-5]+'1.cpp' + g = open(newfilename, 'w') + content = f.read() + content1 = re.sub('void Conv3D\(.*','\g<0> \n #pragma omp parallel for collapse(5) ', content) + content2 = re.sub('void ConvTranspose3D\(.*','\g<0> \n #pragma omp parallel for collapse(5) ', content1) + g.write(content2) + g.close() + diff --git a/Athos/ONNXCompiler/compile.sh b/Athos/ONNXCompiler/compile.sh new file mode 100755 index 0000000000000000000000000000000000000000..55d1ac52b5d9068441523995716fb99df43d68d5 --- /dev/null +++ b/Athos/ONNXCompiler/compile.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +# Authors: Shubham Ugare. + +# Copyright: +# Copyright (c) 2018 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# This script will +# 1) compile the ONNX model to SeeDot AST +# 2) Compile the SeeDot AST to ezpc +# 3) Convert the ezpc code to cpp and then run it on the given dataset + +# Any subsequent(*) commands which fail will cause the shell script to exit immediately +set -e + +modelName=$1 +debugOnnxNode=$2 + +EzPCDir="../../EzPC" +ONNX_dir="../../Athos/ONNXCompiler" +data_dir="debug/"${modelName} +BITLEN="64" +SCALINGFACTOR="24" +COMPILATIONTARGET="CPP" +ezpcOutputFullFileName=${modelName}'.ezpc' +compilationTargetLower=$(echo "$COMPILATIONTARGET" | awk '{print tolower($0)}') +compilationTargetHigher=$(echo "$COMPILATIONTARGET" | awk '{print toupper($0)}') +finalCodeOutputFileName=${modelName}'0.cpp' +finalCodeOutputFileName1=${modelName}'1.cpp' +inputFileName=${modelName}'_input.inp' +seedotASTName=${modelName}'.pkl' + +# modelname_input.npy and modelname_output.npy +onnxInputFileName=${modelName}'_input.npy' +onnxOutputFileName=${modelName}'_output.npy' + +GREEN='\033[0;32m' +NC='\033[0m' # No Color + +mkdir -p debug +mkdir -p ${data_dir} + +# Generating input may take time, hence skip if already generated +if [ -f ${data_dir}"/"${inputFileName} ]; then + echo -e "${GREEN}$inputFileName already exist, skipping process_onnx${NC}" +else + echo "Starting to gemerate random input" + python3 "create_input.py" ${modelName}'.onnx' $SCALINGFACTOR + echo -e "${GREEN}Finished generating input${NC}" +fi + +echo "Starting onnx run" +# can use either 'onnx_run_tf' or 'onnx_run' +# onnx_run is faster and has lesser dependencies +# but may not support all operations +python3 "onnx_run.py" ${modelName}'.onnx' ${debugOnnxNode} > "debug/log_onnx_run.txt" +echo -e "${GREEN}Finished onnx run${NC}" + +echo "Starting process_onnx" +echo "output of process_onnx and the resultant seedot ast are logged in debug/seedot_ast.txt" +python3 "process_onnx.py" ${modelName}'.onnx' > "debug/seedot_ast.txt" +echo -e "${GREEN}Finished process_onnx${NC}" + +echo "Starting seedot to ezpc compilation" +echo "output is logged in debug/seedot_to_ezpc_output.txt" + +if [ -z "$debugOnnxNode" ]; then + python3 ../SeeDot/SeeDot.py -p $seedotASTName --astFile ${data_dir}"/"$seedotASTName --outputFileName ${data_dir}"/"${ezpcOutputFullFileName} --consSF ${SCALINGFACTOR} --bitlen "$BITLEN" > "debug/seedot_to_ezpc_output.txt" +else + debugSeedotNode=$(python3 -c "import common; common.get_seedot_name_from_onnx_name(\"${debugOnnxNode}\")") + echo "${debugSeedotNode} is the corresponding SeeDot name" + python3 ../SeeDot/SeeDot.py -p $seedotASTName --astFile ${data_dir}"/"$seedotASTName --outputFileName ${data_dir}"/"${ezpcOutputFullFileName} --consSF ${SCALINGFACTOR} --debugVar ${debugSeedotNode} --bitlen "$BITLEN" > "debug/seedot_to_ezpc_output.txt" +fi +echo -e "${GREEN}Finished seedot to ezpc compilation${NC}" + +python3 -c 'import common; common.merge_name_map()' + + +cat "../TFEzPCLibrary/Library${BITLEN}_cpp.ezpc" "../TFEzPCLibrary/Library${BITLEN}_common.ezpc" ${data_dir}"/"${ezpcOutputFullFileName} > temp +mv temp "$ezpcOutputFullFileName" + +mv "$ezpcOutputFullFileName" "$EzPCDir/EzPC" +cd "$EzPCDir/EzPC" +eval `opam config env` + +echo "Starting with ezpc to cpp compilation" +./ezpc.sh "$ezpcOutputFullFileName" --bitlen "$BITLEN" --codegen "$compilationTargetHigher" --disable-tac +echo -e "${GREEN}Finished ezpc to cpp compilation ${NC}" + +# deleting the generated files +mv "$finalCodeOutputFileName" "$ONNX_dir" +DIREZPC="${EzPCDir}/EzPC/${modelName}" +for file in "$DIREZPC"* +do + rm "${file}" +done + +if [ "$compilationTargetLower" == "cpp" ]; then + cd "$ONNX_dir" + mv "$finalCodeOutputFileName" "$data_dir" + + echo "Adding openmp threading instructions to the 3d convolutions" + python3 -c "import common; common.add_openmp_threading_to_convolution('${data_dir}"/"${finalCodeOutputFileName}')" + + echo "compiling generated cpp code" + g++ -O3 -g -w -fopenmp ${data_dir}"/"${finalCodeOutputFileName1} -o ${data_dir}"/"${modelName}".out" + echo -e "${GREEN}compiling done ${NC}" + rm -f "debug/cpp_output_raw.txt" || true + echo "running the final code" + eval './'${data_dir}'/'${modelName}'.out' < ${data_dir}'/'${inputFileName} > "debug/cpp_output_raw.txt" + python3 -c "import common; common.parse_output(${SCALINGFACTOR})" + echo -e "${GREEN}All operations done. ${NC}" +fi diff --git a/Athos/ONNXCompiler/create_input.py b/Athos/ONNXCompiler/create_input.py new file mode 100644 index 0000000000000000000000000000000000000000..f8633a244f029536e1ad7ce52848b38719c05b36 --- /dev/null +++ b/Athos/ONNXCompiler/create_input.py @@ -0,0 +1,105 @@ + +''' + +Authors: Shubham Ugare. + +Copyright: +Copyright (c) 2018 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import numpy.random +import numpy as np +import common +import os, sys +import onnx +from onnx import helper +import math +from onnx import numpy_helper + +def main(): + if (len(sys.argv) < 3): + print("Model file or scaling factor unspecified.", file=sys.stderr) + exit(1) + + file_name = sys.argv[1] + scaling_factor = int(sys.argv[2]) + file_path = 'models/' + file_name + model_name = file_name[:-5] # name without the '.onnx' extension + model = onnx.load(file_path) + graph_def = model.graph + + # Generating input + input_dims = common.proto_val_to_dimension_tuple(model.graph.input[0]) + input_array = numpy.random.random(input_dims) + # input_array = numpy.ones(input_dims, dtype=float) + print('Generated random input of dimension ' + str(input_dims)) + np.save('debug/' + model_name + '/' + model_name + '_input', input_array) + + (chunk, cnt) = common.numpy_float_array_to_fixed_point_val_str(input_array, scaling_factor) + + model_name_to_val_dict = { init_vals.name: numpy_helper.to_array(init_vals).tolist() for init_vals in model.graph.initializer} + + preprocess_batch_normalization(graph_def, model_name_to_val_dict) + + for init_vals in model.graph.initializer: + (chunk_1, cnt_1) = common.numpy_float_array_to_fixed_point_val_str( + np.asarray(model_name_to_val_dict[init_vals.name], dtype=np.float32), scaling_factor) + chunk += chunk_1 + cnt += cnt_1 + + f = open('debug/' + model_name + '/' + model_name + '_input.h', 'w') + f.write(chunk) + f.close() + + print('Total ' + str(cnt) + ' integers were written in ' + model_name + '_input.h') + +def preprocess_batch_normalization(graph_def, model_name_to_val_dict): + # set names to graph nodes if not present + for node in graph_def.node: + node.name = node.output[0] + # Update the batch normalization scale and B + # so that mean and var are not required + if(node.op_type == 'BatchNormalization'): + # scale + gamma = model_name_to_val_dict[node.input[1]] + # B + beta = model_name_to_val_dict[node.input[2]] + mean = model_name_to_val_dict[node.input[3]] + var = model_name_to_val_dict[node.input[4]] + for i in range(len(gamma)): + rsigma = 1/math.sqrt(var[i]+1e-5) + gamma[i] = gamma[i]*rsigma + beta[i] = beta[i]-gamma[i]*mean[i] + mean[i] = 0 + var[i] = 1-1e-5 + + # Just testing if the correct values are put + model_name_to_val_dict2 = {} + for init_vals in graph_def.initializer: + # TODO: Remove float_data + model_name_to_val_dict2[init_vals.name] = init_vals.float_data + for node in graph_def.node: + node.name = node.output[0] + if(node.op_type == 'BatchNormalization'): + mean = model_name_to_val_dict[node.input[3]] + for val in mean: + assert(val == 0) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Athos/ONNXCompiler/onnx_run.py b/Athos/ONNXCompiler/onnx_run.py new file mode 100644 index 0000000000000000000000000000000000000000..adc59d815785ca85839ea379fde8c767024e5c54 --- /dev/null +++ b/Athos/ONNXCompiler/onnx_run.py @@ -0,0 +1,67 @@ + +''' + +Authors: Shubham Ugare. + +Copyright: +Copyright (c) 2018 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import numpy as np +import onnxruntime +import common +import os, sys +import onnx +from onnx import helper + +# First read the ONNX file +if (len(sys.argv) < 2): + print("TF python file unspecified.", file=sys.stderr) + exit(1) + +file_name = sys.argv[1] +file_path = 'models/' + file_name +model_name = file_name[:-5] # name without the '.onnx' extension +model = onnx.load(file_path) +sess = onnxruntime.InferenceSession(file_path) + +x = np.load('debug/' + model_name + '/' + model_name + '_input.npy') +x = x.astype(np.float32) + +input_name = model.graph.input[0].name + +if (len(sys.argv) > 2): + intermediate_layer_value_info = helper.ValueInfoProto() + intermediate_layer_value_info.name = sys.argv[2] + model.graph.output.extend([intermediate_layer_value_info]) + onnx.save(model, file_path + '_1') + sess = onnxruntime.InferenceSession(file_path + '_1') + pred = sess.run([intermediate_layer_value_info.name], {input_name: x}) + np.save('debug/' + model_name + '/' + model_name + '_debug', pred) + with open('debug/onnx_debug.txt', 'w') as f: + f.write(common.numpy_float_array_to_float_val_str(pred)) + print("Saving the onnx runtime intermediate output for " + intermediate_layer_value_info.name) + exit() + +pred = sess.run(None, {input_name: x}) +np.save('debug/' + model_name + '/' + model_name + '_output', pred) +with open('debug/onnx_output.txt', 'w') as f: + f.write(common.numpy_float_array_to_float_val_str(pred)) +output_dims = common.proto_val_to_dimension_tuple(model.graph.output[0]) +print("Saving the onnx runtime output of dimension " + str(output_dims)) diff --git a/Athos/ONNXCompiler/onnx_run_tf.py b/Athos/ONNXCompiler/onnx_run_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..0986bf656ae75534349dfb6981c9b06d51951dbc --- /dev/null +++ b/Athos/ONNXCompiler/onnx_run_tf.py @@ -0,0 +1,97 @@ + +''' + +Authors: Shubham Ugare. + +Copyright: +Copyright (c) 2018 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +''' +onnx_run is faster but may not support all operations +onnx_run_tf uses tensorflow backend to run the inference +''' + +import numpy as np +import common +import os, sys +import onnx +from onnx import helper +from onnx_tf.backend import prepare +from onnx import TensorProto + +def main(): + # First read the ONNX file + if (len(sys.argv) < 2): + print("TF python file unspecified.", file=sys.stderr) + exit(1) + + file_name = sys.argv[1] + file_path = 'models/' + file_name + model_name = file_name[:-5] # name without the '.onnx' extension + model = onnx.load(file_path) + model = preprocess_for_tf(model) + + x = np.load('debug/' + model_name + '/' + model_name + '_input.npy') + x = x.astype(np.float32) + + input_name = model.graph.input[0].name + output_name = model.graph.output[0].name + + if (len(sys.argv) > 2): + intermediate_layer_value_info = helper.ValueInfoProto() + intermediate_layer_value_info_name = 'tf_' + sys.argv[2] + intermediate_layer_value_info = helper.make_tensor_value_info(intermediate_layer_value_info_name, TensorProto.FLOAT, []) + model.graph.output.extend([intermediate_layer_value_info]) + output = prepare(model).run(x) + pred = getattr(output, intermediate_layer_value_info_name) + np.save('debug/' + model_name + '/' + model_name + '_debug', pred) + with open('debug/onnx_debug.txt', 'w') as f: + f.write(common.numpy_float_array_to_float_val_str(pred)) + print("Saving the onnx runtime intermediate output for " + intermediate_layer_value_info.name) + exit() + + output = prepare(model).run(x) + pred = getattr(output, output_name) + np.save('debug/' + model_name + '/' + model_name + '_output', pred) + with open('debug/onnx_output.txt', 'w') as f: + f.write(common.numpy_float_array_to_float_val_str(pred)) + output_dims = common.proto_val_to_dimension_tuple(model.graph.output[0]) + print("Saving the onnx runtime output of dimension " + str(output_dims)) + +def preprocess_for_tf(model): + for init_vals in model.graph.initializer: + init_vals.name = 'tf_' + init_vals.name + + for inp in model.graph.input: + inp.name = 'tf_' + inp.name + + for op in model.graph.output: + op.name = 'tf_' + op.name + + for node in model.graph.node: + node.name = 'tf_' + node.name + for i in range(len(node.input)): + node.input[i] = 'tf_' + node.input[i] + for i in range(len(node.output)): + node.output[i] = 'tf_' + node.output[i] + return model + +if __name__ == "__main__": + main() diff --git a/Athos/ONNXCompiler/process_onnx.py b/Athos/ONNXCompiler/process_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..d15dccbbe875a04d46acc916e0966f7975ab374d --- /dev/null +++ b/Athos/ONNXCompiler/process_onnx.py @@ -0,0 +1,174 @@ + +''' +Authors: Shubham Ugare. +Copyright: +Copyright (c) 2018 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' + +import os, sys + +#Add SeeDot directory to path +sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'SeeDot')) + +# For this warning: https://stackoverflow.com/questions/47068709/your-cpu-supports-instructions-that-this-tensorflow-binary-was-not-compiled-to-u +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +import _pickle as pickle +import onnx +import onnx.shape_inference +import AST.AST as AST +from ONNXNodesAST import ONNXNodesAST +from onnx.helper import make_tensor_value_info +from onnx import TensorProto +from AST.PrintAST import PrintAST +from AST.MtdAST import MtdAST +import numpy +import common + +import numpy as np +np.set_printoptions(threshold=np.inf) + +DEBUG = False +out_var_prefix = "J" + +def main(): + sys.setrecursionlimit(10000) + # First read the ONNX file + if (len(sys.argv) < 2): + print("TF python file unspecified.", file=sys.stderr) + exit(1) + file_name = sys.argv[1] + file_path = 'models/' + file_name + model_name = file_name[:-5] # name without the '.onnx' extension + + # load the model and extract the graph + model = onnx.load(file_path) + graph_def = model.graph + + print(model.graph.value_info) + # Before shape inference (model.graph.value_info) should have shapes of all the variables and constants + model.graph.value_info.append(make_tensor_value_info(model.graph.input[0].name, TensorProto.FLOAT, common.proto_val_to_dimension_tuple(model.graph.input[0]))) + model.graph.value_info.append(make_tensor_value_info(model.graph.output[0].name, TensorProto.FLOAT, common.proto_val_to_dimension_tuple(model.graph.output[0]))) + + print(model.graph.value_info) + + for init_vals in model.graph.initializer: + model.graph.value_info.append(make_tensor_value_info(init_vals.name, TensorProto.FLOAT, tuple(init_vals.dims))) + + if(DEBUG): + print("Shape inference *****************") + print(model.graph.value_info) + + inferred_model = onnx.shape_inference.infer_shapes(model) + + if(DEBUG): + print("Printing shape ******************") + print(inferred_model.graph.value_info) + print("Done ******************") + + # value_info: dictionary of name -> (type, dimension tuple) + value_info = {} + for val in inferred_model.graph.value_info: + value_info[val.name] = (val.type.tensor_type.elem_type, common.proto_val_to_dimension_tuple(val)) + + # Iterate through the ONNX graph nodes and translate them to SeeDot AST nodes + program = None + innermost_let_ast_node = None + node_name_to_out_var_dict = {} + out_var_count = 0 + mtdAST = MtdAST() + + (program, innermost_let_ast_node, out_var_count) = process_input_variables(program, innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info) + + process_onnx_nodes(innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info) + + PrintAST().visit(program) + + common.write_debug_info(node_name_to_out_var_dict) + + with open('debug/'+model_name+'/' +model_name + '.pkl', 'wb') as f: + pickle.dump(program, f) + +def process_input_variables(program, innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info): + node = graph_def.input[0] + curAst = ONNXNodesAST.Input(node, value_info, node_name_to_out_var_dict) + mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : 'Input', + AST.ASTNode.mtdKeyTFNodeName : node.name} + cur_out_var_ast_node = AST.ID(node.name) + + if program: + assert(type(innermost_let_ast_node) is AST.Let) + newNode = AST.Let(cur_out_var_ast_node, curAst, cur_out_var_ast_node) + mtdAST.visit(newNode, mtdForCurAST) + # Updating the innermost Let AST node and the expression for previous Let Node + innermost_let_ast_node.expr = newNode + innermost_let_ast_node = newNode + else: + innermost_let_ast_node = AST.Let(cur_out_var_ast_node, curAst, cur_out_var_ast_node) + mtdAST.visit(innermost_let_ast_node, mtdForCurAST) + innermost_let_ast_node.depth = 0 + program = innermost_let_ast_node + + node_name_to_out_var_dict[node.name] = node.name + + for node in graph_def.initializer: + if(DEBUG): + print("Node information") + print(node) + + curAst = ONNXNodesAST.Input(node, value_info, node_name_to_out_var_dict) + mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : 'Input', + AST.ASTNode.mtdKeyTFNodeName : node.name} + if (curAst is None): + continue + + cur_out_var_ast_node = AST.ID(node.name) + + if program: + assert(type(innermost_let_ast_node) is AST.Let) + newNode = AST.Let(cur_out_var_ast_node, curAst, cur_out_var_ast_node) + mtdAST.visit(newNode, mtdForCurAST) + # Updating the innermost Let AST node and the expression for previous Let Node + innermost_let_ast_node.expr = newNode + innermost_let_ast_node = newNode + else: + innermost_let_ast_node = AST.Let(cur_out_var_ast_node, curAst, cur_out_var_ast_node) + mtdAST.visit(innermost_let_ast_node, mtdForCurAST) + innermost_let_ast_node.depth = 0 + program = innermost_let_ast_node + + node_name_to_out_var_dict[node.name] = node.name + return (program, innermost_let_ast_node, out_var_count) + +def process_onnx_nodes(innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info): + for node in graph_def.node: + if(DEBUG): + print("Node information") + print(node) + + print("Processing " + node.name + "\n") + mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : node.op_type, + AST.ASTNode.mtdKeyTFNodeName : node.name} + + func = getattr(ONNXNodesAST, node.op_type) + (innermost_let_ast_node, out_var_count) = func(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST) + + assert(type(innermost_let_ast_node) is AST.Let) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Athos/ONNXCompiler/test/__init__.py b/Athos/ONNXCompiler/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Athos/ONNXCompiler/test/test.py b/Athos/ONNXCompiler/test/test.py new file mode 100644 index 0000000000000000000000000000000000000000..b57c8c0869a9cc1ab95f4177da997ebc2ed94bb0 --- /dev/null +++ b/Athos/ONNXCompiler/test/test.py @@ -0,0 +1,273 @@ +''' + +Authors: Shubham Ugare. + +Copyright: +Copyright (c) 2018 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + + +import onnx +from onnx import helper, numpy_helper +import unittest +from onnx import TensorProto +import numpy as np +import subprocess +import common +from datetime import date +import time +import hashlib + +class TestNode(unittest.TestCase): + + def _get_rnd_float32(self, low=-1.0, high=1.0, shape=None): + output = np.random.uniform(low, high, shape) + cnt = 1 + for val in shape: cnt*=val + if shape == None: + return np.float32(output) + else: + return output.astype(np.float32).reshape(cnt).tolist() + + def check_result(self, graph, name): + current_milli_time = lambda: str(int(round(time.time() * 1000))) + name = name + "_" + current_milli_time() + model = onnx.helper.make_model(graph, producer_name='onnx-compiler-test') + onnx.save(model, 'models/' + name + '.onnx') + + old_hash = hashlib.md5(open('debug/cpp_output.txt','rb').read()).hexdigest() + + bashCommand = './compile.sh ' + name + process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) + output, error = process.communicate() + + print(output) + print(error) + new_hash = hashlib.md5(open('debug/cpp_output.txt','rb').read()).hexdigest() + + self.assertNotEqual(old_hash, new_hash, 'the compilation did not terminate') + + res_onnx = common.extract_txt_to_numpy_array('debug/onnx_output.txt') + res_cpp = common.extract_txt_to_numpy_array('debug/cpp_output.txt') + + np.save('res_onnx', res_onnx) + np.save('res_cpp', res_cpp) + + self.assertIsNone(error, 'error is non None') + np.testing.assert_almost_equal(res_cpp, res_onnx, decimal=4) + + + def test_conv2d(self): + name = "conv2d" + state_in = helper.make_tensor_value_info('state_in', + TensorProto.FLOAT, [1, 3, 10, 10]) + state_out = helper.make_tensor_value_info('state_out', + TensorProto.FLOAT, [1, 6, 5, 5]) + node_def = helper.make_node("Conv", ['state_in', 'weight'], ['state_out'], + pads=[1, 1, 1, 1], strides=[2, 2], kernel_shape=[3, 3], group=3) + + weight_shape = [6, 1, 3, 3] + weight_val = self._get_rnd_float32(shape=weight_shape) + + weight = helper.make_tensor('weight', TensorProto.FLOAT, weight_shape, weight_val) + + graph = helper.make_graph( + [node_def], + name, + [state_in], + [state_out], + [weight] + ) + self.check_result(graph, name) + + + def test_conv3d(self): + name = "conv3d" + state_in = helper.make_tensor_value_info('state_in',TensorProto.FLOAT, [1, 2, 4, 16, 16]) + state_out = helper.make_tensor_value_info('state_out', + TensorProto.FLOAT, [1, 2, 4, 16, 16]) + node_def = helper.make_node("Conv", ['state_in', 'weight'], ['state_out'], + pads=[1, 1, 1, 1, 1, 1], strides=[1, 1, 1], kernel_shape=[3, 3, 3]) + + weight_shape = [2, 2, 3, 3, 3] + weight_val = self._get_rnd_float32(shape=weight_shape) + np.save('weight', weight_val) + + weight = helper.make_tensor('weight', TensorProto.FLOAT, weight_shape, weight_val) + + graph = helper.make_graph( + [node_def], + name, + [state_in], + [state_out], + [weight] + ) + self.check_result(graph, name) + + def test_conv_transpose(self): + name = "conv_transpose" + state_in = helper.make_tensor_value_info('state_in', + TensorProto.FLOAT, [1, 3, 10, 10]) + state_out = helper.make_tensor_value_info('state_out', + TensorProto.FLOAT, [1, 5, 19, 19]) + node_def = helper.make_node("ConvTranspose", ['state_in', 'weight'], ['state_out'], + pads=[1, 1, 1, 1], strides=[2, 2], kernel_shape=[3, 3]) + + weight_shape = [3, 5, 3, 3] + weight_val = self._get_rnd_float32(shape=weight_shape) + + weight = helper.make_tensor('weight', TensorProto.FLOAT, weight_shape, weight_val) + + graph = helper.make_graph( + [node_def], + name, + [state_in], + [state_out], + [weight] + ) + + self.check_result(graph, name) + + # For this to run onnx_run_tf.py should be used in the compile script + # since onnxruntime does not support convtranspose3d + def test_conv_transpose3d(self): + name = "conv3dTranspose" + state_in = helper.make_tensor_value_info('state_in', + TensorProto.FLOAT, [1, 3, 10, 10, 10]) + state_out = helper.make_tensor_value_info('state_out', + TensorProto.FLOAT, [1, 5, 19, 19, 19]) + node_def = helper.make_node("ConvTranspose", ['state_in', 'weight', 'bias'], ['state_out'], + # check with pads which are not 1 + pads=[1, 1, 1, 1, 1, 1], strides=[2, 2, 2], kernel_shape=[3, 3, 3]) + + weight_shape = [3, 5, 3, 3, 3] + weight_val = self._get_rnd_float32(shape=weight_shape) + bias_shape = [5] + bias_val = self._get_rnd_float32(shape=bias_shape) + + weight = helper.make_tensor('weight', TensorProto.FLOAT, weight_shape, weight_val) + bias = helper.make_tensor('bias', TensorProto.FLOAT, bias_shape, bias_val) + + graph = helper.make_graph( + [node_def], + name, + [state_in], + [state_out], + [weight, bias] + ) + self.check_result(graph, name) + + def test_relu(self): + name = "relu" + state_in = helper.make_tensor_value_info('state_in', + TensorProto.FLOAT, [1, 3, 10, 10]) + state_out = helper.make_tensor_value_info('state_out', + TensorProto.FLOAT, [1, 3, 10, 10]) + node_def = helper.make_node("Relu", ['state_in'], ['state_out']) + graph = helper.make_graph( + [node_def], + name, + [state_in], + [state_out], + [] + ) + self.check_result(graph, name) + + def test_pad(self): + name = "pad" + state_in = helper.make_tensor_value_info('state_in', TensorProto.FLOAT, [1, 3, 10, 10]) + pads = helper.make_tensor_value_info('pads', TensorProto.INT64, [8]) + pad_init = numpy_helper.from_array(np.array([0,0,1,1,0,0,1,1], dtype=int), name='pads') + const_val = helper.make_tensor_value_info('const_val', TensorProto.FLOAT, [1]) + const_val_init = numpy_helper.from_array(np.array([0.0], dtype=np.float32), name='const_val') + state_out = helper.make_tensor_value_info('state_out', TensorProto.FLOAT, [1,3,12,12]) + node_def = helper.make_node("Pad", ['state_in', 'pads', 'const_val'], ['state_out'], mode="constant") + graph = helper.make_graph([node_def],name,[state_in, pads, const_val],[state_out],initializer=[pad_init, const_val_init]) + self.check_result(graph, name) + + + def test_relu3d(self): + name = "relu3d" + state_in = helper.make_tensor_value_info('state_in', + TensorProto.FLOAT, [1, 3, 7, 7, 7]) + state_out = helper.make_tensor_value_info('state_out', + TensorProto.FLOAT, [1, 3, 7, 7, 7]) + node_def = helper.make_node("Relu", ['state_in'], ['state_out']) + graph = helper.make_graph( + [node_def], + name, + [state_in], + [state_out], + [] + ) + self.check_result(graph, name) + + def test_reducemean(self): + name = "reducemean" + state_in = helper.make_tensor_value_info('state_in', + TensorProto.FLOAT, [1, 1024, 7, 7]) + state_out = helper.make_tensor_value_info('state_out', + TensorProto.FLOAT, [1, 1024]) + node_def = helper.make_node("ReduceMean", ['state_in'], ['state_out'], axes=[2,3], keepdims=0) + graph = helper.make_graph( + [node_def], + name, + [state_in], + [state_out], + [] + ) + self.check_result(graph, name) + + def test_batchnormalization(self): + name = "batchnormalization" + state_in = helper.make_tensor_value_info('state_in', + TensorProto.FLOAT, [1, 24, 10, 10]) + state_out = helper.make_tensor_value_info('state_out', + TensorProto.FLOAT, [1, 24, 10, 10]) + node_def = helper.make_node("BatchNormalization", ['state_in', 'weight', 'bias','mean','var'], ['state_out'], + momentum=0.8999999761581421) + + weight_shape = [24] + weight_val = self._get_rnd_float32(shape=weight_shape) + weight = helper.make_tensor('weight', TensorProto.FLOAT, weight_shape, weight_val) + + bias_shape = [24] + bias_val = self._get_rnd_float32(shape=weight_shape) + bias = helper.make_tensor('bias', TensorProto.FLOAT, bias_shape, bias_val) + + mean_shape = [24] + mean_val = self._get_rnd_float32(shape=weight_shape) + mean = helper.make_tensor('mean', TensorProto.FLOAT, mean_shape, mean_val) + + + var_shape = [24] + var_val = self._get_rnd_float32(shape=weight_shape, low=0, high=1) + var = helper.make_tensor('var', TensorProto.FLOAT, var_shape, var_val) + + graph = helper.make_graph( + [node_def], + name, + [state_in], + [state_out], + [weight, bias, mean, var] + ) + self.check_result(graph, name) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index ee8916ae74008f0293ac481df5fb2e3cb46f555e..5df8c3eb98c644cb28d3eae9f2895cb43cc0fb9a 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -29,6 +29,7 @@ OperatorsSymbolDict = { "SUB": '-', "MUL": '*', "CONV": '#', + "CONVTRANSPOSE": "#T", #ConvTranspose "RELU": 'relu', "Equal": '==', "ElemWiseMul":'.*', @@ -46,6 +47,7 @@ class Operators(Enum): SUB = auto() MUL = auto() CONV = auto() + CONVTRANSPOSE = auto() RELU = auto() Equal = auto() ElemWiseMul = auto() @@ -65,15 +67,49 @@ class Operators(Enum): assert(enumStr is not None) return Operators[enumStr] + def findConvTransposePadding(i, i_prime, f, p_total, stride): + # The parameters have the following semantics: + # i = conv input img size + # i_prime = convTranspose input img Size + # f = filter size + # p_total = conv input padding total + # stride = conv input stride + p_total_tr = 2*f - p_total - 2 + ((i + p_total - f)%stride) + stride_tr = 1 + i_prime_tilde = i_prime + (i_prime-1)*(stride-1) + return [p_total_tr, stride_tr, i_prime_tilde] + + def findLeftRightPaddingFromTotalPadding(totalPadding): + leftPadding = totalPadding // 2 + rightPadding = totalPadding - leftPadding + return [leftPadding, rightPadding] + + def findConvOutputImgSize(imgSize, totalPadding, filterSize, stride): + return ((imgSize + totalPadding - filterSize) // stride) + 1 + class PaddingKeysDict: + ConvDim = 2 #2D or 3D convolution, default to 2D ##TODO: Add 1D conv when required + #Also used for convTranpose FH = "FH" FW = "FW" + FD = "FD" zPadHLeft = "zPadHLeft" zPadHRight = "zPadHRight" zPadWLeft = "zPadWLeft" zPadWRight = "zPadWRight" + zPadDLeft = "zPadDLeft" + zPadDRight = "zPadDRight" strideH = "strideH" strideW = "strideW" + strideD = "strideD" + inputImgH = "inputImgH" + inputImgW = "inputImgW" + inputImgD = "inputImgD" + outputImgH = "outputImgH" + outputImgW = "outputImgW" + outputImgD = "outputImgD" + paddingUsedStr = "paddingUsedStr" + group = "group" # If this is marked true, each astNode checks the types of its inputs to confirm it satisfies the assumption # Turn this off to get speedup in compilation @@ -143,13 +179,24 @@ class Transp(ASTNode): super().__init__() self.expr = expr +# expr : ASTNode, perm : list of ints +class Transpose(ASTNode): + def __init__(self, expr: ASTNode, perm: list): + if assertInputTypes: + assert isinstance(expr, ASTNode) + for elem in perm: assert isinstance(elem, int) + super().__init__() + self.expr = expr + self.perm = perm + # expr : ASTNode, shape : list of int, order : int : optional class Reshape(ASTNode): - def __init__(self, expr: ASTNode, shape: list, order: int): + def __init__(self, expr: ASTNode, shape: list, order: list): if assertInputTypes: assert isinstance(expr, ASTNode) for elem in shape: assert isinstance(elem, int) - assert isinstance(order, (int,type(None))) + if order: + for elem in order: assert isinstance(elem, int) super().__init__() self.expr = expr self.shape = shape @@ -204,13 +251,16 @@ class UOp(ASTNode): class BOp(ASTNode): # Options is used to convey extra info if the operator needs so # For example, it will be useful for convolution to convey strides etc. + + # IMPORTANT NOTE: The options parameter coming for ConvTranspose is for the conv of which it is an inverse + def __init__(self, expr1: ASTNode, op: Operators, expr2: ASTNode, options=None): if assertInputTypes: assert isinstance(expr1, ASTNode) assert isinstance(op, Operators) assert isinstance(expr2, ASTNode) if options: assert isinstance(options, dict) - if op == Operators.CONV: + if op == Operators.CONV or op == Operators.CONVTRANSPOSE: assert (PaddingKeysDict.FH in options) assert (PaddingKeysDict.FW in options) assert (PaddingKeysDict.zPadHLeft in options) @@ -219,6 +269,21 @@ class BOp(ASTNode): assert (PaddingKeysDict.zPadWRight in options) assert (PaddingKeysDict.strideH in options) assert (PaddingKeysDict.strideW in options) + if PaddingKeysDict.ConvDim in options: + assert(options[PaddingKeysDict.ConvDim]==2 or options[PaddingKeysDict.ConvDim]==3) #1D conv is not supported right now + if options[PaddingKeysDict.ConvDim]==3: + #3D conv - assert over the depth dimension + assert (PaddingKeysDict.FD in options) + assert (PaddingKeysDict.zPadDLeft in options) + assert (PaddingKeysDict.zPadDRight in options) + assert (PaddingKeysDict.strideD in options) + if op == Operators.CONVTRANSPOSE: + # In addition if this op is convTranspose, then + # the output size should also be specified + assert(PaddingKeysDict.outputImgH in options) + assert(PaddingKeysDict.outputImgW in options) + if (PaddingKeysDict.ConvDim in options) and (options[PaddingKeysDict.ConvDim]==3): + assert(PaddingKeysDict.outputImgD in options) super().__init__() self.expr1 = expr1 self.op = op @@ -326,3 +391,4 @@ class FusedBatchNorm(ASTNode): self.expr = expr self.multExpr = multExpr self.addExpr = addExpr + diff --git a/Athos/SeeDot/AST/ASTVisitor.py b/Athos/SeeDot/AST/ASTVisitor.py index f286eb85b1f6207b4bc3fe5ba991a0864edde4dc..4e86acc19cf39399ef4a591339c211cf94a0ba0d 100644 --- a/Athos/SeeDot/AST/ASTVisitor.py +++ b/Athos/SeeDot/AST/ASTVisitor.py @@ -42,6 +42,9 @@ class ASTVisitor: def visitTransp(self, node:AST.Transp, args=None): self.visit(node.expr, args) + def visitTranspose(self, node:AST.Transpose, args=None): + self.visit(node.expr, args) + def visitReshape(self, node:AST.Reshape, args=None): self.visit(node.expr, args) @@ -100,6 +103,8 @@ class ASTVisitor: return self.visitDecl(node, args) elif isinstance(node, AST.Transp): return self.visitTransp(node, args) + elif isinstance(node, AST.Transpose): + return self.visitTranspose(node, args) elif isinstance(node, AST.Reshape): return self.visitReshape(node, args) elif isinstance(node, AST.Pool): diff --git a/Athos/SeeDot/AST/MtdAST.py b/Athos/SeeDot/AST/MtdAST.py index d6eb09e642f0ad4ca2fdd61f16ac6f72c5f71c09..24abd2cd5d2d10bd2d7d5e9df035794804ad2010 100644 --- a/Athos/SeeDot/AST/MtdAST.py +++ b/Athos/SeeDot/AST/MtdAST.py @@ -42,6 +42,10 @@ class MtdAST(ASTVisitor): node.metadata.update(mtd) self.visit(node.expr, mtd) + def visitTranspose(self, node:AST.Transpose, mtd:dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + def visitReshape(self, node:AST.Reshape, mtd:dict): node.metadata.update(mtd) self.visit(node.expr, mtd) @@ -95,3 +99,5 @@ class MtdAST(ASTVisitor): self.visit(node.expr, mtd) self.visit(node.multExpr, mtd) self.visit(node.addExpr, mtd) + + diff --git a/Athos/SeeDot/AST/PrintAST.py b/Athos/SeeDot/AST/PrintAST.py index f387d07b5e6552976c092dbfe310c6e9278552a8..f9925f26291868541181dc1bdf633eb95b51531e 100644 --- a/Athos/SeeDot/AST/PrintAST.py +++ b/Athos/SeeDot/AST/PrintAST.py @@ -51,6 +51,12 @@ class PrintAST(ASTVisitor): self.visit(node.expr) print("^T", end=' ') + def visitTranspose(self, node:AST.Transpose, args=None): + node.expr.depth = node.depth + 1 + print(indent * node.depth, end=' ') + self.visit(node.expr) + print("^Transpose", end=' ') + def visitReshape(self, node:AST.Reshape, args=None): node.expr.depth = node.depth + 1 print(indent * node.depth, "reshape", end=' ') diff --git a/Athos/SeeDot/Codegen/EzPC.py b/Athos/SeeDot/Codegen/EzPC.py index 44e994191abeb5dd80c3910c3b2e7a4e4fad3462..e252b0b3d0d86a572b542875de6b76d8c03457bc 100644 --- a/Athos/SeeDot/Codegen/EzPC.py +++ b/Athos/SeeDot/Codegen/EzPC.py @@ -30,10 +30,11 @@ import IR.IRUtil as IRUtil from Codegen.CodegenBase import CodegenBase class EzPC(CodegenBase): - def __init__(self, writer, decls): + def __init__(self, writer, decls, debugVar): self.out = writer self.decls = decls self.consSFUsed = Util.Config.consSF + self.debugVar = debugVar def printAll(self, prog:IR.Prog, expr:IR.Expr): self._out_prefix() @@ -134,7 +135,10 @@ class EzPC(CodegenBase): self.out.printf('\n') def _out_suffix(self, expr:IR.Expr): - self.out.printf('output(CLIENT, ' + expr.idf + ');\n', indent=True) + if self.debugVar is None: + self.out.printf('output(CLIENT, ' + expr.idf + ');\n', indent=True) + else: + self.out.printf('output(CLIENT, ' + self.debugVar + ');\n', indent=True) self.out.decreaseIndent() self.out.printf('}\n', indent=True) diff --git a/Athos/SeeDot/Compiler.py b/Athos/SeeDot/Compiler.py index b49d0ed12c119eff177b2ed8888d45fd14c396b8..119600cce6d12ef4b410f7087bbef03f17d9b63b 100644 --- a/Athos/SeeDot/Compiler.py +++ b/Athos/SeeDot/Compiler.py @@ -40,7 +40,7 @@ import Optimizations.LivenessOpti as LivenessOpti class Compiler: def __init__(self, version, target, sfType, astFile, printASTBool, consSF, bitlen, outputFileName, - disableRMO, disableLivenessOpti, disableAllOpti): + disableRMO, disableLivenessOpti, disableAllOpti, debugVar): assert(version == Util.Version.Fixed) assert(target == Util.Target.EzPC) assert(sfType == Util.SFType.Constant) @@ -60,6 +60,7 @@ class Compiler: Util.Config.disableRMO = disableRMO Util.Config.disableLivenessOpti = disableLivenessOpti Util.Config.disableAllOpti = disableAllOpti + Util.Config.debugVar = debugVar def insertStartEndFunctionCalls(self, res:(IR.Prog, IR.Expr)): prog = res[0] @@ -99,13 +100,17 @@ class Compiler: compiler = IRBuilderCSF() res = compiler.visit(ast) + Util.write_debug_info(compiler.name_mapping) + # Insert a generic start_computation and end_computation function call after all input IR statements. res = self.insertStartEndFunctionCalls(res); writer = Writer(Util.Config.outputFileName) + debugVarEzPCName = compiler.name_mapping[Util.Config.debugVar] if (Util.Config.debugVar in compiler.name_mapping) else None + if Util.forEzPC(): - codegen = EzPCCodegen(writer, compiler.decls) + codegen = EzPCCodegen(writer, compiler.decls, debugVarEzPCName) else: assert False diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index b1de4b63e58769eafa2d369bc32ea73ac4251cfe..a448954f8f42b8e6b10c28fe8e172b693f09b5ee 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -40,12 +40,14 @@ class IRBuilderCSF(ASTVisitor): # For tracking temp variables self._var_cnt = 0 self._iter_cnt = 0 - # Global variables self.decls = {} #Mapping of (identifier name (string) -> list of [type, secret/public variable, bitlen of decl]) # The 2nd arg can be either 'secret' or 'public'. # If public/secret unspecified, default to 'secret'. # The 3rd arg is used to specify the bitlen of the decl. + + # Name mapping from SeeDot names to new names is useful for debugging + self.name_mapping = {} def getConsSF(self): return Util.Config.consSF @@ -181,9 +183,41 @@ class IRBuilderCSF(ASTVisitor): prog_2 = IRUtil.prog_merge(prog_1, prog_for) self.decls[expr_2.idf] = [typ_2] - prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, typ_2)]), prog) + prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, typ_2)]), prog_2) return (prog_2, expr_2) + def visitTranspose(self, node:AST.Transpose, args=None): + (inp_prog, inp_arr) = self.visit(node.expr) + inp_type = node.expr.type + out_type = node.type + inp_iters = self.getTempIterators(inp_type.dim) + out_iters = [] + perm = node.perm + for i in perm: + out_iters.append(inp_iters[i]) + out_arr = self.getTempVar() + out_arr_expr = IRUtil.addIndex(out_arr, out_iters) + inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters) + assign_expr = IR.Assn(out_arr_expr, inp_arr_expr) + loop = IRUtil.loop(inp_type.shape, inp_iters, [assign_expr]) + # Finalize + comment1 = IR.Comment(str(node.metadata)) + comment2 = IR.Comment("transpose(" + inp_arr.idf + ", [" + ', '.join(str(e) for e in inp_type.shape) + "] --> [" + ', '.join(str(e) for e in out_type.shape) + "])") + transpose_prog = IR.Prog([comment1, comment2] + loop) + final_prog = IRUtil.prog_merge(inp_prog, transpose_prog) + + # Update context + self.decls[out_arr.idf] = [out_type] + + # Update declarations + self.decls.update(dict((var.idf, [Type.Int(), 'public']) for var in inp_iters)) + + for var in inp_iters: + final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret="public")]), final_prog) + final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog) + + return (final_prog, out_arr) + def visitReshape(self, node:AST.Reshape, args=None): (prog_1, expr_1) = self.visit(node.expr) @@ -227,16 +261,19 @@ class IRBuilderCSF(ASTVisitor): cmd5 = [IRUtil.incCmd(curr_iter), IR.If(IRUtil.eq(curr_iter, curr_size), [IRUtil.initVarToZero(curr_iter)] + cmd5)] # Outer loop + # The iterators are selected based on the selection order specified by the user loopShape = [] loopIters = [] - if node.order: + + if(node.order): for order in node.order: order = order - 1 loopShape.append(typ_2.shape[order]) loopIters.append(iters_2[order]) else: loopShape = typ_2.shape - loopIters = iters_2 + loopIters = iters_2 + loop2 = IRUtil.loop(loopShape, loopIters, [IR.Assn(IRUtil.addIndex(expr_2, iters_2), IRUtil.addIndex(expr_1, iters_1))] + cmd5) @@ -347,12 +384,72 @@ class IRBuilderCSF(ASTVisitor): def visitBOp(self, node:AST.BOp, args=None): op = node.op - if (op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal, AST.Operators.Max]): return self.visitBopAddOrSubLike(node) + if (op in [AST.Operators.ADD, AST.Operators.SUB]): return self.visitBopAddOrSub(node) + elif (op in [AST.Operators.Equal, AST.Operators.Max]): return self.visitBopAddOrSubLike(node) elif (op in [AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]): return self.visitBopElemWiseOp(node) - elif op == AST.Operators.MUL: return self.visitBopMul(node) - elif op == AST.Operators.CONV: return self.visitBopConv(node) + elif op == AST.Operators.MUL: return self.visitBopMul(node) + elif op == AST.Operators.CONV: return self.visitBopConv(node) + elif op == AST.Operators.CONVTRANSPOSE: return self.visitBopConvTranspose(node) else: assert False + def visitBopAddOrSub(self, node:AST.BOp, args=None): + (prog_1, expr_1) = self.visit(node.expr1) + (prog_2, expr_2) = self.visit(node.expr2) + + # op_ir, typ_3 + op = node.op + if (op == AST.Operators.ADD): + (op_ir, op_fn) = (IR.Op.Op['+'], operator.add) + funcName = "MatAdd" + elif (op == AST.Operators.SUB): + (op_ir, op_fn) = (IR.Op.Op['-'], operator.sub) + funcName = "MatSub" + else: + assert False + + typ_3 = node.type + + # e : Int + if Type.isInt(typ_3): + prog_3 = IRUtil.prog_merge(prog_1, prog_2) + expr_3 = IR.IntBop(expr_1, op_ir, expr_2) + # e : Tensor() -- float, or Tensor(..) + else: + ## TODO : Hack for techfest + if (node.type.dim != node.expr1.type.dim): + # This needs broadcast of expr1 + assert False # For now this shouldn't occur + if (node.type.dim != node.expr2.type.dim): + # This needs broadcast of expr2 + funcName += 'BroadCast' + + # decl fresh vars + expr_3 = self.getTempVar() + + cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) + outputShape = typ_3.shape + argsDict = OrderedDict() + inp1_shape = node.expr1.type.shape + inp2_shape = node.expr2.type.shape + for ii,curDimSize in enumerate(inp1_shape): + argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) + for ii,curDimSize in enumerate(inp2_shape): + argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) + for ii,curDimSize in enumerate(outputShape): + argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) + argsDict[expr_1] = "A" + argsDict[expr_2] = "B" + argsDict[expr_3] = "C" + funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), + argsDict + ) + comment = IR.Comment(str(node.metadata)) + prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, funcCall])) + self.decls[expr_3.idf] = [typ_3] + prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3) + + return (prog_3, expr_3) + def visitBopAddOrSubLike(self, node:AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) @@ -431,6 +528,13 @@ class IRBuilderCSF(ASTVisitor): cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) outputShape = typ_3.shape argsDict = OrderedDict() + inp1_shape = node.expr1.type.shape + inp2_shape = node.expr2.type.shape + print("Input shapes = ", inp1_shape, inp2_shape) + for ii,curDimSize in enumerate(inp1_shape): + argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) + for ii,curDimSize in enumerate(inp2_shape): + argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) for ii,curDimSize in enumerate(outputShape): argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) argsDict[expr_1] = "A" @@ -546,32 +650,153 @@ class IRBuilderCSF(ASTVisitor): (prog1, expr1) = self.visit(node.expr1) (prog2, expr2) = self.visit(node.expr2) - [N , H , W , CI] = node.expr1.type.shape - [FH, FW, CI, CO] = node.expr2.type.shape + convDim = 2 + if (AST.PaddingKeysDict.ConvDim in node.options): + convDim = node.options[AST.PaddingKeysDict.ConvDim] + + if convDim == 2: + [N, H, W, CI] = node.expr1.type.shape + [FH, FW, CI1, CO] = node.expr2.type.shape + elif convDim == 3: + [N, D, H, W, CI] = node.expr1.type.shape + [FD, FH, FW, CI1, CO] = node.expr2.type.shape + else: + assert(False) returnExpr = self.getTempVar() - comment = IR.Comment(expr1.idf + ' # ' + expr2.idf) + comment = IR.Comment(expr1.idf + ' # ' + expr2.idf + ', convDim = ' + str(convDim)) funcCallArgsDict = OrderedDict() funcCallArgsDict[IR.Int(N, 32)] = "N" + if convDim == 3: + funcCallArgsDict[IR.Int(D, 32)] = "D" funcCallArgsDict[IR.Int(H, 32)] = "H" funcCallArgsDict[IR.Int(W, 32)] = "W" funcCallArgsDict[IR.Int(CI, 32)] = "CI" + if convDim == 3: + funcCallArgsDict[IR.Int(FD, 32)] = "FD" funcCallArgsDict[IR.Int(FH, 32)] = "FH" funcCallArgsDict[IR.Int(FW, 32)] = "FW" funcCallArgsDict[IR.Int(CO, 32)] = "CO" + if convDim == 3: + funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadDLeft], 32)] = "zPadDLeft" + funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadDRight], 32)] = "zPadDRight" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft], 32)] = "zPadHLeft" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHRight], 32)] = "zPadHRight" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft], 32)] = "zPadWLeft" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWRight], 32)] = "zPadWRight" + if convDim == 3: + funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideD], 32)] = "strideD" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideH], 32)] = "strideH" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideW], 32)] = "strideW" + + isGroupConv = False + if AST.PaddingKeysDict.group in node.options.keys(): + funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.group], 32)] = "G" + isGroupConv = True + + funcCallArgsDict[expr1] = "input" + funcCallArgsDict[expr2] = "filter" + funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF" + funcCallArgsDict[returnExpr] = "output" + + if convDim == 2: + funcCallName = "Conv2DCSF" + else: + funcCallName = "Conv3DCSF" + + if isGroupConv: + funcCallName += "Group" + + funcCall = IR.FuncCall(funcCallName, funcCallArgsDict) + + progConv = IR.Prog([comment, funcCall]) + returnProg = IRUtil.prog_merge(prog1, prog2, progConv) + + self.decls[returnExpr.idf] = [node.type] + returnProg = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg) + return (returnProg, returnExpr) + + def visitBopConvTranspose(self, node:AST.BOp, args=None): + (prog1, expr1) = self.visit(node.expr1) + (prog2, expr2) = self.visit(node.expr2) + + convDim = 2 + if (AST.PaddingKeysDict.ConvDim in node.options): + convDim = node.options[AST.PaddingKeysDict.ConvDim] + + if convDim==2: + [N, H_prime, W_prime, CI1] = node.expr1.type.shape + [FH, FW, CO, CI] = node.expr2.type.shape + elif convDim==3: + [N, D_prime, H_prime, W_prime, CI1] = node.expr1.type.shape + [FD, FH, FW, CO, CI] = node.expr2.type.shape + else: + assert(False) + assert(CI1 == CI) + + H = node.options[AST.PaddingKeysDict.outputImgH] #outputH + W = node.options[AST.PaddingKeysDict.outputImgW] #outputW + pad_h_total = node.options[AST.PaddingKeysDict.zPadHLeft] + node.options[AST.PaddingKeysDict.zPadHRight] + pad_w_total = node.options[AST.PaddingKeysDict.zPadWLeft] + node.options[AST.PaddingKeysDict.zPadWRight] + strideH = node.options[AST.PaddingKeysDict.strideH] + strideW = node.options[AST.PaddingKeysDict.strideW] + [pad_h_tr_total, stride_h_tr, h_prime_tilde] = AST.Operators.findConvTransposePadding(H, H_prime, FH, pad_h_total, strideH) + [pad_w_tr_total, stride_w_tr, w_prime_tilde] = AST.Operators.findConvTransposePadding(W, W_prime, FW, pad_w_total, strideW) + + [pad_h_tr_left, pad_h_tr_right] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_h_tr_total) + [pad_w_tr_left, pad_w_tr_right] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_w_tr_total) + + assert(AST.Operators.findConvOutputImgSize(h_prime_tilde, pad_h_tr_total, FH, stride_h_tr) == H) + assert(AST.Operators.findConvOutputImgSize(w_prime_tilde, pad_w_tr_total, FW, stride_w_tr) == W) + + if convDim == 3: + D = node.options[AST.PaddingKeysDict.outputImgD] #outputD + pad_d_total = node.options[AST.PaddingKeysDict.zPadDLeft] + node.options[AST.PaddingKeysDict.zPadDRight] + strideD = node.options[AST.PaddingKeysDict.strideD] + [pad_d_tr_total, stride_d_tr, d_prime_tilde] = AST.Operators.findConvTransposePadding(D, D_prime, FD, pad_d_total, strideD) + [pad_d_tr_left, pad_d_tr_right] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_d_tr_total) + assert(AST.Operators.findConvOutputImgSize(d_prime_tilde, pad_d_tr_total, FD, stride_d_tr) == D) + + returnExpr = self.getTempVar() + comment = IR.Comment(expr1.idf + ' #T ' + expr2.idf + ', convDim = ' + str(convDim)) + funcCallArgsDict = OrderedDict() + funcCallArgsDict[IR.Int(N, 32)] = "N" + if convDim==3: + funcCallArgsDict[IR.Int(D_prime, 32)] = "D_prime" + funcCallArgsDict[IR.Int(H_prime, 32)] = "H_prime" + funcCallArgsDict[IR.Int(W_prime, 32)] = "W_prime" + funcCallArgsDict[IR.Int(CI, 32)] = "CI" + if convDim==3: + funcCallArgsDict[IR.Int(FD, 32)] = "FD" + funcCallArgsDict[IR.Int(FH, 32)] = "FH" + funcCallArgsDict[IR.Int(FW, 32)] = "FW" + funcCallArgsDict[IR.Int(CO, 32)] = "CO" + if convDim==3: + funcCallArgsDict[IR.Int(D, 32)] = "D" + funcCallArgsDict[IR.Int(H, 32)] = "H" + funcCallArgsDict[IR.Int(W, 32)] = "W" + if convDim==3: + funcCallArgsDict[IR.Int(pad_d_tr_left, 32)] = "pad_d_tr_left" + funcCallArgsDict[IR.Int(pad_d_tr_right, 32)] = "pad_d_tr_right" + funcCallArgsDict[IR.Int(pad_h_tr_left, 32)] = "pad_h_tr_left" + funcCallArgsDict[IR.Int(pad_h_tr_right, 32)] = "pad_h_tr_right" + funcCallArgsDict[IR.Int(pad_w_tr_left, 32)] = "pad_w_tr_left" + funcCallArgsDict[IR.Int(pad_w_tr_right, 32)] = "pad_w_tr_right" + if convDim==3: + funcCallArgsDict[IR.Int(strideD, 32)] = "strideD" + funcCallArgsDict[IR.Int(strideH, 32)] = "strideH" + funcCallArgsDict[IR.Int(strideW, 32)] = "strideW" funcCallArgsDict[expr1] = "input" funcCallArgsDict[expr2] = "filter" funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF" funcCallArgsDict[returnExpr] = "output" - funcCall = IR.FuncCall("Conv2DCSF", funcCallArgsDict) + if convDim == 2: + funcCallName = "ConvTranspose2DCSF" + else: + funcCallName = "ConvTranspose3DCSF" + funcCall = IR.FuncCall(funcCallName, funcCallArgsDict) progConv = IR.Prog([comment, funcCall]) returnProg = IRUtil.prog_merge(prog1, prog2, progConv) @@ -634,6 +859,7 @@ class IRBuilderCSF(ASTVisitor): (prog_1, expr_1) = self.visit(node.decl) typ_1 = node.decl.type idf = node.name.name + self.name_mapping[idf] = expr_1.idf (prog_2, expr_2) = self.visit(node.expr) prog_2 = prog_2.subst(idf, expr_1) expr_2 = expr_2.subst(idf, expr_1) @@ -797,4 +1023,4 @@ class IRBuilderCSF(ASTVisitor): self.decls[returnExpr.idf] = [node.type] returnProg = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg) - return (returnProg, returnExpr) + return (returnProg, returnExpr) \ No newline at end of file diff --git a/Athos/SeeDot/Optimizations/LivenessOpti.py b/Athos/SeeDot/Optimizations/LivenessOpti.py index 130b19a93bae6279f520dfcace16cfb47f758d83..5e2b3e643ddc7b055ddfddfdb047ee370b8e04e2 100644 --- a/Athos/SeeDot/Optimizations/LivenessOpti.py +++ b/Athos/SeeDot/Optimizations/LivenessOpti.py @@ -52,6 +52,11 @@ class LivenessAnalysis(ASTVisitor): node.optidict[self.optidictKey] = unboundVars return unboundVars + def visitTranspose(self, node:AST.Transp, args): + unboundVars = self.visit(node.expr, args) + node.optidict[self.optidictKey] = unboundVars + return unboundVars + def visitReshape(self, node:AST.Reshape, args): unboundVars = self.visit(node.expr, args) node.optidict[self.optidictKey] = unboundVars diff --git a/Athos/SeeDot/SeeDot.py b/Athos/SeeDot/SeeDot.py index 847d2ffce296e3f0d30e0f457699fe9c7f15cbb7..25b8ef7e67838e66e707fad53cf1d07e88d75558 100644 --- a/Athos/SeeDot/SeeDot.py +++ b/Athos/SeeDot/SeeDot.py @@ -42,6 +42,7 @@ class MainDriver: parser.add_argument("--disableLivenessOpti", default=False, type=bool, help="Disable liveness optimization.") parser.add_argument("--disableAllOpti", default=False, type=bool, help="Disable all optimizations.") parser.add_argument("--outputFileName", help="Name of the output file with extension (Donot include folder path).") + parser.add_argument("--debugVar", help="Name of the onnx node to be debugged") self.args = parser.parse_args() @@ -67,7 +68,8 @@ class MainDriver: self.args.outputFileName, self.args.disableRMO, self.args.disableLivenessOpti, - self.args.disableAllOpti + self.args.disableAllOpti, + self.args.debugVar ) obj.run() diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py index b8173b58c3d0fd194376e9fa59e80bfa8805a3ee..e7a5276c496eaa0fe131eedbe5d5e7bfd8192a6c 100644 --- a/Athos/SeeDot/Type.py +++ b/Athos/SeeDot/Type.py @@ -27,7 +27,6 @@ from functools import reduce import AST.AST as AST from AST.ASTVisitor import ASTVisitor - class Type: pass @@ -101,6 +100,20 @@ class InferType(ASTVisitor): return node.type + def visitTranspose(self, node:AST.Transpose, args=None): + node.expr.gamma = dict(node.gamma) + exprType = self.visit(node.expr) + + assert isTensor(exprType) + + perm = node.perm + shape = exprType.shape + new_shape = [] + for i in perm: + new_shape.append(shape[i]) + node.type = Tensor(new_shape) + return node.type + def visitReshape(self, node:AST.Reshape, args=None): node.expr.gamma = dict(node.gamma) exprType = self.visit(node.expr) @@ -172,6 +185,8 @@ class InferType(ASTVisitor): return self.visitBopMul(node, eType, fType) elif node.op == AST.Operators.CONV: return self.visitBopConv(node, eType, fType) + elif node.op == AST.Operators.CONVTRANSPOSE: + return self.visitBopConvTranspose(node, eType, fType) else: assert False @@ -236,13 +251,41 @@ class InferType(ASTVisitor): def visitBopConv(self, node:AST.BOp, eType:Type, fType:Type, args=None): assert isTensor(eType) and isTensor(fType) - assert eType.dim == 4 and fType.dim == 4 + convDim = 2 + group = 1 + if AST.PaddingKeysDict.ConvDim in node.options: + convDim = node.options[AST.PaddingKeysDict.ConvDim] + + if convDim==2: + assert eType.dim == 4 and fType.dim == 4 + elif convDim==3: + assert eType.dim == 5 and fType.dim == 5 + else: + assert(False) - [N, H, W, CI] = eType.shape - [FH, FW, CI1, CO] = fType.shape + N = D = H = W = CI = FD = FH = FW = CI1 = CO = -1 + newD = -1 + if (convDim == 2): + [N, H, W, CI] = eType.shape + [FH, FW, CI1, CO] = fType.shape + elif (convDim == 3): + [N, D, H, W, CI] = eType.shape + [FD, FH, FW, CI1, CO] = fType.shape + assert(FD == node.options[AST.PaddingKeysDict.FD]) + zPadDLeft = node.options[AST.PaddingKeysDict.zPadDLeft] + zPadDRight = node.options[AST.PaddingKeysDict.zPadDRight] + strideD = node.options[AST.PaddingKeysDict.strideD] + + newD = ((D + zPadDLeft + zPadDRight - FD)//strideD) + 1 + else: + assert(False) + + if AST.PaddingKeysDict.group in node.options: + group = node.options[AST.PaddingKeysDict.group] + assert(FH == node.options[AST.PaddingKeysDict.FH]) assert(FW == node.options[AST.PaddingKeysDict.FW]) - assert(CI1 == CI) + assert(CI1*group == CI) zPadHLeft = node.options[AST.PaddingKeysDict.zPadHLeft] zPadHRight = node.options[AST.PaddingKeysDict.zPadHRight] zPadWLeft = node.options[AST.PaddingKeysDict.zPadWLeft] @@ -253,7 +296,47 @@ class InferType(ASTVisitor): newH = ((H + zPadHLeft + zPadHRight - FH)//strideH) + 1 newW = ((W + zPadWLeft + zPadWRight - FW)//strideW) + 1 - shape = [N, newH, newW, CO] + if convDim == 2: + shape = [N, newH, newW, CO] + elif convDim == 3: + shape = [N, newD, newH, newW, CO] + node.type = Tensor(shape) + return node.type + + def visitBopConvTranspose(self, node:AST.BOp, eType:Type, fType:Type, args=None): + assert isTensor(eType) and isTensor(fType) + + convDim = 2 + if AST.PaddingKeysDict.ConvDim in node.options: + convDim = node.options[AST.PaddingKeysDict.ConvDim] + + if convDim==2: + [N, HP, WP, CI1] = eType.shape + [FH, FW, CO, CI] = fType.shape + elif convDim==3: + [N, DP, HP, WP, CI1] = eType.shape + [FD, FH, FW, CO, CI] = fType.shape + else: + assert(False) + assert(CI1 == CI) + if convDim==3: + outputImgD = node.options[AST.PaddingKeysDict.outputImgD] + outputImgH = node.options[AST.PaddingKeysDict.outputImgH] + outputImgW = node.options[AST.PaddingKeysDict.outputImgW] + + if convDim==2: + shape = [N, outputImgH, outputImgW, CO] + else: + shape = [N, outputImgD, outputImgH, outputImgW, CO] + + # Logic explanation: + # ConvTranpose can be thought of as the inverse of some convolution for which it is doing the upsampling. + # For calculation of padding in the convTranspose operation, the output image size is required. + # This is why TF also mandates the operator to be specified with output size. + # This conv transpose operation can be thought of as conv between output + # of size shape = [N, outputImgH, outputImgW, CI], and filter of size [FH, FW, CI, CO]. + # Hence, the input for this convTranspose would be [N, HP, WP, CO] + node.type = Tensor(shape) return node.type @@ -357,5 +440,4 @@ class InferType(ASTVisitor): assert(exprType.shape[-1]==C1 and C1==C2) node.type = exprType - return node.type - + return node.type \ No newline at end of file diff --git a/Athos/SeeDot/Util.py b/Athos/SeeDot/Util.py index c9204b32733dcf5e5b6f33a427bb37e5020cc652..867fe05fabdba96ee385de18bda36535eb2beede 100644 --- a/Athos/SeeDot/Util.py +++ b/Athos/SeeDot/Util.py @@ -21,6 +21,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ''' +import os +import _pickle as pickle # Target word length. @@ -50,6 +52,7 @@ class Config: disableRMO = None disableLivenessOpti = None disableAllOpti = None + debugOnnx = None ###### Helper functions ###### def loadASTFromFile(): @@ -66,3 +69,14 @@ def copy_dict(dict_src:dict, diff={}): # z = [y1,y2,..] = [[x1,..], [x2,..], ..] --> [x1,.., x2,.., ..] def flatten(z:list): return [x for y in z for x in y] + +def write_debug_info(name_mapping): + if not os.path.exists('debug'): + os.makedirs('debug') + + with open('debug/seedot_ezpc_name_map.pkl', 'wb') as f: + pickle.dump(name_mapping, f) + + with open('debug/seedot_ezpc_name_map.txt', 'w') as f: + for val in name_mapping: + f.write(val + ' ' + name_mapping[val] + '\n') diff --git a/Athos/TFCompiler/Graph.py b/Athos/TFCompiler/Graph.py index a0029efe5967caf190cc2ea36590f9257a5c768d..2fc1ae5e61d11cc9c6526e228138dfa9281ffcf7 100644 --- a/Athos/TFCompiler/Graph.py +++ b/Athos/TFCompiler/Graph.py @@ -320,6 +320,19 @@ class Tensor: assert(False) self.__valArr = numpy.fromstring(bytes(self.__tensorBytes), dtype).tolist() return self.__valArr + + def getDType(self): + if self.__dtype == DataTypeEnum.DT_FLOAT: + dtype = numpy.dtype('<f4') + elif self.__dtype == DataTypeEnum.DT_BOOL: + dtype = numpy.dtype('bool') + elif self.__dtype == DataTypeEnum.DT_INT32: + dtype = numpy.dtype('int32') + elif self.__dtype == DataTypeEnum.DT_INT64: + dtype = numpy.dtype('int64') + else: + assert(False) + return dtype class MultiValue: def __init__(self): @@ -497,6 +510,12 @@ class Node: def getAttrMapRef(self): return self.__attr + def getAttrVal(self, attrName): + qName = '"' + attrName + '"' + if not qName in self.__attr: + return None + return self.__attr[qName] + def readAttrFromFilePointer(self, fileP, cnt): line = fileP.readline() cnt += 1 @@ -576,6 +595,9 @@ class Graph: self.__Nodes = {} # Map of (op, Node) self.__NodesLi = [] # Sequential list of nodes in the order in which its specified in graph_def. + def getAllNodes(self): + return self.__Nodes + def getAllNodesRef(self): return self.__NodesLi @@ -593,7 +615,7 @@ class Graph: curNode = Node() (noPaseError, cnt) = curNode.readFromFilePointer(fileP, cnt) if (noPaseError): - self.__Nodes[curNode.getOp()] = curNode + self.__Nodes[curNode.getName()] = curNode self.__NodesLi.append(curNode) else: print("Error parsing graph dump for node at line =", cnt, file=sys.stderr) diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index ba72228c9b9bd3b21f71bf6bdf692feda035e09b..170936760d8a7e8a7be2e0d8ec6ede2dc255322c 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -583,3 +583,4 @@ class TFNodesAST: # TFNodesAST.UninterpFuncCallNames.Pack.name, # list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)) + [AST.Int(axis)] ) # return (None, retAST) + \ No newline at end of file diff --git a/Athos/TFEzPCLibrary/Library32_common.ezpc b/Athos/TFEzPCLibrary/Library32_common.ezpc index f6184dbf38302d9d631a120d1ae8a5a8c6725621..1832c826a23e20c125f10a756ea4b6c1c16ee839 100644 --- a/Athos/TFEzPCLibrary/Library32_common.ezpc +++ b/Athos/TFEzPCLibrary/Library32_common.ezpc @@ -24,7 +24,7 @@ SOFTWARE. (**************************) (* TODO : the 2nd arg should be broadcasted *) -def void MatAddBroadCast2(int32_pl s1, int32_pl s2, int32_al[s1][s2] A, int32_al[s2] B, int32_al[s1][s2] outArr){ +def void MatAddBroadCast2(int32_pl a1, int32_pl a2, int32_pl b1, int32_pl s1, int32_pl s2, int32_al[s1][s2] A, int32_al[s2] B, int32_al[s1][s2] outArr){ for i1=[0:s1]{ for i2=[0:s2]{ outArr[i1][i2] = A[i1][i2] + B[i2]; @@ -32,16 +32,24 @@ def void MatAddBroadCast2(int32_pl s1, int32_pl s2, int32_al[s1][s2] A, int32_al }; } -def void MatAdd2(int32_pl s1, int32_pl s2, int32_al[s1][s2] A, int32_al[s1][s2] B, int32_al[s1][s2] outArr){ - for i1=[0:s1]{ - for i2=[0:s2]{ - outArr[i1][i2] = A[i1][i2] + B[i1][i2]; - }; - }; +def void MatAdd2(int32_pl a1, int32_pl a2, int32_pl b1, int32_pl b2, int32_pl s1, int32_pl s2, int32_al[a1][a2] A, int32_al[b1][b2] B, int32_al[s1][s2] outArr){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + outArr[i1][i2] = A[aIdx1][aIdx2] + B[bIdx1][bIdx2]; + }; + }; } (* TODO : the 2nd arg should be broadcasted *) -def void MatAddBroadCast4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] A, int32_al[s4] B, int32_al[s1][s2][s3][s4] outArr){ +def void MatAddBroadCast4(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl b1, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] A, int32_al[s4] B, int32_al[s1][s2][s3][s4] outArr){ for i1=[0:s1]{ for i2=[0:s2]{ for i3=[0:s3]{ @@ -53,18 +61,82 @@ def void MatAddBroadCast4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in }; } -def void MatAdd4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] A, int32_al[s1][s2][s3][s4] B, int32_al[s1][s2][s3][s4] outArr){ +def void MatAddBroadCast5(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl a5, int32_pl b1, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] A, int32_al[s5] B, int32_al[s1][s2][s3][s4][s5] outArr){ for i1=[0:s1]{ for i2=[0:s2]{ for i3=[0:s3]{ for i4=[0:s4]{ - outArr[i1][i2][i3][i4] = A[i1][i2][i3][i4] + B[i1][i2][i3][i4]; + for i5=[0:s5]{ + outArr[i1][i2][i3][i4][i5] = A[i1][i2][i3][i4][i5] + B[i5]; + }; }; }; }; }; } +def void MatAdd4(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl b1, int32_pl b2, int32_pl b3, int32_pl b4, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[a1][a2][a3][a4] A, int32_al[b1][b2][b3][b4] B, int32_al[s1][s2][s3][s4] outArr){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl aIdx3 = 0; + int32_pl aIdx4 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + int32_pl bIdx3 = 0; + int32_pl bIdx4 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + for i3=[0:s3]{ + aIdx3 = ((a3 == 1) ? 0 : i3); + bIdx3 = ((b3 == 1) ? 0 : i3); + for i4=[0:s4]{ + aIdx4 = ((a4 == 1) ? 0 : i4); + bIdx4 = ((b4 == 1) ? 0 : i4); + outArr[i1][i2][i3][i4] = A[aIdx1][aIdx2][aIdx3][aIdx4] + B[bIdx1][bIdx2][bIdx3][bIdx4]; + }; + }; + }; + }; +} + +def void MatAdd5(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl a5, int32_pl b1, int32_pl b2, int32_pl b3, int32_pl b4, int32_pl b5, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[a1][a2][a3][a4][a5] A, int32_al[b1][b2][b3][b4][b5] B, int32_al[s1][s2][s3][s4][s5] outArr){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl aIdx3 = 0; + int32_pl aIdx4 = 0; + int32_pl aIdx5 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + int32_pl bIdx3 = 0; + int32_pl bIdx4 = 0; + int32_pl bIdx5 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + for i3=[0:s3]{ + aIdx3 = ((a3 == 1) ? 0 : i3); + bIdx3 = ((b3 == 1) ? 0 : i3); + for i4=[0:s4]{ + aIdx4 = ((a4 == 1) ? 0 : i4); + bIdx4 = ((b4 == 1) ? 0 : i4); + for i5=[0:s5]{ + aIdx5 = ((a5 == 1) ? 0 : i5); + bIdx5 = ((b5 == 1) ? 0 : i5); + outArr[i1][i2][i3][i4][i5] = A[aIdx1][aIdx2][aIdx3][aIdx4][aIdx5] + B[bIdx1][bIdx2][bIdx3][bIdx4][bIdx5]; + }; + }; + }; + }; + }; +} + (**************************) def void CreateTensor1(int32_pl s1, int32_pl val, int32_pl[s1] arr){ for i1=[0:s1]{ @@ -92,6 +164,20 @@ def void CreateTensor4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32 }; } +def void CreateTensor5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl val, int32_pl[s1][s2][s3][s4][s5] arr){ + for i1=[0:s1]{ + for i2=[0:s2]{ + for i3=[0:s3]{ + for i4=[0:s4]{ + for i5=[0:s5]{ + arr[i1][i2][i3][i4][i5] = val; + }; + }; + }; + }; + }; +} + (**************************) def void CopyTensor1(int32_pl s1, int32_al[s1] targetArr, int32_al[s1] fromArr, int32_al[s1] ignore){ for i1=[0:s1]{ @@ -155,6 +241,20 @@ def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2 }; } +def void CreateCopy5511(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl inps1, int32_pl inps2, int32_pl inps3, int32_pl inps4, int32_pl inps5, int32_al[inps1][inps2][inps3][inps4][inps5] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int32_al[s1][s2][s3][s4][s5] outArr){ + for i=[0:s1]{ + for j=[0:s2]{ + for k=[0:s3]{ + for l=[0:s4]{ + for m=[0:s5]{ + outArr[i][j][k][l][m] = inArr[beginIdx[0]+i][beginIdx[1]+j][beginIdx[2]+k][beginIdx[3]+l][beginIdx[4]+m]; + }; + }; + }; + }; + }; +} + (**************************) def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int32_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int32_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int32_al[s1][s2][s3][s4] outp){ for i1=[0:s1]{ @@ -227,9 +327,44 @@ def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, }; } +(**************************) + +def void Split44(int32_pl O1, int32_pl O2, int32_pl O3, int32_pl O4, int32_pl I1, int32_pl I2, int32_pl I3, int32_pl I4, int32_al[I1][I2][I3][I4] inp, int32_pl axis, int32_pl curCount, int32_pl total, int32_al[O1][O2][O3][O4] out){ + +for o1=[0:O1]{ + for o2=[0:O2]{ + for o3=[0:O3]{ + for o4=[0:O4]{ + + int32_pl i1 = o1; + int32_pl i2 = o2; + int32_pl i3 = o3; + int32_pl i4 = o4; + + if(axis == 0){ + i1 = (I1/total)*curCount+o1; + }; + if(axis == 1){ + i2 = (I2/total)*curCount+o2; + }; + if(axis == 2){ + i3 = (I3/total)*curCount+o3; + }; + if(axis == 3){ + i4 = (I4/total)*curCount+o4; + }; + + out[o1][o2][o3][o4] = inp[i1][i2][i3][i4]; + }; + }; + }; +} +} + (**************************) (* Generic implementation of Conv2DCSF *) + def void Conv2DReshapeFilter(int32_pl FH, int32_pl FW, int32_pl CI, int32_pl CO, int32_al[FH][FW][CI][CO] inputArr, int32_al[CO][FH*FW*CI] outputArr){ for co=[0:CO]{ for fh=[0:FH]{ @@ -291,11 +426,6 @@ def void Conv2DReshapeInput(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, int }; } -(* int32_al[N][H][W][CI] inputArr, - int32_al[FH][FW][CI][CO] filterArr, - int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr -*) - def void Conv2DCSF(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, int32_pl FH, int32_pl FW, int32_pl CO, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, @@ -325,6 +455,274 @@ def void Conv2DCSF(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, Conv2DReshapeMatMulOP(N, newH, newW, CO, matmulOP, outArr); } +(* int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI][CO] filterArr, + int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) + +def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, int32_pl G, + int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI][CO] filterArr, + int32_pl consSF, + int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr); +} + +(**************************) +(* Generic implementation of Conv2D with Groups *) + + +(* int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI][CO] filterArr, + int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) +def void Conv2DReshapeFilterGroup(int32_pl FH, int32_pl FW, int32_pl CI, int32_pl CO, int32_pl g, int32_pl G, int32_al[FH][FW][CI/G][CO] inputArr, int32_al[CO/G][FH*FW*(CI/G)] outputArr){ + + int32_pl CIG = CI/G; + int32_pl COG = CO/G; + int32_pl startCO = g*COG; + + for co=[0:COG]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + for ci=[0:CIG]{ + int32_pl linIdx = (fh*FW*CIG) + (fw*CIG) + ci; + outputArr[co][linIdx] = inputArr[fh][fw][ci][co+startCO]; + }; + }; + }; + }; +} + +def void Conv2DReshapeMatMulOPGroup(int32_pl N, int32_pl finalH, int32_pl finalW, int32_pl CO, int32_pl g, int32_pl G, int32_al[CO/G][N*finalH*finalW] inputArr, int32_al[N][finalH][finalW][CO] outputArr){ + + int32_pl COG = CO/G; + int32_pl startCO = g*COG; + + for co=[0:COG]{ + for n=[0:N]{ + for h=[0:finalH]{ + for w=[0:finalW]{ + outputArr[n][h][w][co+startCO] = inputArr[co][(n*finalH*finalW) + (h*finalW) + w]; + }; + }; + }; + }; +} + +def void Conv2DReshapeInputGroup(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, int32_pl FH, int32_pl FW, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, int32_pl strideH, int32_pl strideW, int32_pl g, int32_pl G, int32_pl RRows, int32_pl RCols, int32_al[N][H][W][CI] inputArr, int32_al[RRows][RCols] outputArr){ + int32_pl linIdxFilterMult = 0; + int32_pl CIG = CI/G; + + for n=[0:N]{ + int32_pl leftTopCornerH = 0 - zPadHLeft; + int32_pl extremeRightBottomCornerH = H - 1 + zPadHRight; + while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){ + int32_pl leftTopCornerW = 0 - zPadWLeft; + int32_pl extremeRightBottomCornerW = W - 1 + zPadWRight; + while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){ + + for fh=[0:FH]{ + for fw=[0:FW]{ + int32_pl curPosH = leftTopCornerH + fh; + int32_pl curPosW = leftTopCornerW + fw; + int32_al val = 0; + + int32_pl startCI = g*CIG; + + for ci=[0:CIG]{ + if ((((curPosH < 0) || (curPosH >= H)) || ((curPosW < 0) || (curPosW >= W)))){ + val = 0; + } + else{ + val = inputArr[n][curPosH][curPosW][ci+startCI]; + }; + outputArr[(fh*FW*CIG) + (fw*CIG) + ci][linIdxFilterMult] = val; + }; + }; + }; + + linIdxFilterMult = linIdxFilterMult + 1; + leftTopCornerW = leftTopCornerW + strideW; + }; + + leftTopCornerH = leftTopCornerH + strideH; + }; + }; +} + + +def void Conv2DCSFGroup(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, int32_pl G, + int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI/G][CO] filterArr, + int32_pl consSF, + int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl CIG = CI/G; + int32_pl reshapedFilterRows = CO/G; + int32_pl reshapedFilterCols = FH*FW*CIG; + int32_pl reshapedIPRows = FH*FW*CIG; + int32_pl outH = (((H + (zPadHLeft+zPadHRight) - FH)/strideH) + 1); + int32_pl outW = (((W + (zPadWLeft+zPadWRight) - FW)/strideW) + 1); + int32_pl reshapedIPCols = N * outH * outW; + + + for g=[0:G]{ + int32_al[reshapedIPRows][reshapedIPCols] inputReshaped; + int32_al[reshapedFilterRows][reshapedIPCols] matmulOP; + int32_al[reshapedFilterRows][reshapedFilterCols] filterReshaped; + + Conv2DReshapeFilterGroup(FH, FW, CI, CO, g, G, filterArr, filterReshaped); + Conv2DReshapeInputGroup(N, H, W, CI, FH, FW, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, g, G, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped); + + MatMulCSF2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, consSF); + + Conv2DReshapeMatMulOPGroup(N, outH, outW, CO, g, G, matmulOP, outArr); + } + +} + +(**************************) +(* Generic implementation of Conv3DCSF *) + +def void Conv3DReshapeFilter(int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CI, int32_pl CO, int32_al[FD][FH][FW][CI][CO] inputArr, int32_al[CO][FD*FH*FW*CI] outputArr){ + for co=[0:CO]{ + for fd=[0:FD]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + for ci=[0:CI]{ + int32_pl linIdx = (fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci; + outputArr[co][linIdx] = inputArr[fd][fh][fw][ci][co]; + }; + }; + }; + }; + }; +} + +def void Conv3DReshapeMatMulOP(int32_pl N, int32_pl finalD, int32_pl finalH, int32_pl finalW, int32_pl CO, int32_al[CO][N*finalD*finalH*finalW] inputArr, int32_al[N][finalD][finalH][finalW][CO] outputArr){ + for co=[0:CO]{ + for n=[0:N]{ + for d=[0:finalD]{ + for h=[0:finalH]{ + for w=[0:finalW]{ + outputArr[n][d][h][w][co] = inputArr[co][(n*finalD*finalH*finalW) + (d*finalH*finalW) + (h*finalW) + w]; + }; + }; + }; + }; + }; +} + +def void Conv3DReshapeInput(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, int32_pl FD, int32_pl FH, int32_pl FW, int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, int32_pl strideD, int32_pl strideH, int32_pl strideW, int32_pl RRows, int32_pl RCols, int32_al[N][D][H][W][CI] inputArr, int32_al[RRows][RCols] outputArr){ + int32_pl linIdxFilterMult = 0; + for n=[0:N]{ + int32_pl leftTopCornerD = 0 - zPadDLeft; + int32_pl extremeRightBottomCornerD = D - 1 + zPadDRight; + while((leftTopCornerD + FD - 1) <= extremeRightBottomCornerD){ + int32_pl leftTopCornerH = 0 - zPadHLeft; + int32_pl extremeRightBottomCornerH = H - 1 + zPadHRight; + while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){ + int32_pl leftTopCornerW = 0 - zPadWLeft; + int32_pl extremeRightBottomCornerW = W - 1 + zPadWRight; + while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){ + + for fd=[0:FD]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + int32_pl curPosD = leftTopCornerD + fd; + int32_pl curPosH = leftTopCornerH + fh; + int32_pl curPosW = leftTopCornerW + fw; + int32_al val = 0; + for ci=[0:CI]{ + if ((((curPosD < 0) || (curPosD >= D)) || ((curPosH < 0) || (curPosH >= H)) || ((curPosW < 0) || (curPosW >= W)))){ + val = 0; + } + else{ + val = inputArr[n][curPosD][curPosH][curPosW][ci]; + }; + outputArr[(fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci][linIdxFilterMult] = val; + }; + }; + }; + }; + + linIdxFilterMult = linIdxFilterMult + 1; + leftTopCornerW = leftTopCornerW + strideW; + }; + + leftTopCornerH = leftTopCornerH + strideH; + }; + + leftTopCornerD = leftTopCornerD + strideD; + }; + }; +} + +(* int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CI][CO] filterArr, + int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) +(* Loop implementation of convolution run faster with multithreadin *) +def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CI][CO] filterArr, + int32_pl consSF, + int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1; + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr); +} + +(* int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CI][CO] filterArr, + int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) +def void Conv3DCSF(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CI][CO] filterArr, + int32_pl consSF, + int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl reshapedFilterRows = CO; + int32_pl reshapedFilterCols = FD*FH*FW*CI; + int32_pl reshapedIPRows = FD*FH*FW*CI; + int32_pl newD = (((D + (zPadDLeft+zPadDRight) - FD)/strideD) + 1); + int32_pl newH = (((H + (zPadHLeft+zPadHRight) - FH)/strideH) + 1); + int32_pl newW = (((W + (zPadWLeft+zPadWRight) - FW)/strideW) + 1); + int32_pl reshapedIPCols = N * newD * newH * newW; + + int32_al[reshapedFilterRows][reshapedFilterCols] filterReshaped; + int32_al[reshapedIPRows][reshapedIPCols] inputReshaped; + int32_al[reshapedFilterRows][reshapedIPCols] matmulOP; + + Conv3DReshapeFilter(FD, FH, FW, CI, CO, filterArr, filterReshaped); + Conv3DReshapeInput(N, D, H, W, CI, FD, FH, FW, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped); + + MatMulCSF2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, consSF); + + Conv3DReshapeMatMulOP(N, newD, newH, newW, CO, matmulOP, outArr); +} + (**************************) def void Transpose2(int32_pl s1, int32_pl s2, int32_al[s2][s1] inArr, int32_al[s1][s2] outArr){ for i=[0:s1]{ @@ -360,6 +758,60 @@ def void Pad442(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp }; } +def void Pad552(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl inps1, int32_pl inps2, int32_pl inps3, int32_pl inps4, int32_pl inps5, int32_al[inps1][inps2][inps3][inps4][inps5] inpArr, int32_pl pads1, int32_pl pads2, int32_pl[pads1][pads2] paddings, int32_al[s1][s2][s3][s4][s5] outArr){ + int32_pl lbounds1 = paddings[0][0]; + int32_pl rbounds1excl = s1-paddings[0][1]; + int32_pl lbounds2 = paddings[1][0]; + int32_pl rbounds2excl = s2-paddings[1][1]; + int32_pl lbounds3 = paddings[2][0]; + int32_pl rbounds3excl = s3-paddings[2][1]; + int32_pl lbounds4 = paddings[3][0]; + int32_pl rbounds4excl = s4-paddings[3][1]; + int32_pl lbounds5 = paddings[4][0]; + int32_pl rbounds5excl = s5-paddings[4][1]; + for i=[0:s1]{ + for j=[0:s2]{ + for k=[0:s3]{ + for l=[0:s4]{ + for m=[0:s5]{ + if ((i >= lbounds1) && (i < rbounds1excl) && (j >= lbounds2) && (j < rbounds2excl) && (k >= lbounds3) && (k < rbounds3excl) && (l >= lbounds4) && (l < rbounds4excl) && (m >= lbounds5) && (m < rbounds5excl)){ + outArr[i][j][k][l][m] = inpArr[i-paddings[0][0]][j-paddings[1][0]][k-paddings[2][0]][l-paddings[3][0]][m-paddings[4][0]]; + } + else{ + outArr[i][j][k][l][m] = 0; + }; + }; + }; + }; + }; + }; +} + +def void PadONNX441(int32_pl o1, int32_pl o2, int32_pl o3, int32_pl o4, int32_pl i1, int32_pl i2, int32_pl i3, int32_pl i4, int32_al[i1][i2][i3][i4] inpArr, int32_pl pads, int32_pl[pads] paddings, int32_al[o1][o2][o3][o4] outArr) { + int32_pl lbounds1 = paddings[0]; + int32_pl rbounds1excl = o1 - paddings[4]; + int32_pl lbounds2 = paddings[1]; + int32_pl rbounds2excl = o2 - paddings[5]; + int32_pl lbounds3 = paddings[2]; + int32_pl rbounds3excl = o3 - paddings[6]; + int32_pl lbounds4 = paddings[3]; + int32_pl rbounds4excl = o4 - paddings[7]; + for i=[0:o1]{ + for j=[0:o2]{ + for k=[0:o3]{ + for l=[0:o4]{ + if ((i >= lbounds1) && (i < rbounds1excl) && (j >= lbounds2) && (j < rbounds2excl) && (k >= lbounds3) && (k < rbounds3excl) && (l >= lbounds4) && (l < rbounds4excl)){ + outArr[i][j][k][l] = inpArr[i-paddings[0]][j-paddings[1]][k-paddings[2]][l-paddings[3]]; + } + else{ + outArr[i][j][k][l] = 0; + }; + }; + }; + }; + }; +} + (**************************) (* Squeeze where the input is a 4D tensor, output is a 2D tensor and hence 2 dims are getting squeezed. *) def void Squeeze24(int32_pl s1, int32_pl s2, int32_pl dim1, int32_pl dim2, int32_pl ins1, int32_pl ins2, int32_pl ins3, int32_pl ins4, int32_al[ins1][ins2][ins3][ins4] inArr, int32_al[s1][s2] outArr){ @@ -380,6 +832,238 @@ def void Squeeze24(int32_pl s1, int32_pl s2, int32_pl dim1, int32_pl dim2, int32 } +(**************************) +(* Generic implementation of ConvTranpose2D *) + +def void ConvTranspose2DReshapeMatMulOP(int32_pl N, int32_pl finalH, int32_pl finalW, int32_pl CO, int32_al[CO][N*finalH*finalW] inputArr, int32_al[N][finalH][finalW][CO] outputArr){ + + for co=[0:CO]{ + for n=[0:N]{ + for h=[0:finalH]{ + for w=[0:finalW]{ + outputArr[n][h][w][co] = inputArr[co][(n*finalH*finalW) + (h*finalW) + w]; + }; + }; + }; + }; +} + + +def void ConvTranspose2DReshapeFilter(int32_pl FH, int32_pl FW, int32_pl CO, int32_pl CI, int32_al[FH][FW][CO][CI] inputArr, int32_al[CO][FH*FW*CI] outputArr) +{ + for co=[0:CO]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + for ci=[0:CI]{ + int32_pl linIdx = (fh*FW*CI) + (fw*CI) + ci; + outputArr[co][linIdx] = inputArr[FH-1-fh][FW-1-fw][co][ci]; + }; + }; + }; + }; +} + +def void ConvTranspose2DReshapeInput(int32_pl N, int32_pl HPrime, int32_pl WPrime, int32_pl CI, int32_pl FH, int32_pl FW, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, int32_pl strideH, int32_pl strideW, int32_pl RRows, int32_pl RCols, int32_al[N][HPrime][WPrime][CI] inputArr, int32_al[RRows][RCols] outputArr){ + int32_pl linIdxFilterMult = 0; + for n=[0:N]{ + int32_pl leftTopCornerH = 0 - zPadTrHLeft; + int32_pl HPrimeTilde = HPrime + ((HPrime-1)*(strideH-1)); + int32_pl extremeRightBottomCornerH = HPrimeTilde - 1 + zPadTrHRight; + while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){ + int32_pl leftTopCornerW = 0 - zPadTrWLeft; + int32_pl WPrimeTilde = WPrime + ((WPrime-1)*(strideW-1)); + int32_pl extremeRightBottomCornerW = WPrimeTilde - 1 + zPadTrWRight; + while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){ + + for fh=[0:FH]{ + for fw=[0:FW]{ + int32_pl curPosH = leftTopCornerH + fh; + int32_pl curPosW = leftTopCornerW + fw; + int32_al val = 0; + for ci=[0:CI]{ + if ((((curPosH < 0) || (curPosH >= HPrimeTilde)) || ((curPosW < 0) || (curPosW >= WPrimeTilde)))){ + val = 0; + } + else{ + (* curPosH lies between 0 and HPrimeTilde *) + if (((curPosH % strideH) == 0) && ((curPosW % strideW) == 0)) { + int32_pl idxInputH = curPosH / strideH; + int32_pl idxInputW = curPosW / strideW; + val = inputArr[n][idxInputH][idxInputW][ci]; + } + else{ + val = 0; (* This represents fractional stride. *) + }; + }; + outputArr[(fh*FW*CI) + (fw*CI) + ci][linIdxFilterMult] = val; + }; + }; + }; + + linIdxFilterMult = linIdxFilterMult + 1; + leftTopCornerW = leftTopCornerW + 1; (* Imp Note: The actual stride is always 1 *) + }; + + leftTopCornerH = leftTopCornerH + 1; (* Imp Note: The actual stride is always 1 *) + }; + }; +} + +(* int32_al[N][HPrime][WPrime][CI] inputArr, + int32_al[FH][FW][CO][CI] filter, + int32_al[N][H][W][CO] outputArr +*) +def void ConvTranspose2DCSF(int32_pl N, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl H, int32_pl W, + int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideH, int32_pl strideW, + int32_al[N][HPrime][WPrime][CI] inputArr, + int32_al[FH][FW][CO][CI] filterArr, + int32_pl consSF, + int32_al[N][H][W][CO] outArr) +{ + int32_pl reshapedFilterRows = CO; + int32_pl reshapedFilterCols = FH*FW*CI; + int32_pl reshapedIPRows = FH*FW*CI; + int32_pl reshapedIPCols = N * H * W; + + int32_al[reshapedFilterRows][reshapedFilterCols] filterReshaped; + int32_al[reshapedIPRows][reshapedIPCols] inputReshaped; + int32_al[reshapedFilterRows][reshapedIPCols] matmulOP; + + ConvTranspose2DReshapeFilter(FH, FW, CO, CI, filterArr, filterReshaped); + ConvTranspose2DReshapeInput(N, HPrime, WPrime, CI, FH, FW, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideH, strideW, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped); + + MatMulCSF2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, consSF); + + ConvTranspose2DReshapeMatMulOP(N, H, W, CO, matmulOP, outArr); +} + +(**************************) +(* Generic implementation of ConvTranpose3D *) + +def void ConvTranspose3DReshapeFilter(int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, int32_pl CI, int32_al[FD][FH][FW][CO][CI] inputArr, int32_al[CO][FD*FH*FW*CI] outputArr) +{ + for co=[0:CO]{ + for fd=[0:FD]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + for ci=[0:CI]{ + int32_pl linIdx = (fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci; + outputArr[co][linIdx] = inputArr[FD-1-fd][FH-1-fh][FW-1-fw][co][ci]; + }; + }; + }; + }; + }; +} + +def void ConvTranspose3DReshapeInput(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, int32_pl FD, int32_pl FH, int32_pl FW, int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, int32_pl strideD, int32_pl strideH, int32_pl strideW, int32_pl RRows, int32_pl RCols, int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, int32_al[RRows][RCols] outputArr){ + int32_pl linIdxFilterMult = 0; + for n=[0:N]{ + int32_pl leftTopCornerD = 0 - zPadTrDLeft; + int32_pl DPrimeTilde = DPrime + ((DPrime-1)*(strideD-1)); + int32_pl extremeRightBottomCornerD = DPrimeTilde - 1 + zPadTrDRight; + while((leftTopCornerD + FD - 1) <= extremeRightBottomCornerD){ + int32_pl leftTopCornerH = 0 - zPadTrHLeft; + int32_pl HPrimeTilde = HPrime + ((HPrime-1)*(strideH-1)); + int32_pl extremeRightBottomCornerH = HPrimeTilde - 1 + zPadTrHRight; + while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){ + int32_pl leftTopCornerW = 0 - zPadTrWLeft; + int32_pl WPrimeTilde = WPrime + ((WPrime-1)*(strideW-1)); + int32_pl extremeRightBottomCornerW = WPrimeTilde - 1 + zPadTrWRight; + while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){ + + for fd=[0:FD]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + int32_pl curPosD = leftTopCornerD + fd; + int32_pl curPosH = leftTopCornerH + fh; + int32_pl curPosW = leftTopCornerW + fw; + int32_al val = 0; + for ci=[0:CI]{ + if (((curPosD < 0) || (curPosD >= DPrimeTilde)) || ((curPosH < 0) || (curPosH >= HPrimeTilde)) || ((curPosW < 0) || (curPosW >= WPrimeTilde))) { + val = 0; + } + else{ + (* curPosH lies between 0 and HPrimeTilde *) + if (((curPosD % strideD) == 0) && ((curPosH % strideH) == 0) && ((curPosW % strideW) == 0)) { + int32_pl idxInputD = curPosD / strideD; + int32_pl idxInputH = curPosH / strideH; + int32_pl idxInputW = curPosW / strideW; + val = inputArr[n][idxInputD][idxInputH][idxInputW][ci]; + } + else{ + val = 0; (* This represents fractional stride. *) + }; + }; + outputArr[(fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci][linIdxFilterMult] = val; + }; + }; + }; + }; + + linIdxFilterMult = linIdxFilterMult + 1; + leftTopCornerW = leftTopCornerW + 1; (* Imp Note: The actual stride is always 1 *) + }; + + leftTopCornerH = leftTopCornerH + 1; (* Imp Note: The actual stride is always 1 *) + }; + + leftTopCornerD = leftTopCornerD + 1; (* Imp Note: The actual stride is always 1 *) + }; + }; +} + +(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int32_al[FD][FH][FW][CO][CI] filter, + int32_al[N][D][H][W][CO] outputArr +*) +def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl D, int32_pl H, int32_pl W, + int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int32_al[FD][FH][FW][CO][CI] filterArr, + int32_pl consSF, + int32_al[N][D][H][W][CO] outArr) +{ + ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr); +} + +(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int32_al[FD][FH][FW][CO][CI] filter, + int32_al[N][D][H][W][CO] outputArr +*) +def void ConvTranspose3DCSF(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl D, int32_pl H, int32_pl W, + int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int32_al[FD][FH][FW][CO][CI] filterArr, + int32_pl consSF, + int32_al[N][D][H][W][CO] outArr) +{ + int32_pl reshapedFilterRows = CO; + int32_pl reshapedFilterCols = FD*FH*FW*CI; + int32_pl reshapedIPRows = FD*FH*FW*CI; + int32_pl reshapedIPCols = N * D * H * W; + + int32_al[reshapedFilterRows][reshapedFilterCols] filterReshaped; + int32_al[reshapedIPRows][reshapedIPCols] inputReshaped; + int32_al[reshapedFilterRows][reshapedIPCols] matmulOP; + + ConvTranspose3DReshapeFilter(FD, FH, FW, CO, CI, filterArr, filterReshaped); + ConvTranspose3DReshapeInput(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped); + + MatMulCSF2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, consSF); + + Conv3DReshapeMatMulOP(N, D, H, W, CO, matmulOP, outArr); +} + (**************************) def void ClearMemPublic(int32_pl x){ return; @@ -387,4 +1071,14 @@ def void ClearMemPublic(int32_pl x){ def void ClearMemPublic1(int32_pl s, int32_pl[s] x){ return; +} + +def void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl[s1][s2][s3][s4] arr) +{ + return; +} + +def void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr) +{ + return; } \ No newline at end of file diff --git a/Athos/TFEzPCLibrary/Library32_cpp.ezpc b/Athos/TFEzPCLibrary/Library32_cpp.ezpc index 4e5911ae232ee2b08a00d970e43fa2a47c083918..57dd0a8fefa1b753d24ede8ec29d7107cc1487c8 100644 --- a/Athos/TFEzPCLibrary/Library32_cpp.ezpc +++ b/Athos/TFEzPCLibrary/Library32_cpp.ezpc @@ -21,7 +21,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *) - (**************************) def void MatMulCSF2D(int32_pl i, int32_pl j, int32_pl k, int32_al[i][j] A, int32_al[j][k] B, int32_al[i][k] C, int32_pl consSF){ for i1=[0:i]{ @@ -35,6 +34,145 @@ def void MatMulCSF2D(int32_pl i, int32_pl j, int32_pl k, int32_al[i][j] A, int32 }; } +(**************************) +(* These loop implementations of convolution run faster with multithreading *) + +def void Conv2DLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, + int32_pl outH, int32_pl outW, int32_pl G, + int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI/G][CO] filterArr, + int32_pl consSF, + int32_al[N][outH][outW][CO] outArr){ + + int32_pl GIS = CI/G; + int32_pl GOS = CO/G; + + for n=[0:N]{ + for cog=[0:GOS]{ + for cig=[0:GIS]{ + for g=[0:G]{ + for h=[0:outH]{ + for w=[0:outW]{ + + int32_al val = 0; + int32_pl ci = GIS*g + cig; + int32_pl co = GOS*g + cog; + int32_pl curPosH = strideH*h-zPadHLeft; + + for fh=[0:FH]{ + int32_pl curPosW = strideW*w-zPadWLeft; + + for fw=[0:FW]{ + if( (curPosH >= 0) && (curPosW >= 0) && (curPosH < H) && (curPosW < W)){ + val = val +_al (inputArr[n][curPosH][curPosW][ci]*filterArr[fh][fw][(ci/G)][co]); + }; + + curPosW = curPosW + 1; + }; + curPosH = curPosH + 1; + }; + + outArr[n][h][w][co] = outArr[n][h][w][co] +_al (val >> consSF); + }; + }; + }; + }; + }; + }; +} + +(**************************) +def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight,int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_pl outD, int32_pl outH, int32_pl outW, + int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CI][CO] filterArr, + int32_pl consSF, + int32_al[N][outD][outH][outW][CO] outArr){ + + for n=[0:N]{ + for co=[0:CO]{ + for d=[0:outD]{ + for h=[0:outH]{ + for w=[0:outW]{ + for ci=[0:CI]{ + int32_al val = 0; + for fd=[d*strideD:d*strideD+FD]{ + for fh=[h*strideH:h*strideH+FH]{ + for fw=[w*strideW:w*strideW+FW]{ + int32_pl curPosD = fd-zPadDLeft; + int32_pl curPosH = fh-zPadHLeft; + int32_pl curPosW = fw-zPadWLeft; + if( (curPosD >= 0) && (curPosH >= 0) && (curPosW >= 0) && (curPosD < D) && (curPosH < H) && (curPosW < W)){ + int32_pl curFilterPosD = fd-(d*strideD); + int32_pl curFilterPosH = fh-(h*strideH); + int32_pl curFilterPosW = fw-(w*strideW); + val = val +_al (inputArr[n][curPosD][curPosH][curPosW][ci]*filterArr[curFilterPosD][curFilterPosH][curFilterPosW][ci][co]); + }; + }; + }; + }; + outArr[n][d][h][w][co] = outArr[n][d][h][w][co] +_al (val >> consSF); + }; + }; + }; + }; + }; + }; +} + + +(**************************) +def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight,int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_pl outD, int32_pl outH, int32_pl outW, + int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CO][CI] filterArr, + int32_pl consSF, + int32_al[N][outD][outH][outW][CO] outArr){ + + for n=[0:N]{ + for co=[0:CO]{ + for d=[0:outD]{ + for h=[0:outH]{ + for w=[0:outW]{ + for ci=[0:CI]{ + int32_al val = 0; + for fd=[d:d+FD]{ + for fh=[h:h+FH]{ + for fw=[w:w+FW]{ + + int32_pl curPosD = (fd-zPadDLeft)/strideD; + int32_pl curPosH = (fh-zPadHLeft)/strideD; + int32_pl curPosW = (fw-zPadWLeft)/strideD; + + if( (curPosD >= 0) && (curPosH >= 0) && (curPosW >= 0) && (curPosD < D) && (curPosH < H) && (curPosW < W) && ((fd-zPadDLeft)%strideD == 0) && ((fh-zPadHLeft)%strideH == 0) && ((fw-zPadWLeft)%strideW == 0)){ + + int32_pl curFilterPosD = FD+d-fd-1; + int32_pl curFilterPosH = FH+h-fh-1; + int32_pl curFilterPosW = FW+w-fw-1; + val = val +_al (inputArr[n][curPosD][curPosH][curPosW][ci]*filterArr[curFilterPosD][curFilterPosH][curFilterPosW][co][ci]); + }; + }; + }; + }; + outArr[n][d][h][w][co] = outArr[n][d][h][w][co] +_al (val >> consSF); + }; + }; + }; + }; + }; + }; +} + + (**************************) def void ArgMax1(int32_pl outArrS1, int32_pl inArrS1, int32_pl inArrS2, int32_al[inArrS1][inArrS2] inArr, int32_pl dim, int32_al[outArrS1] outArr){ for od=[0:inArrS1]{ @@ -90,37 +228,115 @@ def void Relu4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][ }; } +def void Relu5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] inArr, int32_al[s1][s2][s3][s4][s5] outArr){ + for i1=[0:s1]{ + for i2=[0:s2]{ + for i3=[0:s3]{ + for i4=[0:s4]{ + for i5=[0:s5]{ + outArr[i1][i2][i3][i4][i5] = (inArr[i1][i2][i3][i4][i5] > 0 ? inArr[i1][i2][i3][i4][i5] : 0); + }; + }; + }; + }; + }; +} (**************************) -def void ElemWiseMul2(int32_pl s1, int32_pl s2, int32_al[s1][s2] arr1, int32_al[s1][s2] arr2, int32_al[s1][s2] outArr, int32_pl shrout){ - for i1=[0:s1]{ - for i2=[0:s2]{ - outArr[i1][i2] = ((arr1[i1][i2] * arr2[i1][i2]) >> shrout); - }; - }; +def void ElemWiseMul2(int32_pl a1, int32_pl a2, int32_pl b1, int32_pl b2, int32_pl s1, int32_pl s2, int32_al[a1][a2] A, int32_al[b1][b2] B, int32_al[s1][s2] outArr, int32_pl shrout){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + outArr[i1][i2] = ((A[aIdx1][aIdx2] * B[bIdx1][bIdx2]) >> shrout); + }; + }; } -def void ElemWiseMul4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] arr1, int32_al[s1][s2][s3][s4] arr2, int32_al[s1][s2][s3][s4] outArr, int32_pl shrout){ - for i1=[0:s1]{ - for i2=[0:s2]{ - for i3=[0:s3]{ - for i4=[0:s4]{ - outArr[i1][i2][i3][i4] = ((arr1[i1][i2][i3][i4] * arr2[i1][i2][i3][i4]) >> shrout); - }; - }; - }; - }; +def void ElemWiseMul4(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl b1, int32_pl b2, int32_pl b3, int32_pl b4, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[a1][a2][a3][a4] A, int32_al[b1][b2][b3][b4] B, int32_al[s1][s2][s3][s4] outArr, int32_pl shrout){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl aIdx3 = 0; + int32_pl aIdx4 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + int32_pl bIdx3 = 0; + int32_pl bIdx4 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + for i3=[0:s3]{ + aIdx3 = ((a3 == 1) ? 0 : i3); + bIdx3 = ((b3 == 1) ? 0 : i3); + for i4=[0:s4]{ + aIdx4 = ((a4 == 1) ? 0 : i4); + bIdx4 = ((b4 == 1) ? 0 : i4); + outArr[i1][i2][i3][i4] = ((A[aIdx1][aIdx2][aIdx3][aIdx4] * B[bIdx1][bIdx2][bIdx3][bIdx4]) >> shrout); + }; + }; + }; + }; } -(**************************) -def void ElemWiseDiv2(int32_pl s1, int32_pl s2, int32_al[s1][s2] arr1, int32_al[s1][s2] arr2, int32_al[s1][s2] outArr, int32_pl shrout){ - for i1=[0:s1]{ - for i2=[0:s2]{ - outArr[i1][i2] = ((arr1[i1][i2] / arr2[i1][i2]) << shrout); - }; - }; +def void ElemWiseMul5(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl a5, int32_pl b1, int32_pl b2, int32_pl b3, int32_pl b4, int32_pl b5, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[a1][a2][a3][a4][a5] A, int32_al[b1][b2][b3][b4][b5] B, int32_al[s1][s2][s3][s4][s5] outArr, int32_pl shrout){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl aIdx3 = 0; + int32_pl aIdx4 = 0; + int32_pl aIdx5 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + int32_pl bIdx3 = 0; + int32_pl bIdx4 = 0; + int32_pl bIdx5 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + for i3=[0:s3]{ + aIdx3 = ((a3 == 1) ? 0 : i3); + bIdx3 = ((b3 == 1) ? 0 : i3); + for i4=[0:s4]{ + aIdx4 = ((a4 == 1) ? 0 : i4); + bIdx4 = ((b4 == 1) ? 0 : i4); + for i5=[0:s5]{ + aIdx5 = ((a5 == 1) ? 0 : i5); + bIdx5 = ((b5 == 1) ? 0 : i5); + outArr[i1][i2][i3][i4][i5] = ((A[aIdx1][aIdx2][aIdx3][aIdx4][aIdx5] * B[bIdx1][bIdx2][bIdx3][bIdx4][bIdx5]) >> shrout); + }; + }; + }; + }; + }; } +(**************************) +def void ElemWiseDiv2(int32_pl a1, int32_pl a2, int32_pl b1, int32_pl b2, int32_pl s1, int32_pl s2, int32_al[a1][a2] A, int32_al[b1][b2] B, int32_al[s1][s2] outArr, int32_pl shrout){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + outArr[i1][i2] = ((A[aIdx1][aIdx2] / B[bIdx1][bIdx2]) >> shrout); + }; + }; +} (**************************) def void Floor2(int32_pl s1, int32_pl s2, int32_al[s1][s2] inArr, int32_al[s1][s2] outArr, int32_pl curSF){ for i1=[0:s1]{ @@ -275,10 +491,28 @@ def void FusedBatchNorm4411(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, }; } +def void FusedBatchNorm5511(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] inArr, int32_al[s5] multArr, int32_al[s5] biasArr, int32_pl consSF, int32_al[s1][s2][s3][s4][s5] outputArr){ + for i1=[0:s1]{ + for i2=[0:s2]{ + for i3=[0:s3]{ + for i4=[0:s4]{ + for i5=[0:s5]{ + int32_al t1 = (inArr[i1][i2][i3][i4][i5] *_al multArr[i5]); + int32_al t2 = (t1 >> consSF); + outputArr[i1][i2][i3][i4][i5] = t2 + biasArr[i5]; + }; + }; + }; + }; + }; +} + + +(**************************) def void ReduceMean24(int32_pl outS1, int32_pl outS2, int32_pl inS1, int32_pl inS2, int32_pl inS3, int32_pl inS4, int32_al[inS1][inS2][inS3][inS4] inputArr, - int32_al[2] axes, + int32_pl[2] axes, int32_al[outS1][outS2] outputArr ) { @@ -297,6 +531,29 @@ def void ReduceMean24(int32_pl outS1, int32_pl outS2, }; } +(* This one is used for onnx compilation *) +def void ReduceMeanONNX24(int32_pl outS1, int32_pl outS2, + int32_pl inS1, int32_pl inS2, int32_pl inS3, int32_pl inS4, + int32_al[inS1][inS2][inS3][inS4] inputArr, + int32_pl axis1, int32_pl axis2, + int32_al[outS1][outS2] outputArr + ) +{ + for i1=[0:outS1]{ + for i2=[0:outS2]{ + int32_al summ = 0; + for i=[0:inS3]{ + for j=[0:inS4]{ + summ = summ + inputArr[i1][i2][i][j]; + }; + }; + int32_pl numElem = inS3*inS4; + summ = summ / numElem; + outputArr[i1][i2] = summ; + }; + }; +} + (**************************) def void ClearMemSecret1(int32_pl s1, int32_al[s1] arr) { @@ -318,6 +575,11 @@ def void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int return; } +def void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] arr) +{ + return; +} + def void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr) { return; @@ -332,4 +594,4 @@ def void StartComputation() def void EndComputation() { return; -} +} \ No newline at end of file diff --git a/Athos/TFEzPCLibrary/Library32_porthos.ezpc b/Athos/TFEzPCLibrary/Library32_porthos.ezpc index e867e2f5ed5911f46edcbf67bfbdd81e397bb44f..69be6df96ef730ecf5aea965b9243c760c81c84a 100644 --- a/Athos/TFEzPCLibrary/Library32_porthos.ezpc +++ b/Athos/TFEzPCLibrary/Library32_porthos.ezpc @@ -36,6 +36,7 @@ extern void ArgMax3(int32_pl outs1, int32_pl outs2, int32_pl outs3, (**************************) extern void Relu2(int32_pl s1, int32_pl s2, int32_al[s1][s2] inArr, int32_al[s1][s2] outArr); extern void Relu4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] inArr, int32_al[s1][s2][s3][s4] outArr); +extern void Relu5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] inArr, int32_al[s1][s2][s3][s4][s5] outArr); (**************************) extern void ElemWiseMul2(int32_pl s1, int32_pl s2, int32_al[s1][s2] arr1, int32_al[s1][s2] arr2, int32_al[s1][s2] outArr, int32_pl shrout); @@ -75,9 +76,10 @@ extern void ClearMemSecret1(int32_pl s1, int32_al[s1] arr); extern void ClearMemSecret2(int32_pl s1, int32_pl s2, int32_al[s1][s2] arr); extern void ClearMemSecret3(int32_pl s1, int32_pl s2, int32_pl s3, int32_al[s1][s2][s3] arr); extern void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] arr); +extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] arr) extern void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr); (**************************) extern void StartComputation(); -extern void EndComputation(); \ No newline at end of file +extern void EndComputation(); diff --git a/Athos/TFEzPCLibrary/Library64_common.ezpc b/Athos/TFEzPCLibrary/Library64_common.ezpc index 042cb9ca0eb42e8ceee1e81420a696aa65d940c4..9eda4383ed4a57eeb35d79dda370f18872f27bbd 100644 --- a/Athos/TFEzPCLibrary/Library64_common.ezpc +++ b/Athos/TFEzPCLibrary/Library64_common.ezpc @@ -24,7 +24,7 @@ SOFTWARE. (**************************) (* TODO : the 2nd arg should be broadcasted *) -def void MatAddBroadCast2(int32_pl s1, int32_pl s2, int64_al[s1][s2] A, int64_al[s2] B, int64_al[s1][s2] outArr){ +def void MatAddBroadCast2(int32_pl a1, int32_pl a2, int32_pl b1, int32_pl s1, int32_pl s2, int64_al[s1][s2] A, int64_al[s2] B, int64_al[s1][s2] outArr){ for i1=[0:s1]{ for i2=[0:s2]{ outArr[i1][i2] = A[i1][i2] + B[i2]; @@ -32,16 +32,24 @@ def void MatAddBroadCast2(int32_pl s1, int32_pl s2, int64_al[s1][s2] A, int64_al }; } -def void MatAdd2(int32_pl s1, int32_pl s2, int64_al[s1][s2] A, int64_al[s1][s2] B, int64_al[s1][s2] outArr){ - for i1=[0:s1]{ - for i2=[0:s2]{ - outArr[i1][i2] = A[i1][i2] + B[i1][i2]; - }; - }; +def void MatAdd2(int32_pl a1, int32_pl a2, int32_pl b1, int32_pl b2, int32_pl s1, int32_pl s2, int64_al[a1][a2] A, int64_al[b1][b2] B, int64_al[s1][s2] outArr){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + outArr[i1][i2] = A[aIdx1][aIdx2] + B[bIdx1][bIdx2]; + }; + }; } (* TODO : the 2nd arg should be broadcasted *) -def void MatAddBroadCast4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] A, int64_al[s4] B, int64_al[s1][s2][s3][s4] outArr){ +def void MatAddBroadCast4(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl b1, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] A, int64_al[s4] B, int64_al[s1][s2][s3][s4] outArr){ for i1=[0:s1]{ for i2=[0:s2]{ for i3=[0:s3]{ @@ -53,18 +61,82 @@ def void MatAddBroadCast4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in }; } -def void MatAdd4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] A, int64_al[s1][s2][s3][s4] B, int64_al[s1][s2][s3][s4] outArr){ +def void MatAddBroadCast5(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl a5, int32_pl b1, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] A, int64_al[s5] B, int64_al[s1][s2][s3][s4][s5] outArr){ for i1=[0:s1]{ for i2=[0:s2]{ for i3=[0:s3]{ for i4=[0:s4]{ - outArr[i1][i2][i3][i4] = A[i1][i2][i3][i4] + B[i1][i2][i3][i4]; + for i5=[0:s5]{ + outArr[i1][i2][i3][i4][i5] = A[i1][i2][i3][i4][i5] + B[i5]; + }; }; }; }; }; } +def void MatAdd4(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl b1, int32_pl b2, int32_pl b3, int32_pl b4, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[a1][a2][a3][a4] A, int64_al[b1][b2][b3][b4] B, int64_al[s1][s2][s3][s4] outArr){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl aIdx3 = 0; + int32_pl aIdx4 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + int32_pl bIdx3 = 0; + int32_pl bIdx4 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + for i3=[0:s3]{ + aIdx3 = ((a3 == 1) ? 0 : i3); + bIdx3 = ((b3 == 1) ? 0 : i3); + for i4=[0:s4]{ + aIdx4 = ((a4 == 1) ? 0 : i4); + bIdx4 = ((b4 == 1) ? 0 : i4); + outArr[i1][i2][i3][i4] = A[aIdx1][aIdx2][aIdx3][aIdx4] + B[bIdx1][bIdx2][bIdx3][bIdx4]; + }; + }; + }; + }; +} + +def void MatAdd5(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl a5, int32_pl b1, int32_pl b2, int32_pl b3, int32_pl b4, int32_pl b5, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[a1][a2][a3][a4][a5] A, int64_al[b1][b2][b3][b4][b5] B, int64_al[s1][s2][s3][s4][s5] outArr){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl aIdx3 = 0; + int32_pl aIdx4 = 0; + int32_pl aIdx5 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + int32_pl bIdx3 = 0; + int32_pl bIdx4 = 0; + int32_pl bIdx5 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + for i3=[0:s3]{ + aIdx3 = ((a3 == 1) ? 0 : i3); + bIdx3 = ((b3 == 1) ? 0 : i3); + for i4=[0:s4]{ + aIdx4 = ((a4 == 1) ? 0 : i4); + bIdx4 = ((b4 == 1) ? 0 : i4); + for i5=[0:s5]{ + aIdx5 = ((a5 == 1) ? 0 : i5); + bIdx5 = ((b5 == 1) ? 0 : i5); + outArr[i1][i2][i3][i4][i5] = A[aIdx1][aIdx2][aIdx3][aIdx4][aIdx5] + B[bIdx1][bIdx2][bIdx3][bIdx4][bIdx5]; + }; + }; + }; + }; + }; +} + (**************************) def void CreateTensor1(int32_pl s1, int64_pl val, int64_pl[s1] arr){ for i1=[0:s1]{ @@ -92,6 +164,20 @@ def void CreateTensor4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64 }; } +def void CreateTensor5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_pl val, int64_pl[s1][s2][s3][s4][s5] arr){ + for i1=[0:s1]{ + for i2=[0:s2]{ + for i3=[0:s3]{ + for i4=[0:s4]{ + for i5=[0:s5]{ + arr[i1][i2][i3][i4][i5] = val; + }; + }; + }; + }; + }; +} + (**************************) def void CopyTensor1(int32_pl s1, int64_al[s1] targetArr, int64_al[s1] fromArr, int64_al[s1] ignore){ for i1=[0:s1]{ @@ -155,6 +241,20 @@ def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2 }; } +def void CreateCopy5511(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl inps1, int32_pl inps2, int32_pl inps3, int32_pl inps4, int32_pl inps5, int64_al[inps1][inps2][inps3][inps4][inps5] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int64_al[s1][s2][s3][s4][s5] outArr){ + for i=[0:s1]{ + for j=[0:s2]{ + for k=[0:s3]{ + for l=[0:s4]{ + for m=[0:s5]{ + outArr[i][j][k][l][m] = inArr[beginIdx[0]+i][beginIdx[1]+j][beginIdx[2]+k][beginIdx[3]+l][beginIdx[4]+m]; + }; + }; + }; + }; + }; +} + (**************************) def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int64_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int64_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int64_al[s1][s2][s3][s4] outp){ for i1=[0:s1]{ @@ -227,9 +327,44 @@ def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, }; } +(**************************) + +def void Split44(int32_pl O1, int32_pl O2, int32_pl O3, int32_pl O4, int32_pl I1, int32_pl I2, int32_pl I3, int32_pl I4, int64_al[I1][I2][I3][I4] inp, int32_pl axis, int32_pl curCount, int32_pl total, int64_al[O1][O2][O3][O4] out){ + +for o1=[0:O1]{ + for o2=[0:O2]{ + for o3=[0:O3]{ + for o4=[0:O4]{ + + int32_pl i1 = o1; + int32_pl i2 = o2; + int32_pl i3 = o3; + int32_pl i4 = o4; + + if(axis == 0){ + i1 = (I1/total)*curCount+o1; + }; + if(axis == 1){ + i2 = (I2/total)*curCount+o2; + }; + if(axis == 2){ + i3 = (I3/total)*curCount+o3; + }; + if(axis == 3){ + i4 = (I4/total)*curCount+o4; + }; + + out[o1][o2][o3][o4] = inp[i1][i2][i3][i4]; + }; + }; + }; +} +} + (**************************) (* Generic implementation of Conv2DCSF *) + def void Conv2DReshapeFilter(int32_pl FH, int32_pl FW, int32_pl CI, int32_pl CO, int64_al[FH][FW][CI][CO] inputArr, int64_al[CO][FH*FW*CI] outputArr){ for co=[0:CO]{ for fh=[0:FH]{ @@ -291,11 +426,6 @@ def void Conv2DReshapeInput(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, int }; } -(* int64_al[N][H][W][CI] inputArr, - int64_al[FH][FW][CI][CO] filterArr, - int64_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr -*) - def void Conv2DCSF(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, int32_pl FH, int32_pl FW, int32_pl CO, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, @@ -325,6 +455,274 @@ def void Conv2DCSF(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, Conv2DReshapeMatMulOP(N, newH, newW, CO, matmulOP, outArr); } +(* int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI][CO] filterArr, + int64_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) + +def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, int32_pl G, + int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI][CO] filterArr, + int32_pl consSF, + int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr); +} + +(**************************) +(* Generic implementation of Conv2D with Groups *) + + +(* int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI][CO] filterArr, + int64_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) +def void Conv2DReshapeFilterGroup(int32_pl FH, int32_pl FW, int32_pl CI, int32_pl CO, int32_pl g, int32_pl G, int64_al[FH][FW][CI/G][CO] inputArr, int64_al[CO/G][FH*FW*(CI/G)] outputArr){ + + int32_pl CIG = CI/G; + int32_pl COG = CO/G; + int32_pl startCO = g*COG; + + for co=[0:COG]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + for ci=[0:CIG]{ + int32_pl linIdx = (fh*FW*CIG) + (fw*CIG) + ci; + outputArr[co][linIdx] = inputArr[fh][fw][ci][co+startCO]; + }; + }; + }; + }; +} + +def void Conv2DReshapeMatMulOPGroup(int32_pl N, int32_pl finalH, int32_pl finalW, int32_pl CO, int32_pl g, int32_pl G, int64_al[CO/G][N*finalH*finalW] inputArr, int64_al[N][finalH][finalW][CO] outputArr){ + + int32_pl COG = CO/G; + int32_pl startCO = g*COG; + + for co=[0:COG]{ + for n=[0:N]{ + for h=[0:finalH]{ + for w=[0:finalW]{ + outputArr[n][h][w][co+startCO] = inputArr[co][(n*finalH*finalW) + (h*finalW) + w]; + }; + }; + }; + }; +} + +def void Conv2DReshapeInputGroup(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, int32_pl FH, int32_pl FW, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, int32_pl strideH, int32_pl strideW, int32_pl g, int32_pl G, int32_pl RRows, int32_pl RCols, int64_al[N][H][W][CI] inputArr, int64_al[RRows][RCols] outputArr){ + int32_pl linIdxFilterMult = 0; + int32_pl CIG = CI/G; + + for n=[0:N]{ + int32_pl leftTopCornerH = 0 - zPadHLeft; + int32_pl extremeRightBottomCornerH = H - 1 + zPadHRight; + while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){ + int32_pl leftTopCornerW = 0 - zPadWLeft; + int32_pl extremeRightBottomCornerW = W - 1 + zPadWRight; + while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){ + + for fh=[0:FH]{ + for fw=[0:FW]{ + int32_pl curPosH = leftTopCornerH + fh; + int32_pl curPosW = leftTopCornerW + fw; + int64_al val = 0L; + + int32_pl startCI = g*CIG; + + for ci=[0:CIG]{ + if ((((curPosH < 0) || (curPosH >= H)) || ((curPosW < 0) || (curPosW >= W)))){ + val = 0L; + } + else{ + val = inputArr[n][curPosH][curPosW][ci+startCI]; + }; + outputArr[(fh*FW*CIG) + (fw*CIG) + ci][linIdxFilterMult] = val; + }; + }; + }; + + linIdxFilterMult = linIdxFilterMult + 1; + leftTopCornerW = leftTopCornerW + strideW; + }; + + leftTopCornerH = leftTopCornerH + strideH; + }; + }; +} + + +def void Conv2DCSFGroup(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, int32_pl G, + int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI/G][CO] filterArr, + int32_pl consSF, + int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl CIG = CI/G; + int32_pl reshapedFilterRows = CO/G; + int32_pl reshapedFilterCols = FH*FW*CIG; + int32_pl reshapedIPRows = FH*FW*CIG; + int32_pl outH = (((H + (zPadHLeft+zPadHRight) - FH)/strideH) + 1); + int32_pl outW = (((W + (zPadWLeft+zPadWRight) - FW)/strideW) + 1); + int32_pl reshapedIPCols = N * outH * outW; + + + for g=[0:G]{ + int64_al[reshapedIPRows][reshapedIPCols] inputReshaped; + int64_al[reshapedFilterRows][reshapedIPCols] matmulOP; + int64_al[reshapedFilterRows][reshapedFilterCols] filterReshaped; + + Conv2DReshapeFilterGroup(FH, FW, CI, CO, g, G, filterArr, filterReshaped); + Conv2DReshapeInputGroup(N, H, W, CI, FH, FW, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, g, G, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped); + + MatMulCSF2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, consSF); + + Conv2DReshapeMatMulOPGroup(N, outH, outW, CO, g, G, matmulOP, outArr); + } + +} + +(**************************) +(* Generic implementation of Conv3DCSF *) + +def void Conv3DReshapeFilter(int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CI, int32_pl CO, int64_al[FD][FH][FW][CI][CO] inputArr, int64_al[CO][FD*FH*FW*CI] outputArr){ + for co=[0:CO]{ + for fd=[0:FD]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + for ci=[0:CI]{ + int32_pl linIdx = (fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci; + outputArr[co][linIdx] = inputArr[fd][fh][fw][ci][co]; + }; + }; + }; + }; + }; +} + +def void Conv3DReshapeMatMulOP(int32_pl N, int32_pl finalD, int32_pl finalH, int32_pl finalW, int32_pl CO, int64_al[CO][N*finalD*finalH*finalW] inputArr, int64_al[N][finalD][finalH][finalW][CO] outputArr){ + for co=[0:CO]{ + for n=[0:N]{ + for d=[0:finalD]{ + for h=[0:finalH]{ + for w=[0:finalW]{ + outputArr[n][d][h][w][co] = inputArr[co][(n*finalD*finalH*finalW) + (d*finalH*finalW) + (h*finalW) + w]; + }; + }; + }; + }; + }; +} + +def void Conv3DReshapeInput(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, int32_pl FD, int32_pl FH, int32_pl FW, int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, int32_pl strideD, int32_pl strideH, int32_pl strideW, int32_pl RRows, int32_pl RCols, int64_al[N][D][H][W][CI] inputArr, int64_al[RRows][RCols] outputArr){ + int32_pl linIdxFilterMult = 0; + for n=[0:N]{ + int32_pl leftTopCornerD = 0 - zPadDLeft; + int32_pl extremeRightBottomCornerD = D - 1 + zPadDRight; + while((leftTopCornerD + FD - 1) <= extremeRightBottomCornerD){ + int32_pl leftTopCornerH = 0 - zPadHLeft; + int32_pl extremeRightBottomCornerH = H - 1 + zPadHRight; + while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){ + int32_pl leftTopCornerW = 0 - zPadWLeft; + int32_pl extremeRightBottomCornerW = W - 1 + zPadWRight; + while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){ + + for fd=[0:FD]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + int32_pl curPosD = leftTopCornerD + fd; + int32_pl curPosH = leftTopCornerH + fh; + int32_pl curPosW = leftTopCornerW + fw; + int64_al val = 0L; + for ci=[0:CI]{ + if ((((curPosD < 0) || (curPosD >= D)) || ((curPosH < 0) || (curPosH >= H)) || ((curPosW < 0) || (curPosW >= W)))){ + val = 0L; + } + else{ + val = inputArr[n][curPosD][curPosH][curPosW][ci]; + }; + outputArr[(fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci][linIdxFilterMult] = val; + }; + }; + }; + }; + + linIdxFilterMult = linIdxFilterMult + 1; + leftTopCornerW = leftTopCornerW + strideW; + }; + + leftTopCornerH = leftTopCornerH + strideH; + }; + + leftTopCornerD = leftTopCornerD + strideD; + }; + }; +} + +(* int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CI][CO] filterArr, + int64_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) +(* Loop implementation of convolution run faster with multithreadin *) +def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CI][CO] filterArr, + int32_pl consSF, + int64_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1; + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr); +} + +(* int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CI][CO] filterArr, + int64_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) +def void Conv3DCSF(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CI][CO] filterArr, + int32_pl consSF, + int64_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl reshapedFilterRows = CO; + int32_pl reshapedFilterCols = FD*FH*FW*CI; + int32_pl reshapedIPRows = FD*FH*FW*CI; + int32_pl newD = (((D + (zPadDLeft+zPadDRight) - FD)/strideD) + 1); + int32_pl newH = (((H + (zPadHLeft+zPadHRight) - FH)/strideH) + 1); + int32_pl newW = (((W + (zPadWLeft+zPadWRight) - FW)/strideW) + 1); + int32_pl reshapedIPCols = N * newD * newH * newW; + + int64_al[reshapedFilterRows][reshapedFilterCols] filterReshaped; + int64_al[reshapedIPRows][reshapedIPCols] inputReshaped; + int64_al[reshapedFilterRows][reshapedIPCols] matmulOP; + + Conv3DReshapeFilter(FD, FH, FW, CI, CO, filterArr, filterReshaped); + Conv3DReshapeInput(N, D, H, W, CI, FD, FH, FW, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped); + + MatMulCSF2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, consSF); + + Conv3DReshapeMatMulOP(N, newD, newH, newW, CO, matmulOP, outArr); +} + (**************************) def void Transpose2(int32_pl s1, int32_pl s2, int64_al[s2][s1] inArr, int64_al[s1][s2] outArr){ for i=[0:s1]{ @@ -360,6 +758,60 @@ def void Pad442(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp }; } +def void Pad552(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl inps1, int32_pl inps2, int32_pl inps3, int32_pl inps4, int32_pl inps5, int64_al[inps1][inps2][inps3][inps4][inps5] inpArr, int32_pl pads1, int32_pl pads2, int32_pl[pads1][pads2] paddings, int64_al[s1][s2][s3][s4][s5] outArr){ + int32_pl lbounds1 = paddings[0][0]; + int32_pl rbounds1excl = s1-paddings[0][1]; + int32_pl lbounds2 = paddings[1][0]; + int32_pl rbounds2excl = s2-paddings[1][1]; + int32_pl lbounds3 = paddings[2][0]; + int32_pl rbounds3excl = s3-paddings[2][1]; + int32_pl lbounds4 = paddings[3][0]; + int32_pl rbounds4excl = s4-paddings[3][1]; + int32_pl lbounds5 = paddings[4][0]; + int32_pl rbounds5excl = s5-paddings[4][1]; + for i=[0:s1]{ + for j=[0:s2]{ + for k=[0:s3]{ + for l=[0:s4]{ + for m=[0:s5]{ + if ((i >= lbounds1) && (i < rbounds1excl) && (j >= lbounds2) && (j < rbounds2excl) && (k >= lbounds3) && (k < rbounds3excl) && (l >= lbounds4) && (l < rbounds4excl) && (m >= lbounds5) && (m < rbounds5excl)){ + outArr[i][j][k][l][m] = inpArr[i-paddings[0][0]][j-paddings[1][0]][k-paddings[2][0]][l-paddings[3][0]][m-paddings[4][0]]; + } + else{ + outArr[i][j][k][l][m] = 0L; + }; + }; + }; + }; + }; + }; +} + +def void PadONNX441(int32_pl o1, int32_pl o2, int32_pl o3, int32_pl o4, int32_pl i1, int32_pl i2, int32_pl i3, int32_pl i4, int64_al[i1][i2][i3][i4] inpArr, int32_pl pads, int32_pl[pads] paddings, int64_al[o1][o2][o3][o4] outArr) { + int32_pl lbounds1 = paddings[0]; + int32_pl rbounds1excl = o1 - paddings[4]; + int32_pl lbounds2 = paddings[1]; + int32_pl rbounds2excl = o2 - paddings[5]; + int32_pl lbounds3 = paddings[2]; + int32_pl rbounds3excl = o3 - paddings[6]; + int32_pl lbounds4 = paddings[3]; + int32_pl rbounds4excl = o4 - paddings[7]; + for i=[0:o1]{ + for j=[0:o2]{ + for k=[0:o3]{ + for l=[0:o4]{ + if ((i >= lbounds1) && (i < rbounds1excl) && (j >= lbounds2) && (j < rbounds2excl) && (k >= lbounds3) && (k < rbounds3excl) && (l >= lbounds4) && (l < rbounds4excl)){ + outArr[i][j][k][l] = inpArr[i-paddings[0]][j-paddings[1]][k-paddings[2]][l-paddings[3]]; + } + else{ + outArr[i][j][k][l] = 0L; + }; + }; + }; + }; + }; +} + (**************************) (* Squeeze where the input is a 4D tensor, output is a 2D tensor and hence 2 dims are getting squeezed. *) def void Squeeze24(int32_pl s1, int32_pl s2, int32_pl dim1, int32_pl dim2, int32_pl ins1, int32_pl ins2, int32_pl ins3, int32_pl ins4, int64_al[ins1][ins2][ins3][ins4] inArr, int64_al[s1][s2] outArr){ @@ -380,6 +832,238 @@ def void Squeeze24(int32_pl s1, int32_pl s2, int32_pl dim1, int32_pl dim2, int32 } +(**************************) +(* Generic implementation of ConvTranpose2D *) + +def void ConvTranspose2DReshapeMatMulOP(int32_pl N, int32_pl finalH, int32_pl finalW, int32_pl CO, int64_al[CO][N*finalH*finalW] inputArr, int64_al[N][finalH][finalW][CO] outputArr){ + + for co=[0:CO]{ + for n=[0:N]{ + for h=[0:finalH]{ + for w=[0:finalW]{ + outputArr[n][h][w][co] = inputArr[co][(n*finalH*finalW) + (h*finalW) + w]; + }; + }; + }; + }; +} + + +def void ConvTranspose2DReshapeFilter(int32_pl FH, int32_pl FW, int32_pl CO, int32_pl CI, int64_al[FH][FW][CO][CI] inputArr, int64_al[CO][FH*FW*CI] outputArr) +{ + for co=[0:CO]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + for ci=[0:CI]{ + int32_pl linIdx = (fh*FW*CI) + (fw*CI) + ci; + outputArr[co][linIdx] = inputArr[FH-1-fh][FW-1-fw][co][ci]; + }; + }; + }; + }; +} + +def void ConvTranspose2DReshapeInput(int32_pl N, int32_pl HPrime, int32_pl WPrime, int32_pl CI, int32_pl FH, int32_pl FW, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, int32_pl strideH, int32_pl strideW, int32_pl RRows, int32_pl RCols, int64_al[N][HPrime][WPrime][CI] inputArr, int64_al[RRows][RCols] outputArr){ + int32_pl linIdxFilterMult = 0; + for n=[0:N]{ + int32_pl leftTopCornerH = 0 - zPadTrHLeft; + int32_pl HPrimeTilde = HPrime + ((HPrime-1)*(strideH-1)); + int32_pl extremeRightBottomCornerH = HPrimeTilde - 1 + zPadTrHRight; + while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){ + int32_pl leftTopCornerW = 0 - zPadTrWLeft; + int32_pl WPrimeTilde = WPrime + ((WPrime-1)*(strideW-1)); + int32_pl extremeRightBottomCornerW = WPrimeTilde - 1 + zPadTrWRight; + while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){ + + for fh=[0:FH]{ + for fw=[0:FW]{ + int32_pl curPosH = leftTopCornerH + fh; + int32_pl curPosW = leftTopCornerW + fw; + int64_al val = 0L; + for ci=[0:CI]{ + if ((((curPosH < 0) || (curPosH >= HPrimeTilde)) || ((curPosW < 0) || (curPosW >= WPrimeTilde)))){ + val = 0L; + } + else{ + (* curPosH lies between 0 and HPrimeTilde *) + if (((curPosH % strideH) == 0) && ((curPosW % strideW) == 0)) { + int32_pl idxInputH = curPosH / strideH; + int32_pl idxInputW = curPosW / strideW; + val = inputArr[n][idxInputH][idxInputW][ci]; + } + else{ + val = 0L; (* This represents fractional stride. *) + }; + }; + outputArr[(fh*FW*CI) + (fw*CI) + ci][linIdxFilterMult] = val; + }; + }; + }; + + linIdxFilterMult = linIdxFilterMult + 1; + leftTopCornerW = leftTopCornerW + 1; (* Imp Note: The actual stride is always 1 *) + }; + + leftTopCornerH = leftTopCornerH + 1; (* Imp Note: The actual stride is always 1 *) + }; + }; +} + +(* int64_al[N][HPrime][WPrime][CI] inputArr, + int64_al[FH][FW][CO][CI] filter, + int64_al[N][H][W][CO] outputArr +*) +def void ConvTranspose2DCSF(int32_pl N, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl H, int32_pl W, + int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideH, int32_pl strideW, + int64_al[N][HPrime][WPrime][CI] inputArr, + int64_al[FH][FW][CO][CI] filterArr, + int32_pl consSF, + int64_al[N][H][W][CO] outArr) +{ + int32_pl reshapedFilterRows = CO; + int32_pl reshapedFilterCols = FH*FW*CI; + int32_pl reshapedIPRows = FH*FW*CI; + int32_pl reshapedIPCols = N * H * W; + + int64_al[reshapedFilterRows][reshapedFilterCols] filterReshaped; + int64_al[reshapedIPRows][reshapedIPCols] inputReshaped; + int64_al[reshapedFilterRows][reshapedIPCols] matmulOP; + + ConvTranspose2DReshapeFilter(FH, FW, CO, CI, filterArr, filterReshaped); + ConvTranspose2DReshapeInput(N, HPrime, WPrime, CI, FH, FW, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideH, strideW, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped); + + MatMulCSF2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, consSF); + + ConvTranspose2DReshapeMatMulOP(N, H, W, CO, matmulOP, outArr); +} + +(**************************) +(* Generic implementation of ConvTranpose3D *) + +def void ConvTranspose3DReshapeFilter(int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, int32_pl CI, int64_al[FD][FH][FW][CO][CI] inputArr, int64_al[CO][FD*FH*FW*CI] outputArr) +{ + for co=[0:CO]{ + for fd=[0:FD]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + for ci=[0:CI]{ + int32_pl linIdx = (fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci; + outputArr[co][linIdx] = inputArr[FD-1-fd][FH-1-fh][FW-1-fw][co][ci]; + }; + }; + }; + }; + }; +} + +def void ConvTranspose3DReshapeInput(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, int32_pl FD, int32_pl FH, int32_pl FW, int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, int32_pl strideD, int32_pl strideH, int32_pl strideW, int32_pl RRows, int32_pl RCols, int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, int64_al[RRows][RCols] outputArr){ + int32_pl linIdxFilterMult = 0; + for n=[0:N]{ + int32_pl leftTopCornerD = 0 - zPadTrDLeft; + int32_pl DPrimeTilde = DPrime + ((DPrime-1)*(strideD-1)); + int32_pl extremeRightBottomCornerD = DPrimeTilde - 1 + zPadTrDRight; + while((leftTopCornerD + FD - 1) <= extremeRightBottomCornerD){ + int32_pl leftTopCornerH = 0 - zPadTrHLeft; + int32_pl HPrimeTilde = HPrime + ((HPrime-1)*(strideH-1)); + int32_pl extremeRightBottomCornerH = HPrimeTilde - 1 + zPadTrHRight; + while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){ + int32_pl leftTopCornerW = 0 - zPadTrWLeft; + int32_pl WPrimeTilde = WPrime + ((WPrime-1)*(strideW-1)); + int32_pl extremeRightBottomCornerW = WPrimeTilde - 1 + zPadTrWRight; + while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){ + + for fd=[0:FD]{ + for fh=[0:FH]{ + for fw=[0:FW]{ + int32_pl curPosD = leftTopCornerD + fd; + int32_pl curPosH = leftTopCornerH + fh; + int32_pl curPosW = leftTopCornerW + fw; + int64_al val = 0L; + for ci=[0:CI]{ + if (((curPosD < 0) || (curPosD >= DPrimeTilde)) || ((curPosH < 0) || (curPosH >= HPrimeTilde)) || ((curPosW < 0) || (curPosW >= WPrimeTilde))) { + val = 0L; + } + else{ + (* curPosH lies between 0 and HPrimeTilde *) + if (((curPosD % strideD) == 0) && ((curPosH % strideH) == 0) && ((curPosW % strideW) == 0)) { + int32_pl idxInputD = curPosD / strideD; + int32_pl idxInputH = curPosH / strideH; + int32_pl idxInputW = curPosW / strideW; + val = inputArr[n][idxInputD][idxInputH][idxInputW][ci]; + } + else{ + val = 0L; (* This represents fractional stride. *) + }; + }; + outputArr[(fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci][linIdxFilterMult] = val; + }; + }; + }; + }; + + linIdxFilterMult = linIdxFilterMult + 1; + leftTopCornerW = leftTopCornerW + 1; (* Imp Note: The actual stride is always 1 *) + }; + + leftTopCornerH = leftTopCornerH + 1; (* Imp Note: The actual stride is always 1 *) + }; + + leftTopCornerD = leftTopCornerD + 1; (* Imp Note: The actual stride is always 1 *) + }; + }; +} + +(* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int64_al[FD][FH][FW][CO][CI] filter, + int64_al[N][D][H][W][CO] outputArr +*) +def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl D, int32_pl H, int32_pl W, + int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int64_al[FD][FH][FW][CO][CI] filterArr, + int32_pl consSF, + int64_al[N][D][H][W][CO] outArr) +{ + ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr); +} + +(* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int64_al[FD][FH][FW][CO][CI] filter, + int64_al[N][D][H][W][CO] outputArr +*) +def void ConvTranspose3DCSF(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl D, int32_pl H, int32_pl W, + int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int64_al[FD][FH][FW][CO][CI] filterArr, + int32_pl consSF, + int64_al[N][D][H][W][CO] outArr) +{ + int32_pl reshapedFilterRows = CO; + int32_pl reshapedFilterCols = FD*FH*FW*CI; + int32_pl reshapedIPRows = FD*FH*FW*CI; + int32_pl reshapedIPCols = N * D * H * W; + + int64_al[reshapedFilterRows][reshapedFilterCols] filterReshaped; + int64_al[reshapedIPRows][reshapedIPCols] inputReshaped; + int64_al[reshapedFilterRows][reshapedIPCols] matmulOP; + + ConvTranspose3DReshapeFilter(FD, FH, FW, CO, CI, filterArr, filterReshaped); + ConvTranspose3DReshapeInput(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped); + + MatMulCSF2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, consSF); + + Conv3DReshapeMatMulOP(N, D, H, W, CO, matmulOP, outArr); +} + (**************************) def void ClearMemPublic(int32_pl x){ return; @@ -387,4 +1071,14 @@ def void ClearMemPublic(int32_pl x){ def void ClearMemPublic1(int32_pl s, int32_pl[s] x){ return; +} + +def void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl[s1][s2][s3][s4] arr) +{ + return; +} + +def void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr) +{ + return; } \ No newline at end of file diff --git a/Athos/TFEzPCLibrary/Library64_cpp.ezpc b/Athos/TFEzPCLibrary/Library64_cpp.ezpc index 3393fa06d5e2123e449f554e8972c07d30af7256..3c2a6e06d3f05922452c4bb3b398142365adb8d7 100644 --- a/Athos/TFEzPCLibrary/Library64_cpp.ezpc +++ b/Athos/TFEzPCLibrary/Library64_cpp.ezpc @@ -34,6 +34,145 @@ def void MatMulCSF2D(int32_pl i, int32_pl j, int32_pl k, int64_al[i][j] A, int64 }; } +(**************************) +(* These loop implementations of convolution run faster with multithreading *) + +def void Conv2DLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, + int32_pl outH, int32_pl outW, int32_pl G, + int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI/G][CO] filterArr, + int32_pl consSF, + int64_al[N][outH][outW][CO] outArr){ + + int32_pl GIS = CI/G; + int32_pl GOS = CO/G; + + for n=[0:N]{ + for cog=[0:GOS]{ + for cig=[0:GIS]{ + for g=[0:G]{ + for h=[0:outH]{ + for w=[0:outW]{ + + int64_al val = 0L; + int32_pl ci = GIS*g + cig; + int32_pl co = GOS*g + cog; + int32_pl curPosH = strideH*h-zPadHLeft; + + for fh=[0:FH]{ + int32_pl curPosW = strideW*w-zPadWLeft; + + for fw=[0:FW]{ + if( (curPosH >= 0) && (curPosW >= 0) && (curPosH < H) && (curPosW < W)){ + val = val +_al (inputArr[n][curPosH][curPosW][ci]*filterArr[fh][fw][(ci/G)][co]); + }; + + curPosW = curPosW + 1; + }; + curPosH = curPosH + 1; + }; + + outArr[n][h][w][co] = outArr[n][h][w][co] +_al (val >> consSF); + }; + }; + }; + }; + }; + }; +} + +(**************************) +def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight,int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_pl outD, int32_pl outH, int32_pl outW, + int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CI][CO] filterArr, + int32_pl consSF, + int64_al[N][outD][outH][outW][CO] outArr){ + + for n=[0:N]{ + for co=[0:CO]{ + for d=[0:outD]{ + for h=[0:outH]{ + for w=[0:outW]{ + for ci=[0:CI]{ + int64_al val = 0L; + for fd=[d*strideD:d*strideD+FD]{ + for fh=[h*strideH:h*strideH+FH]{ + for fw=[w*strideW:w*strideW+FW]{ + int32_pl curPosD = fd-zPadDLeft; + int32_pl curPosH = fh-zPadHLeft; + int32_pl curPosW = fw-zPadWLeft; + if( (curPosD >= 0) && (curPosH >= 0) && (curPosW >= 0) && (curPosD < D) && (curPosH < H) && (curPosW < W)){ + int32_pl curFilterPosD = fd-(d*strideD); + int32_pl curFilterPosH = fh-(h*strideH); + int32_pl curFilterPosW = fw-(w*strideW); + val = val +_al (inputArr[n][curPosD][curPosH][curPosW][ci]*filterArr[curFilterPosD][curFilterPosH][curFilterPosW][ci][co]); + }; + }; + }; + }; + outArr[n][d][h][w][co] = outArr[n][d][h][w][co] +_al (val >> consSF); + }; + }; + }; + }; + }; + }; +} + + +(**************************) +def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight,int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_pl outD, int32_pl outH, int32_pl outW, + int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CO][CI] filterArr, + int32_pl consSF, + int64_al[N][outD][outH][outW][CO] outArr){ + + for n=[0:N]{ + for co=[0:CO]{ + for d=[0:outD]{ + for h=[0:outH]{ + for w=[0:outW]{ + for ci=[0:CI]{ + int64_al val = 0L; + for fd=[d:d+FD]{ + for fh=[h:h+FH]{ + for fw=[w:w+FW]{ + + int32_pl curPosD = (fd-zPadDLeft)/strideD; + int32_pl curPosH = (fh-zPadHLeft)/strideD; + int32_pl curPosW = (fw-zPadWLeft)/strideD; + + if( (curPosD >= 0) && (curPosH >= 0) && (curPosW >= 0) && (curPosD < D) && (curPosH < H) && (curPosW < W) && ((fd-zPadDLeft)%strideD == 0) && ((fh-zPadHLeft)%strideH == 0) && ((fw-zPadWLeft)%strideW == 0)){ + + int32_pl curFilterPosD = FD+d-fd-1; + int32_pl curFilterPosH = FH+h-fh-1; + int32_pl curFilterPosW = FW+w-fw-1; + val = val +_al (inputArr[n][curPosD][curPosH][curPosW][ci]*filterArr[curFilterPosD][curFilterPosH][curFilterPosW][co][ci]); + }; + }; + }; + }; + outArr[n][d][h][w][co] = outArr[n][d][h][w][co] +_al (val >> consSF); + }; + }; + }; + }; + }; + }; +} + + (**************************) def void ArgMax1(int32_pl outArrS1, int32_pl inArrS1, int32_pl inArrS2, int64_al[inArrS1][inArrS2] inArr, int32_pl dim, int64_al[outArrS1] outArr){ for od=[0:inArrS1]{ @@ -89,37 +228,115 @@ def void Relu4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][ }; } +def void Relu5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] inArr, int64_al[s1][s2][s3][s4][s5] outArr){ + for i1=[0:s1]{ + for i2=[0:s2]{ + for i3=[0:s3]{ + for i4=[0:s4]{ + for i5=[0:s5]{ + outArr[i1][i2][i3][i4][i5] = (inArr[i1][i2][i3][i4][i5] > 0L ? inArr[i1][i2][i3][i4][i5] : 0L); + }; + }; + }; + }; + }; +} (**************************) -def void ElemWiseMul2(int32_pl s1, int32_pl s2, int64_al[s1][s2] arr1, int64_al[s1][s2] arr2, int64_al[s1][s2] outArr, int64_pl shrout){ - for i1=[0:s1]{ - for i2=[0:s2]{ - outArr[i1][i2] = ((arr1[i1][i2] * arr2[i1][i2]) >> shrout); - }; - }; +def void ElemWiseMul2(int32_pl a1, int32_pl a2, int32_pl b1, int32_pl b2, int32_pl s1, int32_pl s2, int64_al[a1][a2] A, int64_al[b1][b2] B, int64_al[s1][s2] outArr, int64_pl shrout){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + outArr[i1][i2] = ((A[aIdx1][aIdx2] * B[bIdx1][bIdx2]) >> shrout); + }; + }; } -def void ElemWiseMul4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] arr1, int64_al[s1][s2][s3][s4] arr2, int64_al[s1][s2][s3][s4] outArr, int64_pl shrout){ - for i1=[0:s1]{ - for i2=[0:s2]{ - for i3=[0:s3]{ - for i4=[0:s4]{ - outArr[i1][i2][i3][i4] = ((arr1[i1][i2][i3][i4] * arr2[i1][i2][i3][i4]) >> shrout); - }; - }; - }; - }; +def void ElemWiseMul4(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl b1, int32_pl b2, int32_pl b3, int32_pl b4, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[a1][a2][a3][a4] A, int64_al[b1][b2][b3][b4] B, int64_al[s1][s2][s3][s4] outArr, int64_pl shrout){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl aIdx3 = 0; + int32_pl aIdx4 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + int32_pl bIdx3 = 0; + int32_pl bIdx4 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + for i3=[0:s3]{ + aIdx3 = ((a3 == 1) ? 0 : i3); + bIdx3 = ((b3 == 1) ? 0 : i3); + for i4=[0:s4]{ + aIdx4 = ((a4 == 1) ? 0 : i4); + bIdx4 = ((b4 == 1) ? 0 : i4); + outArr[i1][i2][i3][i4] = ((A[aIdx1][aIdx2][aIdx3][aIdx4] * B[bIdx1][bIdx2][bIdx3][bIdx4]) >> shrout); + }; + }; + }; + }; } -(**************************) -def void ElemWiseDiv2(int32_pl s1, int32_pl s2, int64_al[s1][s2] arr1, int64_al[s1][s2] arr2, int64_al[s1][s2] outArr, int64_pl shrout){ - for i1=[0:s1]{ - for i2=[0:s2]{ - outArr[i1][i2] = ((arr1[i1][i2] / arr2[i1][i2]) << shrout); - }; - }; +def void ElemWiseMul5(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl a5, int32_pl b1, int32_pl b2, int32_pl b3, int32_pl b4, int32_pl b5, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[a1][a2][a3][a4][a5] A, int64_al[b1][b2][b3][b4][b5] B, int64_al[s1][s2][s3][s4][s5] outArr, int64_pl shrout){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl aIdx3 = 0; + int32_pl aIdx4 = 0; + int32_pl aIdx5 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + int32_pl bIdx3 = 0; + int32_pl bIdx4 = 0; + int32_pl bIdx5 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + for i3=[0:s3]{ + aIdx3 = ((a3 == 1) ? 0 : i3); + bIdx3 = ((b3 == 1) ? 0 : i3); + for i4=[0:s4]{ + aIdx4 = ((a4 == 1) ? 0 : i4); + bIdx4 = ((b4 == 1) ? 0 : i4); + for i5=[0:s5]{ + aIdx5 = ((a5 == 1) ? 0 : i5); + bIdx5 = ((b5 == 1) ? 0 : i5); + outArr[i1][i2][i3][i4][i5] = ((A[aIdx1][aIdx2][aIdx3][aIdx4][aIdx5] * B[bIdx1][bIdx2][bIdx3][bIdx4][bIdx5]) >> shrout); + }; + }; + }; + }; + }; } +(**************************) +def void ElemWiseDiv2(int32_pl a1, int32_pl a2, int32_pl b1, int32_pl b2, int32_pl s1, int32_pl s2, int64_al[a1][a2] A, int64_al[b1][b2] B, int64_al[s1][s2] outArr, int64_pl shrout){ + int32_pl aIdx1 = 0; + int32_pl aIdx2 = 0; + int32_pl bIdx1 = 0; + int32_pl bIdx2 = 0; + for i1=[0:s1]{ + aIdx1 = ((a1 == 1) ? 0 : i1); + bIdx1 = ((b1 == 1) ? 0 : i1); + for i2=[0:s2]{ + aIdx2 = ((a2 == 1) ? 0 : i2); + bIdx2 = ((b2 == 1) ? 0 : i2); + outArr[i1][i2] = ((A[aIdx1][aIdx2] / B[bIdx1][bIdx2]) >> shrout); + }; + }; +} (**************************) def void Floor2(int32_pl s1, int32_pl s2, int64_al[s1][s2] inArr, int64_al[s1][s2] outArr, int64_pl curSF){ for i1=[0:s1]{ @@ -274,11 +491,28 @@ def void FusedBatchNorm4411(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, }; } +def void FusedBatchNorm5511(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] inArr, int64_al[s5] multArr, int64_al[s5] biasArr, int32_pl consSF, int64_al[s1][s2][s3][s4][s5] outputArr){ + for i1=[0:s1]{ + for i2=[0:s2]{ + for i3=[0:s3]{ + for i4=[0:s4]{ + for i5=[0:s5]{ + int64_al t1 = (inArr[i1][i2][i3][i4][i5] *_al multArr[i5]); + int64_al t2 = (t1 >> consSF); + outputArr[i1][i2][i3][i4][i5] = t2 + biasArr[i5]; + }; + }; + }; + }; + }; +} + + (**************************) def void ReduceMean24(int32_pl outS1, int32_pl outS2, int32_pl inS1, int32_pl inS2, int32_pl inS3, int32_pl inS4, int64_al[inS1][inS2][inS3][inS4] inputArr, - int64_al[2] axes, + int32_pl[2] axes, int64_al[outS1][outS2] outputArr ) { @@ -297,6 +531,29 @@ def void ReduceMean24(int32_pl outS1, int32_pl outS2, }; } +(* This one is used for onnx compilation *) +def void ReduceMeanONNX24(int32_pl outS1, int32_pl outS2, + int32_pl inS1, int32_pl inS2, int32_pl inS3, int32_pl inS4, + int64_al[inS1][inS2][inS3][inS4] inputArr, + int32_pl axis1, int32_pl axis2, + int64_al[outS1][outS2] outputArr + ) +{ + for i1=[0:outS1]{ + for i2=[0:outS2]{ + int64_al summ = 0L; + for i=[0:inS3]{ + for j=[0:inS4]{ + summ = summ + inputArr[i1][i2][i][j]; + }; + }; + int64_pl numElem = inS3*inS4; + summ = summ / numElem; + outputArr[i1][i2] = summ; + }; + }; +} + (**************************) def void ClearMemSecret1(int32_pl s1, int64_al[s1] arr) { @@ -318,6 +575,11 @@ def void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int return; } +def void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr) +{ + return; +} + def void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr) { return; @@ -332,4 +594,4 @@ def void StartComputation() def void EndComputation() { return; -} +} \ No newline at end of file diff --git a/Athos/TFEzPCLibrary/Library64_porthos.ezpc b/Athos/TFEzPCLibrary/Library64_porthos.ezpc index f103eae050851d151df71576ed8e6e9be566a780..cc1cf66d87c7718f4d3bd12f851e724692ddc3f9 100644 --- a/Athos/TFEzPCLibrary/Library64_porthos.ezpc +++ b/Athos/TFEzPCLibrary/Library64_porthos.ezpc @@ -36,6 +36,7 @@ extern void ArgMax3(int32_pl outs1, int32_pl outs2, int32_pl outs3, (**************************) extern void Relu2(int32_pl s1, int32_pl s2, int64_al[s1][s2] inArr, int64_al[s1][s2] outArr); extern void Relu4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] inArr, int64_al[s1][s2][s3][s4] outArr); +extern void Relu5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] inArr, int64_al[s1][s2][s3][s4][s5] outArr); (**************************) extern void ElemWiseMul2(int32_pl s1, int32_pl s2, int64_al[s1][s2] arr1, int64_al[s1][s2] arr2, int64_al[s1][s2] outArr, int64_pl shrout); @@ -75,6 +76,7 @@ extern void ClearMemSecret1(int32_pl s1, int64_al[s1] arr); extern void ClearMemSecret2(int32_pl s1, int32_pl s2, int64_al[s1][s2] arr); extern void ClearMemSecret3(int32_pl s1, int32_pl s2, int32_pl s3, int64_al[s1][s2][s3] arr); extern void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] arr); +extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr) extern void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr);