Skip to content

Commit

Permalink
Context of computation for crypten (OpenMined#2963)
Browse files Browse the repository at this point in the history
* New message type for crypten party initialization

* Handle CryptenInit message by running local party

* Define a new context of computation for crypten

The new context of computation can run parties that are distributed
across syft workers. The communication of crypten parties remains the
same, however, syft workers handle initialization and the serialization
of return values.

* updated docs

* testing crypten context

* fix: was setting the bad env variable

DISTRIBUTED_BACKEND should have been set and not BACKEND

* test serde of CryptenInit message

* add crypten as core deps, this should change to extra

* delete useless comment

* fix: list CryptenInit in OBJ_SIMPLIFIER_AND_DETAILERS

* fix msgpack tests

* run black

* don't cover empty function
  • Loading branch information
youben11 authored and gmuraru committed Feb 14, 2020
1 parent 0831588 commit 02288cb
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 0 deletions.
1 change: 1 addition & 0 deletions pip-dep/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ torchvision==0.5.0
websocket_client>=0.56.0
websockets>=7.0
zstd>=1.4.0.0
git+https://github.com/facebookresearch/CrypTen.git@68e0364c66df95ddbb98422fb641382c3f58734c#egg=crypten
4 changes: 4 additions & 0 deletions syft/frameworks/crypten/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from syft.frameworks.crypten.context import toy_func, run_party


__all__ = ["toy_func", "run_party"]
152 changes: 152 additions & 0 deletions syft/frameworks/crypten/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import functools
import multiprocessing
import threading
import os
import crypten
from crypten.communicator import DistributedCommunicator
from syft.messaging.message import CryptenInit


def _launch(func, rank, world_size, master_addr, master_port, queue, func_args, func_kwargs):
communicator_args = {
"RANK": rank,
"WORLD_SIZE": world_size,
"RENDEZVOUS": "env:https://",
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
"DISTRIBUTED_BACKEND": "gloo",
}
for key, val in communicator_args.items():
os.environ[key] = str(val)

crypten.init()
return_value = func(*func_args, **func_kwargs)
crypten.uninit()

queue.put(return_value)


def _new_party(func, rank, world_size, master_addr, master_port, func_args, func_kwargs):
queue = multiprocessing.Queue()
process = multiprocessing.Process(
target=_launch,
args=(func, rank, world_size, master_addr, master_port, queue, func_args, func_kwargs),
)
return process, queue


def run_party(func, rank, world_size, master_addr, master_port, func_args, func_kwargs):
"""Start crypten party localy and run computation.
Args:
func (function): computation to be done.
rank (int): rank of the crypten party.
world_size (int): number of crypten parties involved in the computation.
master_addr (str): IP address of the master party (party with rank 0).
master_port (int or str): port of the master party (party with rank 0).
func_args (list): arguments to be passed to func.
func_kwargs (dict): keyword arguments to be passed to func.
Returns:
The return value of func.
"""

process, queue = _new_party(
func, rank, world_size, master_addr, master_port, func_args, func_kwargs
)
was_initialized = DistributedCommunicator.is_initialized()
if was_initialized:
crypten.uninit()
process.start()
process.join()
if was_initialized:
crypten.init()
return queue.get()


def _send_party_info(worker, rank, msg, return_values):
"""Send message to worker with necessary information to run a crypten party.
Add response to return_values dictionary.
Args:
worker (BaseWorker): worker to send the message to.
rank (int): rank of the crypten party.
msg (CryptenInit): message containing the rank, world_size, master_addr and master_port.
return_values (dict): dictionnary holding return values of workers.
"""

response = worker.send_msg(msg, worker)
return_values[rank] = response.contents


def toy_func():
# Toy function to be called by each party
alice_t = crypten.cryptensor([73, 81], src=0)
bob_t = crypten.cryptensor([90, 100], src=1)
out = bob_t.get_plain_text()
return out.tolist() # issues with putting torch tensors into queues


def run_multiworkers(workers: list, master_addr: str, master_port: int = 15987):
"""Defines decorator to run function across multiple workers.
Args:
workers (list): workers (parties) to be involved in the computation.
master_addr (str): IP address of the master party (party with rank 0).
master_port (int, str): port of the master party (party with rank 0), default is 15987.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# TODO:
# - check if workers are reachable / they can handle the computation
# - check return code of processes for possible failure

world_size = len(workers) + 1
return_values = {rank: None for rank in range(world_size)}

# Start local party
process, queue = _new_party(toy_func, 0, world_size, master_addr, master_port, (), {})
was_initialized = DistributedCommunicator.is_initialized()
if was_initialized:
crypten.uninit()
process.start()
# Run TTP if required
# TODO: run ttp in a specified worker
if crypten.mpc.ttp_required():
ttp_process, _ = _new_party(
crypten.mpc.provider.TTPServer,
world_size,
world_size,
master_addr,
master_port,
(),
{},
)
ttp_process.start()

# Send messages to other workers so they start their parties
threads = []
for i in range(len(workers)):
rank = i + 1
msg = CryptenInit((rank, world_size, master_addr, master_port))
thread = threading.Thread(
target=_send_party_info, args=(workers[i], rank, msg, return_values)
)
thread.start()
threads.append(thread)

# Wait for local party and sender threads
process.join()
return_values[0] = queue.get()
for thread in threads:
thread.join()
if was_initialized:
crypten.init()

return return_values

return wrapper

return decorator
31 changes: 31 additions & 0 deletions syft/messaging/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,37 @@ def __repr__(self):
return self.__str__()


class CryptenInit(Message):
"""Initialize a Crypten party using this message.
Crypten uses processes as parties, those processes need to be initialized with information
so they can communicate and exchange tensors and shares while doing computation. This message
allows the exchange of information such as the ip and port of the master party to connect to,
as well as the rank of the party to run and the number of parties involved."""

def __init__(self, contents):
super().__init__(contents)

@staticmethod
def detail(worker: AbstractWorker, msg_tuple: tuple) -> "CryptenInit":
"""
This function takes the simplified tuple version of this message and converts
it into an CryptenInit. The simplify() method runs the inverse of this method.
Args:
worker (AbstractWorker): a reference to the worker necessary for detailing. Read
syft/serde/serde.py for more information on why this is necessary.
msg_tuple (Tuple): the raw information being detailed.
Returns:
CryptenInit message.
Examples:
message = detail(sy.local_worker, msg_tuple)
"""
return CryptenInit(sy.serde.msgpack.serde._detail(worker, msg_tuple[0]))


class Operation(Message):
"""All syft operations use this message type
Expand Down
2 changes: 2 additions & 0 deletions syft/serde/msgpack/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from syft.messaging.message import ForceObjectDeleteMessage
from syft.messaging.message import SearchMessage
from syft.messaging.message import PlanCommandMessage
from syft.messaging.message import CryptenInit
from syft.serde import compression
from syft.serde.msgpack.native_serde import MAP_NATIVE_SIMPLIFIERS_AND_DETAILERS
from syft.workers.abstract import AbstractWorker
Expand Down Expand Up @@ -128,6 +129,7 @@
ForceObjectDeleteMessage,
SearchMessage,
PlanCommandMessage,
CryptenInit,
GradFunc,
String,
]
Expand Down
17 changes: 17 additions & 0 deletions syft/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
from syft.messaging.message import GetShapeMessage
from syft.messaging.message import PlanCommandMessage
from syft.messaging.message import SearchMessage
from syft.messaging.message import CryptenInit
from syft.messaging.plan import Plan
from syft.workers.abstract import AbstractWorker
from syft.frameworks.crypten import toy_func, run_party

