Skip to content
Snippets Groups Projects
Commit cd3356cd authored by Bhatu's avatar Bhatu
Browse files

Update test cases and add timeout

parent 055da16c
No related branches found
No related tags found
No related merge requests found
......@@ -164,7 +164,6 @@ def test_matmul(
config = Config(backend).add_input(a).add_output(output)
if not bisModel:
config.add_input(b)
config.config["scale"] = 12
compiler = Compiler(graph, config, test_dir)
mpc_output = compiler.compile_and_run([a_inp])
assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2)
......
......@@ -45,6 +45,8 @@ from tests.utils import Config, Compiler, assert_almost_equal
)
@pytest.mark.parametrize("dtype", [np.single])
def test_conv(test_dir, backend, tfOp, a_shape, kernel_shape, strides, padding, dtype):
if tfOp == tf.nn.conv3d and backend in ["2PC_HE", "2PC_OT"]:
pytest.skip("[conv3d] Missing Support in SCI")
graph = tf.Graph()
a_inp = dtype(np.random.randn(*a_shape))
kernel_inp = dtype(np.random.randn(*kernel_shape))
......@@ -56,6 +58,7 @@ def test_conv(test_dir, backend, tfOp, a_shape, kernel_shape, strides, padding,
expected_output = sess.run(output, feed_dict={a: a_inp})
config = Config(backend).add_input(a).add_output(output)
config.config["scale"] = 12
compiler = Compiler(graph, config, test_dir)
mpc_output = compiler.compile_and_run([a_inp])
assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2)
......@@ -96,6 +99,8 @@ def test_conv_transpose(
padding,
dtype,
):
if backend in ["2PC_HE", "2PC_OT"]:
pytest.skip("[conv3d] Missing Support in SCI")
graph = tf.Graph()
a_inp = dtype(np.random.randn(*a_shape))
kernel_inp = dtype(np.random.randn(*kernel_shape))
......@@ -107,6 +112,7 @@ def test_conv_transpose(
expected_output = sess.run(output, feed_dict={a: a_inp})
config = Config(backend).add_input(a).add_output(output)
config.config["scale"] = 12
compiler = Compiler(graph, config, test_dir)
mpc_output = compiler.compile_and_run([a_inp])
assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2)
......
......@@ -35,7 +35,7 @@ from tests.utils import Config, Compiler, assert_almost_equal
@pytest.mark.skip(reason="[non-linear] Haven't made non-linear functionalities public")
@pytest.mark.parametrize("a_shape", [(4, 4), (1,), ()])
@pytest.mark.parametrize("a_shape", [[4, 4], [1], []])
@pytest.mark.parametrize("dtype", [np.single])
@pytest.mark.parametrize(
"tfOp",
......@@ -63,7 +63,7 @@ def test_non_linear(test_dir, backend, tfOp, a_shape, dtype):
return
@pytest.mark.skip(reason="[softmax] Haven't made non-linear functionalities public")
@pytest.mark.parametrize("a_shape, axis", [((2, 3), 1), ((1,), 0)])
@pytest.mark.parametrize("a_shape, axis", [([2, 3], 1), ([1], 0)])
@pytest.mark.parametrize("dtype", [np.single])
def test_softmax(test_dir, backend, a_shape, axis, dtype):
graph = tf.Graph()
......
......@@ -40,11 +40,9 @@ from tests.utils import Config, Compiler, assert_almost_equal
([2, 3], [6]),
([6], [2, 3]),
([2, 3], [3, 2]),
([2, 3], [-1]), # Flatten 1-D
pytest.param(
[1], [], marks=pytest.mark.skip(reason="[reshape] dumping weights error")
), # convert to scalar
([3, 2, 3], [2, -1]), # infer -1 as 9
([2, 3], [-1]), # Flatten 1-D,
([1], []), # convert to scalar,
([3, 2, 3], [2, -1]), # infer -1 as 9,
([3, 2, 3], [-1, 9]), # infer -1 as 2
],
)
......@@ -126,7 +124,6 @@ def test_split(test_dir, backend, a_shape, num_or_size_splits, axis, dtype):
# Squeeze
# TODO: also add a squeeze dim example.
@pytest.mark.parametrize(
"a_shape, axis",
[
......
......@@ -47,12 +47,12 @@ from tests.utils import Config, Compiler, assert_almost_equal
),
tf.shape,
tf.identity,
pytest.param(
tf.zeros_like, marks=pytest.mark.skip(reason="[zeros_like] EzPC issue for inp=[2,2]")
),
tf.zeros_like,
],
)
def test_uop(test_dir, backend, tfOp, a_shape, dtype):
if backend.startswith("2PC") and tfOp == tf.math.square:
pytest.skip("[SCI][square] Secret Secret mul not implemented")
graph = tf.Graph()
a_inp = dtype(np.random.randn(*a_shape))
with graph.as_default():
......@@ -80,7 +80,7 @@ def test_uop(test_dir, backend, tfOp, a_shape, dtype):
)
@pytest.mark.parametrize("dtype", [np.single])
@pytest.mark.parametrize("tfOp", [tf.math.reduce_mean, tf.reduce_sum])
@pytest.mark.skip(reason="[reduce] Reduce mean assert shape failure")
#@pytest.mark.skip(reason="[reduce] Reduce mean output mismatch and shape failure")
def test_reduce(test_dir, backend, tfOp, a_shape, axis, keepdims, dtype):
graph = tf.Graph()
a_inp = dtype(np.random.randn(*a_shape))
......@@ -103,10 +103,13 @@ def test_reduce(test_dir, backend, tfOp, a_shape, axis, keepdims, dtype):
([3, 2], None),
([3, 2], 0),
([3, 2], 1),
([3, 2, 3], 1),
([3, 2, 1, 1], 1),
([3, 2], 1),
],
)
@pytest.mark.parametrize("dtype", [np.single])
@pytest.mark.skip(reason="[argmax] Generic argmax not implemented")
@pytest.mark.skip(reason="[argmax] Need support for argmax along arbitrary axis")
def test_argmax(test_dir, backend, a_shape, axis, dtype):
graph = tf.Graph()
a_inp = dtype(np.random.randn(*a_shape))
......@@ -143,6 +146,8 @@ def test_argmax(test_dir, backend, a_shape, axis, dtype):
def test_pool(
test_dir, backend, tfOp, a_shape, ksize, strides, padding, data_format, dtype
):
if backend.startswith("2PC") and tfOp == tf.nn.max_pool:
pytest.skip("[SCI][maxpool] Output mismatch bug")
graph = tf.Graph()
a_inp = dtype(np.random.randn(*a_shape))
with graph.as_default():
......@@ -173,9 +178,10 @@ def test_pool(
"from_dtype, to_dtype",
[
(np.single, np.single),
(
pytest.param(
np.double,
np.single,
marks=pytest.mark.skip(reason="[cast] Support for parsing DOUBLES"),
),
pytest.param(
np.single,
......@@ -212,6 +218,6 @@ def test_fill(test_dir, backend, a_shape, value):
config = Config(backend).add_output(output)
compiler = Compiler(graph, config, test_dir)
mpc_output = compiler.compile_and_run([])
mpc_output = compiler.compile_and_run([], timeoutSeconds=60)
assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2)
return
\ No newline at end of file
'''
"""
Authors: Pratik Bhatu.
......@@ -20,7 +20,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
"""
import tempfile
import sys
import os
......@@ -51,11 +51,13 @@ class Config:
elif mode == "2PC_OT":
self.config["target"] = "PORTHOS2PC"
self.config["bitlength"] = 41
self.config["scale"] = 12
self.config["backend"] = "OT"
elif mode == "2PC_HE":
self.config["target"] = "PORTHOS2PC"
self.config["bitlength"] = 41
self.config["scale"] = 12
self.config["backend"] = "HE"
else:
assert False, "Mode has to be one of CPP/3PC/2PC_OT/2PC_HE"
......@@ -129,7 +131,7 @@ class Program:
self.target = params["target"]
self.test_dir = test_dir
def run(self, inputs):
def run(self, inputs, timeoutSeconds):
# scale input and dump to file
inputs_scaled = os.path.join(
self.test_dir, "input_fixedpt_scale_" + str(self.scale) + ".inp"
......@@ -175,7 +177,10 @@ class Program:
commands = [client_cmd, server_cmd, party2_cmd]
procs = [subprocess.Popen(i, shell=True) for i in commands]
for p in procs:
p.wait()
try:
p.wait(timeoutSeconds)
except subprocess.TimeoutExpired:
p.kill()
elif self.target == "PORTHOS2PC":
util_dir = os.path.dirname(os.path.abspath(__file__))
sci_dir = os.path.join(util_dir, "..", "..", "SCI")
......@@ -195,7 +200,10 @@ class Program:
commands = [client_cmd, server_cmd]
procs = [subprocess.Popen(i, shell=True) for i in commands]
for p in procs:
p.wait()
try:
p.wait(timeoutSeconds)
except subprocess.TimeoutExpired:
p.kill()
return convert_raw_output_to_np(raw_output, self.bitlength, self.scale)
......@@ -205,13 +213,15 @@ class Compiler:
self.config = config.config
self.test_dir = test_dir
def compile_and_run(self, inputs):
def compile_and_run(self, inputs, timeoutSeconds=40):
save_graph(self.graph_def, self.config, self.test_dir)
params = get_params(self.config)
print(params)
(output_program, model_weight_file) = CompileTFGraph.generate_code(params)
(output_program, model_weight_file) = CompileTFGraph.generate_code(
params, debug=False
)
prog = Program(output_program, model_weight_file, params, self.test_dir)
output = prog.run(inputs)
output = prog.run(inputs, timeoutSeconds)
return output
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment