From 568ef1549e7c27c4ab8d25294e53e210d48b601e Mon Sep 17 00:00:00 2001 From: youben11 Date: Sat, 25 Jan 2020 11:15:56 +0100 Subject: [PATCH 01/13] New message type for crypten party initialization --- syft/messaging/message.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/syft/messaging/message.py b/syft/messaging/message.py index 39c7acb67a8..3dbfc383692 100644 --- a/syft/messaging/message.py +++ b/syft/messaging/message.py @@ -104,6 +104,22 @@ def __repr__(self): return self.__str__() +class CryptenInit(Message): + """Initialize a Crypten party using this message. + + Crypten uses processes as parties, those processes need to be initialized with information + so they can communicate and exchange tensors and shares while doing computation. This message + allows the exchange of information such as the ip and port of the master party to connect to, + as well as the rank of the party to run and the number of parties involved.""" + + def __init__(self, contents): + super().__init__(contents) + + @staticmethod + def detail(worker: AbstractWorker, msg_tuple: tuple) -> "CryptenInit": + return CryptenInit(sy.serde.msgpack.serde._detail(worker, msg_tuple[0])) + + class Operation(Message): """All syft operations use this message type From b5d03de8d3029283621aac720d5570cf622f3653 Mon Sep 17 00:00:00 2001 From: youben11 Date: Sat, 25 Jan 2020 11:16:28 +0100 Subject: [PATCH 02/13] Handle CryptenInit message by running local party --- syft/workers/base.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/syft/workers/base.py b/syft/workers/base.py index 84a780bd489..5331dc1744c 100644 --- a/syft/workers/base.py +++ b/syft/workers/base.py @@ -29,8 +29,10 @@ from syft.messaging.message import GetShapeMessage from syft.messaging.message import PlanCommandMessage from syft.messaging.message import SearchMessage +from syft.messaging.message import CryptenInit from syft.messaging.plan import Plan from syft.workers.abstract import AbstractWorker +from syft.frameworks.crypten import toy_func, run_party from syft.exceptions import GetNotPermittedError from syft.exceptions import WorkerNotFoundException @@ -118,6 +120,7 @@ def __init__( GetShapeMessage: self.get_tensor_shape, SearchMessage: self.search, ForceObjectDeleteMessage: self.force_rm_obj, + CryptenInit: self.run_crypten_party, # TODO: update Message to CryptenInit after implementing it } self._plan_command_router = { @@ -396,6 +399,20 @@ def send( return pointer + def run_crypten_party(self, message: tuple): + """Run crypten party according to the information received. + + Args: + message (CryptenInit): should contain the rank, world_size, master_addr and master_port. + + Returns: + An ObjectMessage containing the return value of the crypten function computed. + """ + + rank, world_size, master_addr, master_port = message + return_value = run_party(toy_func, rank, world_size, master_addr, master_port, (), {}) + return ObjectMessage(return_value) + def execute_command(self, message: tuple) -> PointerTensor: """ Executes commands received from other workers. From 321a617afacc26c149dc7edf89acde14e654aec0 Mon Sep 17 00:00:00 2001 From: youben11 Date: Sat, 25 Jan 2020 11:17:26 +0100 Subject: [PATCH 03/13] Define a new context of computation for crypten The new context of computation can run parties that are distributed across syft workers. The communication of crypten parties remains the same, however, syft workers handle initialization and the serialization of return values. --- syft/frameworks/crypten/__init__.py | 4 + syft/frameworks/crypten/context.py | 150 ++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 syft/frameworks/crypten/__init__.py create mode 100644 syft/frameworks/crypten/context.py diff --git a/syft/frameworks/crypten/__init__.py b/syft/frameworks/crypten/__init__.py new file mode 100644 index 00000000000..f6c3d28efe2 --- /dev/null +++ b/syft/frameworks/crypten/__init__.py @@ -0,0 +1,4 @@ +from syft.frameworks.crypten.context import toy_func, run_party + + +__all__ = ["toy_func", "run_party"] diff --git a/syft/frameworks/crypten/context.py b/syft/frameworks/crypten/context.py new file mode 100644 index 00000000000..7b1c035ff73 --- /dev/null +++ b/syft/frameworks/crypten/context.py @@ -0,0 +1,150 @@ +import functools +import multiprocessing +import threading +import os +import crypten +from crypten.communicator import DistributedCommunicator +from syft.messaging.message import CryptenInit + + +def _launch(func, rank, world_size, master_addr, master_port, queue, func_args, func_kwargs): + communicator_args = { + "RANK": rank, + "WORLD_SIZE": world_size, + "RENDEZVOUS": "env://", + "MASTER_ADDR": master_addr, + "MASTER_PORT": master_port, + "BACKEND": "gloo", + } + for key, val in communicator_args.items(): + os.environ[key] = str(val) + + crypten.init() + return_value = func(*func_args, **func_kwargs) + crypten.uninit() + + queue.put(return_value) + + +def _new_party(func, rank, world_size, master_addr, master_port, func_args, func_kwargs): + queue = multiprocessing.Queue() + process = multiprocessing.Process( + target=_launch, + args=(func, rank, world_size, master_addr, master_port, queue, func_args, func_kwargs), + ) + return process, queue + + +def run_party(func, rank, world_size, master_addr, master_port, func_args, func_kwargs): + """Start crypten party localy and run computation. + + Args: + func (function): computation to be done. + rank (int): rank of the crypten party. + world_size (int): number of crypten parties involved in the computation. + master_addr (str): IP address of the master party (party with rank 0). + master_port (int, str): port of the master party (party with rank 0). + func_args (list): arguments to be passed to func. + func_kwargs (dict): keyword arguments to be passed to func. + + Returns: + The return value of func. + """ + + process, queue = _new_party( + func, rank, world_size, master_addr, master_port, func_args, func_kwargs + ) + was_initialized = DistributedCommunicator.is_initialized() + if was_initialized: + crypten.uninit() + process.start() + process.join() + if was_initialized: + crypten.init() + return queue.get() + + +def _send_party_info(worker, rank, msg, return_values): + """Send message to worker with necessary information to run a crypten party. + Add response to return_values dictionary. + + Args: + worker (BaseWorker): worker to send the message to. + rank (int): rank of the crypten party. + msg (CryptenInit): message containing the rank, world_size, master_addr and master_port. + return_values (dict): dictionnary holding return values of workers. + """ + + response = worker.send_msg(msg, worker) + return_values[rank] = response.contents + + +def toy_func(): + # Toy function to be called by each party + alice_t = crypten.cryptensor([73, 81], src=0) + bob_t = crypten.cryptensor([90, 100], src=1) + out = bob_t.get_plain_text() + return out.tolist() # issues with putting torch tensors into queues + + +def run_multiworkers(workers: list, master_addr: str, master_port: int = 15987): + """Defines decorator to run function across multiple workers. + + Args: + workers (list): workers (parties) to be involved in the computation. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + # TODO: + # - check if workers are reachable / they can handle the computation + # - check return code of processes for possible failure + + world_size = len(workers) + 1 + return_values = {rank: None for rank in range(world_size)} + + # Start local party + process, queue = _new_party(toy_func, 0, world_size, master_addr, master_port, (), {}) + was_initialized = DistributedCommunicator.is_initialized() + if was_initialized: + crypten.uninit() + process.start() + # Run TTP if required + # TODO: run ttp in a specified worker + if crypten.mpc.ttp_required(): + ttp_process, _ = _new_party( + crypten.mpc.provider.TTPServer, + world_size, + world_size, + master_addr, + master_port, + (), + {}, + ) + ttp_process.start() + + # Send messages to other workers so they start their parties + threads = [] + for i in range(len(workers)): + rank = i + 1 + msg = CryptenInit((rank, world_size, master_addr, master_port)) + thread = threading.Thread( + target=_send_party_info, args=(workers[i], rank, msg, return_values) + ) + thread.start() + threads.append(thread) + + # Wait for local party and sender threads + process.join() + return_values[0] = queue.get() + for thread in threads: + thread.join() + if was_initialized: + crypten.init() + + return return_values + + return wrapper + + return decorator From 262ce1520c122343c9afacda2277e9fc30a1525e Mon Sep 17 00:00:00 2001 From: youben11 Date: Sat, 25 Jan 2020 15:34:14 +0100 Subject: [PATCH 04/13] updated docs --- syft/frameworks/crypten/context.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/syft/frameworks/crypten/context.py b/syft/frameworks/crypten/context.py index 7b1c035ff73..df0af6f604d 100644 --- a/syft/frameworks/crypten/context.py +++ b/syft/frameworks/crypten/context.py @@ -43,7 +43,7 @@ def run_party(func, rank, world_size, master_addr, master_port, func_args, func_ rank (int): rank of the crypten party. world_size (int): number of crypten parties involved in the computation. master_addr (str): IP address of the master party (party with rank 0). - master_port (int, str): port of the master party (party with rank 0). + master_port (int or str): port of the master party (party with rank 0). func_args (list): arguments to be passed to func. func_kwargs (dict): keyword arguments to be passed to func. @@ -92,6 +92,8 @@ def run_multiworkers(workers: list, master_addr: str, master_port: int = 15987): Args: workers (list): workers (parties) to be involved in the computation. + master_addr (str): IP address of the master party (party with rank 0). + master_port (int, str): port of the master party (party with rank 0), default is 15987. """ def decorator(func): From 00bda1676be2c059d54714c4439a578a9a0d07a0 Mon Sep 17 00:00:00 2001 From: youben11 Date: Sun, 26 Jan 2020 10:14:15 +0100 Subject: [PATCH 05/13] testing crypten context --- test/crypten/test_context.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 test/crypten/test_context.py diff --git a/test/crypten/test_context.py b/test/crypten/test_context.py new file mode 100644 index 00000000000..27236e6d673 --- /dev/null +++ b/test/crypten/test_context.py @@ -0,0 +1,24 @@ +import pytest +from syft.frameworks.crypten.context import run_multiworkers + + +def test_context(workers): + # self, alice and bob + n_workers = 3 + alice = workers["alice"] + bob = workers["bob"] + + @run_multiworkers([alice, bob], master_addr="127.0.0.1") + def test_three_parties(): + pass + + return_values = test_three_parties() + # A toy function is ran at each party, and they should all decrypt + # a tensor with value [90., 100.] + expected_value = [90.0, 100.0] + for rank in range(n_workers): + assert ( + return_values[rank] == expected_value + ), "Crypten party with rank {} don't match expected value {} != {}".format( + rank, return_values[rank], expected_value + ) From f349be16a15940612558669ca0b6f2f676b02194 Mon Sep 17 00:00:00 2001 From: youben11 Date: Sun, 26 Jan 2020 17:22:40 +0100 Subject: [PATCH 06/13] fix: was setting the bad env variable DISTRIBUTED_BACKEND should have been set and not BACKEND --- syft/frameworks/crypten/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syft/frameworks/crypten/context.py b/syft/frameworks/crypten/context.py index df0af6f604d..b555f8e3580 100644 --- a/syft/frameworks/crypten/context.py +++ b/syft/frameworks/crypten/context.py @@ -14,7 +14,7 @@ def _launch(func, rank, world_size, master_addr, master_port, queue, func_args, "RENDEZVOUS": "env://", "MASTER_ADDR": master_addr, "MASTER_PORT": master_port, - "BACKEND": "gloo", + "DISTRIBUTED_BACKEND": "gloo", } for key, val in communicator_args.items(): os.environ[key] = str(val) From eec0dc7c4a56958b60ff48c3bba2dd40579f52ca Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 27 Jan 2020 18:28:53 +0100 Subject: [PATCH 07/13] test serde of CryptenInit message --- test/serde/msgpack/test_msgpack_serde_full.py | 1 + test/serde/serde_helpers.py | 27 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/test/serde/msgpack/test_msgpack_serde_full.py b/test/serde/msgpack/test_msgpack_serde_full.py index 76210d090c7..721e6b2f122 100644 --- a/test/serde/msgpack/test_msgpack_serde_full.py +++ b/test/serde/msgpack/test_msgpack_serde_full.py @@ -78,6 +78,7 @@ samples[syft.messaging.message.ForceObjectDeleteMessage] = make_forceobjectdeletemessage samples[syft.messaging.message.SearchMessage] = make_searchmessage samples[syft.messaging.message.PlanCommandMessage] = make_plancommandmessage +samples[syft.messaging.message.CryptenInit] = make_crypteninit samples[syft.frameworks.torch.tensors.interpreters.gradients_core.GradFunc] = make_gradfn diff --git a/test/serde/serde_helpers.py b/test/serde/serde_helpers.py index 302d5eb8e7f..4373b324d26 100644 --- a/test/serde/serde_helpers.py +++ b/test/serde/serde_helpers.py @@ -1318,6 +1318,33 @@ def compare(detailed, original): ] +# syft.messaging.message.CryptenInit +def make_crypteninit(**kwargs): + def compare(detailed, original): + assert type(detailed) == syft.messaging.message.CryptenInit + assert detailed.contents == original.contents + return True + + return [ + { + "value": syft.messaging.message.CryptenInit([0, 2, "127.0.01", 8080]), + "simplified": ( + CODE[syft.messaging.message.CryptenInit], + ((CODE[list], (0, 2, "127.0.01", 8080)),), # (Any) simplified content + ), + "cmp_detailed": compare, + }, + { + "value": syft.messaging.message.CryptenInit((0, 2, "127.0.01", 8080)), + "simplified": ( + CODE[syft.messaging.message.CryptenInit], + ((CODE[tuple], (0, 2, "127.0.01", 8080)),), # (Any) simplified content + ), + "cmp_detailed": compare, + }, + ] + + # syft.messaging.message.Operation def make_operation(**kwargs): bob = kwargs["workers"]["bob"] From c165f9ec256668178593e6821894016712e72aab Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 27 Jan 2020 19:14:16 +0100 Subject: [PATCH 08/13] add crypten as core deps, this should change to extra --- pip-dep/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/pip-dep/requirements.txt b/pip-dep/requirements.txt index 37a20496bc4..4f39ff34890 100644 --- a/pip-dep/requirements.txt +++ b/pip-dep/requirements.txt @@ -13,3 +13,4 @@ Pillow<7 websocket_client>=0.56.0 websockets>=7.0 zstd>=1.4.0.0 +git+https://github.com/facebookresearch/CrypTen.git@68e0364c66df95ddbb98422fb641382c3f58734c#egg=crypten From 66051a33d4241b18bff9454edc157f1f21c7daed Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 29 Jan 2020 18:57:45 +0100 Subject: [PATCH 09/13] delete useless comment --- syft/workers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syft/workers/base.py b/syft/workers/base.py index 5331dc1744c..ec729a1768e 100644 --- a/syft/workers/base.py +++ b/syft/workers/base.py @@ -120,7 +120,7 @@ def __init__( GetShapeMessage: self.get_tensor_shape, SearchMessage: self.search, ForceObjectDeleteMessage: self.force_rm_obj, - CryptenInit: self.run_crypten_party, # TODO: update Message to CryptenInit after implementing it + CryptenInit: self.run_crypten_party, } self._plan_command_router = { From 4475674e9455b1f67554d523a9a387aca5b50a13 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 30 Jan 2020 18:27:58 +0100 Subject: [PATCH 10/13] fix: list CryptenInit in OBJ_SIMPLIFIER_AND_DETAILERS --- syft/messaging/message.py | 15 +++++++++++++++ syft/serde/msgpack/serde.py | 2 ++ 2 files changed, 17 insertions(+) diff --git a/syft/messaging/message.py b/syft/messaging/message.py index 3dbfc383692..18f9cff8a5e 100644 --- a/syft/messaging/message.py +++ b/syft/messaging/message.py @@ -117,6 +117,21 @@ def __init__(self, contents): @staticmethod def detail(worker: AbstractWorker, msg_tuple: tuple) -> "CryptenInit": + """ + This function takes the simplified tuple version of this message and converts + it into an CryptenInit. The simplify() method runs the inverse of this method. + + Args: + worker (AbstractWorker): a reference to the worker necessary for detailing. Read + syft/serde/serde.py for more information on why this is necessary. + msg_tuple (Tuple): the raw information being detailed. + + Returns: + CryptenInit message. + + Examples: + message = detail(sy.local_worker, msg_tuple) + """ return CryptenInit(sy.serde.msgpack.serde._detail(worker, msg_tuple[0])) diff --git a/syft/serde/msgpack/serde.py b/syft/serde/msgpack/serde.py index 5bf9f017215..6826649b2ca 100644 --- a/syft/serde/msgpack/serde.py +++ b/syft/serde/msgpack/serde.py @@ -67,6 +67,7 @@ from syft.messaging.message import ForceObjectDeleteMessage from syft.messaging.message import SearchMessage from syft.messaging.message import PlanCommandMessage +from syft.messaging.message import CryptenInit from syft.serde import compression from syft.serde.msgpack.native_serde import MAP_NATIVE_SIMPLIFIERS_AND_DETAILERS from syft.workers.abstract import AbstractWorker @@ -130,6 +131,7 @@ ForceObjectDeleteMessage, SearchMessage, PlanCommandMessage, + CryptenInit, GradFunc, String, ] From 2644a1a977ea8819ed546230a9ea48a59ef542c5 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 31 Jan 2020 11:32:31 +0100 Subject: [PATCH 11/13] fix msgpack tests --- test/serde/serde_helpers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/serde/serde_helpers.py b/test/serde/serde_helpers.py index 4373b324d26..f86730f8902 100644 --- a/test/serde/serde_helpers.py +++ b/test/serde/serde_helpers.py @@ -1327,18 +1327,18 @@ def compare(detailed, original): return [ { - "value": syft.messaging.message.CryptenInit([0, 2, "127.0.01", 8080]), + "value": syft.messaging.message.CryptenInit([0, 2, "127.0.0.1", 8080]), "simplified": ( CODE[syft.messaging.message.CryptenInit], - ((CODE[list], (0, 2, "127.0.01", 8080)),), # (Any) simplified content + ((CODE[list], (0, 2, (CODE[str], (b"127.0.0.1",)), 8080)),), # (Any) simplified content ), "cmp_detailed": compare, }, { - "value": syft.messaging.message.CryptenInit((0, 2, "127.0.01", 8080)), + "value": syft.messaging.message.CryptenInit((0, 2, "127.0.0.1", 8080)), "simplified": ( CODE[syft.messaging.message.CryptenInit], - ((CODE[tuple], (0, 2, "127.0.01", 8080)),), # (Any) simplified content + ((CODE[tuple], (0, 2, (CODE[str], (b"127.0.0.1",)), 8080)),), # (Any) simplified content ), "cmp_detailed": compare, }, From ff7f45ceaeb123528ebf444a1fe0e02dd00db458 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 31 Jan 2020 12:01:15 +0100 Subject: [PATCH 12/13] run black --- test/serde/serde_helpers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/serde/serde_helpers.py b/test/serde/serde_helpers.py index f86730f8902..2c45534b05d 100644 --- a/test/serde/serde_helpers.py +++ b/test/serde/serde_helpers.py @@ -1330,7 +1330,9 @@ def compare(detailed, original): "value": syft.messaging.message.CryptenInit([0, 2, "127.0.0.1", 8080]), "simplified": ( CODE[syft.messaging.message.CryptenInit], - ((CODE[list], (0, 2, (CODE[str], (b"127.0.0.1",)), 8080)),), # (Any) simplified content + ( + (CODE[list], (0, 2, (CODE[str], (b"127.0.0.1",)), 8080)), + ), # (Any) simplified content ), "cmp_detailed": compare, }, @@ -1338,7 +1340,9 @@ def compare(detailed, original): "value": syft.messaging.message.CryptenInit((0, 2, "127.0.0.1", 8080)), "simplified": ( CODE[syft.messaging.message.CryptenInit], - ((CODE[tuple], (0, 2, (CODE[str], (b"127.0.0.1",)), 8080)),), # (Any) simplified content + ( + (CODE[tuple], (0, 2, (CODE[str], (b"127.0.0.1",)), 8080)), + ), # (Any) simplified content ), "cmp_detailed": compare, }, From 3c885d89617db62a21ebdde85805d41f4878d448 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 31 Jan 2020 17:16:18 +0100 Subject: [PATCH 13/13] don't cover empty function --- test/crypten/test_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/crypten/test_context.py b/test/crypten/test_context.py index 27236e6d673..69e789db35c 100644 --- a/test/crypten/test_context.py +++ b/test/crypten/test_context.py @@ -10,7 +10,7 @@ def test_context(workers): @run_multiworkers([alice, bob], master_addr="127.0.0.1") def test_three_parties(): - pass + pass # pragma: no cover return_values = test_three_parties() # A toy function is ran at each party, and they should all decrypt