from syft.exceptions import GetNotPermittedError
from syft.exceptions import WorkerNotFoundException
Expand Down Expand Up @@ -118,6 +120,7 @@ def __init__(
GetShapeMessage: self.get_tensor_shape,
SearchMessage: self.search,
ForceObjectDeleteMessage: self.force_rm_obj,
CryptenInit: self.run_crypten_party,
}

self._plan_command_router = {
Expand Down Expand Up @@ -396,6 +399,20 @@ def send(

return pointer

def run_crypten_party(self, message: tuple):
"""Run crypten party according to the information received.
Args:
message (CryptenInit): should contain the rank, world_size, master_addr and master_port.
Returns:
An ObjectMessage containing the return value of the crypten function computed.
"""

rank, world_size, master_addr, master_port = message
return_value = run_party(toy_func, rank, world_size, master_addr, master_port, (), {})
return ObjectMessage(return_value)

def execute_command(self, message: tuple) -> PointerTensor:
"""
Executes commands received from other workers.
Expand Down
24 changes: 24 additions & 0 deletions test/crypten/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
from syft.frameworks.crypten.context import run_multiworkers


def test_context(workers):
# self, alice and bob
n_workers = 3
alice = workers["alice"]
bob = workers["bob"]

@run_multiworkers([alice, bob], master_addr="127.0.0.1")
def test_three_parties():
pass # pragma: no cover

return_values = test_three_parties()
# A toy function is ran at each party, and they should all decrypt
# a tensor with value [90., 100.]
expected_value = [90.0, 100.0]
for rank in range(n_workers):
assert (
return_values[rank] == expected_value
), "Crypten party with rank {} don't match expected value {} != {}".format(
rank, return_values[rank], expected_value
)
1 change: 1 addition & 0 deletions test/serde/msgpack/test_msgpack_serde_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
samples[syft.messaging.message.ForceObjectDeleteMessage] = make_forceobjectdeletemessage
samples[syft.messaging.message.SearchMessage] = make_searchmessage
samples[syft.messaging.message.PlanCommandMessage] = make_plancommandmessage
samples[syft.messaging.message.CryptenInit] = make_crypteninit

samples[syft.frameworks.torch.tensors.interpreters.gradients_core.GradFunc] = make_gradfn

Expand Down
31 changes: 31 additions & 0 deletions test/serde/serde_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,37 @@ def compare(detailed, original):
]


# syft.messaging.message.CryptenInit
def make_crypteninit(**kwargs):
def compare(detailed, original):
assert type(detailed) == syft.messaging.message.CryptenInit
assert detailed.contents == original.contents
return True

return [
{
"value": syft.messaging.message.CryptenInit([0, 2, "127.0.0.1", 8080]),
"simplified": (
CODE[syft.messaging.message.CryptenInit],
(
(CODE[list], (0, 2, (CODE[str], (b"127.0.0.1",)), 8080)),
), # (Any) simplified content
),
"cmp_detailed": compare,
},
{
"value": syft.messaging.message.CryptenInit((0, 2, "127.0.0.1", 8080)),
"simplified": (
CODE[syft.messaging.message.CryptenInit],
(
(CODE[tuple], (0, 2, (CODE[str], (b"127.0.0.1",)), 8080)),
), # (Any) simplified content
),
"cmp_detailed": compare,
},
]


# syft.messaging.message.Operation
def make_operation(**kwargs):
bob = kwargs["workers"]["bob"]
Expand Down

0 comments on commit 02288cb

Please sign in to comment.