import argparse
from argparse import RawTextHelpFormatter

import os
import os.path
import json
import sys

import TFCompiler.ProcessTFGraph as Athos
import CompilerScripts.parse_config as parse_config
import CompilerScripts.compile_tf as compile_tf


def parse_args():
  parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter)
  parser.add_argument(
    "--config",
    required=True,
    type=str,
    help="""Path to the config json file
Config file should be a json in the following format:
{
  // Mandatory options

  "model_name":"model.pb",  // Tensorflow protobuf file to compile.
  "output_tensors":[
  "output1",
  "output2"
  ],
  "target":"PORTHOS2PC",  // Compilation target. ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC
  
  // Optional options
  "scale":10,         // Scaling factor to compile for. Defaults to 12.
  "bitlength":64,       // Bit length to compile for. Defaults to 64.
  "save_weights" : true,  // Save model scaled weights in fixed point. Defaults to false.

  "input_tensors":{           // Name and shape of the input tensors
  "actual_input_1":"224,244,3",     // for the model. Not required if the
  "input2":"2,245,234,3"        // placeholder nodes have shape info.
  },
  "modulo" : 32,      // Modulo to be used for shares. Applicable for 
              // CPPRING/PORTHOS2PC backend. For 
              // PORTHOS2PC + backend=OT => Power of 2 
              // PORTHOS2PC + backend=HE => Prime value."

  "backend" : "OT",     // Backend to be used - OT/HE (default OT). 
              // Only applicable for PORTHOS2PC backend

  "disable_all_hlil_opts" : false,    // Disable all optimizations in HLIL
  "disable_relu_maxpool_opts" : false,  // Disable Relu-Maxpool optimization
  "disable_garbage_collection" : false,   // Disable Garbage Collection optimization
  "disable_trunc_opts" : false      // Disable truncation placement optimization
}
""",
  )
  args = parser.parse_args()
  return args

def generate_code(params):
  # Mandatory
  model_name = params["model_name"]
  input_tensor_info = params["input_tensors"]
  output_tensors = params["output_tensors"]
  scale = 12 if params["scale"] is None else params["scale"]
  bitlength = 64 if params["bitlength"] is None else params["bitlength"]
  target = params["target"]
  save_weights = params["save_weights"]
  save_weights = False if save_weights is None else save_weights

  assert bitlength <= 64 and bitlength >= 1, "Bitlen must be >= 1 and <= 64"
  assert target in [
    "PORTHOS",
    "PORTHOS2PC",
    "ABY",
    "CPP",
    "CPPRING",
  ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC"

  cwd = os.getcwd()
  athos_dir = os.path.dirname(os.path.abspath(__file__))
  model_abs_path = os.path.abspath(model_name)
  model_abs_dir = os.path.dirname(model_abs_path)
  # Generate graphdef and sizeInfo metadata
  weights_path = compile_tf.compile(
    model_name, input_tensor_info, output_tensors, scale, save_weights
  )

  # Compile to seedot. Generate AST in model directory
  Athos.process_tf_graph(model_abs_path)

  # Compile to ezpc
  model_base_name = os.path.basename(model_abs_path)[:-3]
  ezpc_file_name = "{mname}_{bl}_{target}.ezpc".format(
    mname=model_base_name, bl=bitlength, target=target.lower()
  )
  ezpc_abs_path = os.path.join(model_abs_dir, ezpc_file_name)
  disable_all_hlil_opts = (
    False
    if params["disable_all_hlil_opts"] is None
    else params["disable_all_hlil_opts"]
  )
  disable_relu_maxpool_opts = (
    False
    if params["disable_relu_maxpool_opts"] is None
    else params["disable_relu_maxpool_opts"]
  )
  disable_garbage_collection = (
    False
    if params["disable_garbage_collection"] is None
    else params["disable_garbage_collection"]
  )
  disable_trunc_opts = (
    False if params["disable_trunc_opts"] is None else params["disable_trunc_opts"]
  )
  seedot_args = ""
  seedot_args += "--astFile {}/astOutput.pkl --consSF {} ".format(
    model_abs_dir, scale
  )
  seedot_args += "--bitlen {} --outputFileName {} ".format(bitlength, ezpc_abs_path)
  seedot_args += "--disableAllOpti {} ".format(disable_all_hlil_opts)
  seedot_args += "--disableRMO {} ".format(disable_relu_maxpool_opts)
  seedot_args += "--disableLivenessOpti {} ".format(disable_garbage_collection)
  seedot_args += "--disableTruncOpti {} ".format(disable_trunc_opts)

  seedot_script = os.path.join(athos_dir, "SeeDot", "SeeDot.py")
  print("python3 {} ".format(seedot_script) + seedot_args)
  os.system("python3 {} ".format(seedot_script) + seedot_args)

  # Add library functions
  if target in ["ABY", "CPPRING"]:
    library = "cpp"
  else:
    library = target.lower()

  lib_bitlength = 64 if bitlength > 32 else 32
  library_dir = os.path.join(athos_dir, "TFEzPCLibrary")
  common = os.path.join(library_dir, "Library{}_common.ezpc".format(lib_bitlength))
  if library == "cpp":
    pre = os.path.join(
      library_dir, "Library{}_{}_pre.ezpc".format(lib_bitlength, library)
    )
    post = os.path.join(
      library_dir, "Library{}_{}_post.ezpc".format(lib_bitlength, library)
    )
  else:
    pre = os.path.join(
      library_dir, "Library{}_{}.ezpc".format(lib_bitlength, library)
    )
    post = ""
  temp = os.path.join(model_abs_dir, "temp.ezpc")
  os.system(
    "cat {pre} {common} {post} {ezpc}> {temp}".format(
      pre=pre, common=common, post=post, ezpc=ezpc_abs_path, temp=temp
    )
  )
  os.system("mv {temp} {ezpc}".format(temp=temp, ezpc=ezpc_abs_path))

  modulo = params["modulo"]
  backend = "OT" if params["backend"] is None else params["backend"]
  ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/")
  # Copy generated code to the ezpc directory
  os.system("cp {ezpc} {ezpc_dir}".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir))
  os.chdir(ezpc_dir)
  ezpc_args = ""
  ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac ".format(
    bl=bitlength, target=target
  )
  output_name = ezpc_file_name[:-5] + "0.cpp"
  if modulo is not None:
    ezpc_args += "--modulo {} ".format(modulo)
  if target == "PORTHOS2PC":
    ezpc_args += "--backend {} ".format(backend.upper())
    output_name = ezpc_file_name[:-5] + "_{}0.cpp".format(backend.upper())
  if target in ["PORTHOS"]:
    ezpc_args += "--sf {} ".format(scale)

  os.system(
    "eval `opam config env`; ./ezpc.sh {} ".format(ezpc_file_name) + ezpc_args
  )
  os.system(
    "cp {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir)
  )
  output_file = os.path.join(model_abs_dir, output_name)

  if target == "PORTHOS2PC":
    program_name = model_base_name + "_" + target + "_" + backend + ".out"
  else:
    program_name = model_base_name + "_" + target + ".out"
  program_path = os.path.join(model_abs_dir, program_name)
  os.chdir(model_abs_dir)
  if target in [ "CPP", "CPPRING"]:
    os.system(
      "g++ -O3 -w {file} -o {output}".format(file=output_file, output=program_path)
    )
  elif target == "PORTHOS":
    porthos_src = os.path.join(athos_dir, "..", "Porthos", "src")
    porthos_lib = os.path.join(porthos_src, "build", "lib")
    if os.path.exists(porthos_lib):
      os.system(
        """g++ -O3 -fopenmp -pthread -w -march=native -msse4.1 -maes -mpclmul \
        -mrdseed -fpermissive -fpic -std=c++17 -L {porthos_lib} -I {porthos_headers} {file} \
        -lPorthos-Protocols -lssl -lcrypto -lrt -lboost_system \
        -o {output}""".format(porthos_lib=porthos_lib, porthos_headers=porthos_src,
          file=output_file, output=program_path)
      )
    else:
      print("Not compiling generated code. Please follow the readme and build Porthos.")
  elif target == "PORTHOS2PC":
    sci = os.path.join(athos_dir, "..", "SCI")
    sci_src = os.path.join(sci, "src")
    sci_lib = os.path.join(sci, "build", "lib")
    eigen_path = os.path.join(sci, "extern", "eigen")
    seal_lib_path = os.path.join(sci, "extern", "SEAL", "native", "lib")
    if os.path.exists(sci_lib):
      os.system(
        """g++ -O3 -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \
        -faligned-new -std=c++17 -fopenmp -I {eigen} -I {sci_src} {file} \
        -L {sci_lib} -lSCI-LinearHE -L {seal} -lseal -lssl -lcrypto \
        -o {output}""".format(eigen=eigen_path, sci_src=sci_src,
          file=output_file,sci_lib=sci_lib,seal=seal_lib_path, output=program_path)
      )
    else:
      print("Not compiling generated code. Please follow the readme and build SCI.")

  os.chdir(cwd)
  return (program_path, weights_path)

if __name__ == "__main__":
  args = parse_args()
  params = parse_config.get_params(args.config)
  generate_code(params)