Skip to content

Commit

Permalink
[Core][pubsub] handle failures when publish failed. (ray-project#33115)
Browse files Browse the repository at this point in the history
Why are these changes needed?
ray-project#32046 indicating that the pubsub might lose data, especially when the subscriber is under load. After examine the protocol it seems one bug is that the publisher fails to handle publish failures. i.e. when we push message in mailbox, we will delete the message being sent regardless of RPC failures.

This PR tries to address the problem by adding monotonically increasing sequence_id to each message, and only delete messages when the subscriber acknowledged a message has been received.

The sequence_id sequences is also generated per publisher, regardless of channels. This means if there exists multiple channels for the same publisher, each channel might not see contiguous sequences. This also assumes the invariant that a subscriber object will only subscribe to one publisher.

We also relies on the pubsub protocol that at most one going push request will be inflight.

This also handles the case gcs failover. We do so by track the publisher_id between both publisher and subscriber. When gcs failover, the publisher_id will be different, thus both the publisher and subscriber will forget the information about previous state.
  • Loading branch information
scv119 committed Apr 19, 2023
1 parent 8c279d5 commit 897a282
Show file tree
Hide file tree
Showing 20 changed files with 519 additions and 122 deletions.
33 changes: 32 additions & 1 deletion python/ray/_private/gcs_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __init__(self, worker_id: bytes = None):
# SubscriberID / UniqueID, which is 28 (kUniqueIDSize) random bytes.
self._subscriber_id = bytes(bytearray(random.getrandbits(8) for _ in range(28)))
self._last_batch_size = 0
self._max_processed_sequence_id = 0
self._publisher_id = b""

# Batch size of the result from last poll. Used to indicate whether the
# subscriber can keep up.
Expand All @@ -91,7 +93,9 @@ def _subscribe_request(self, channel):

def _poll_request(self):
return gcs_service_pb2.GcsSubscriberPollRequest(
subscriber_id=self._subscriber_id
subscriber_id=self._subscriber_id,
max_processed_sequence_id=self._max_processed_sequence_id,
publisher_id=self._publisher_id,
)

def _unsubscribe_request(self, channels):
Expand Down Expand Up @@ -272,7 +276,21 @@ def _poll_locked(self, timeout=None) -> None:

if fut.done():
self._last_batch_size = len(fut.result().pub_messages)
if fut.result().publisher_id != self._publisher_id:
if self._publisher_id != "":
logger.debug(
f"replied publisher_id {fut.result().publisher_id} "
f"different from {self._publisher_id}, this should "
"only happens during gcs failover."
)
self._publisher_id = fut.result().publisher_id
self._max_processed_sequence_id = 0

for msg in fut.result().pub_messages:
if msg.sequence_id <= self._max_processed_sequence_id:
logger.warn(f"Ignoring out of order message {msg}")
continue
self._max_processed_sequence_id = msg.sequence_id
if msg.channel_type != self._channel:
logger.warn(f"Ignoring message from unsubscribed channel {msg}")
continue
Expand Down Expand Up @@ -538,7 +556,20 @@ async def _poll(self, timeout=None) -> None:
break
try:
self._last_batch_size = len(poll.result().pub_messages)
if poll.result().publisher_id != self._publisher_id:
if self._publisher_id != "":
logger.debug(
f"replied publisher_id {poll.result().publisher_id}"
f"different from {self._publisher_id}, this should "
"only happens during gcs failover."
)
self._publisher_id = poll.result().publisher_id
self._max_processed_sequence_id = 0
for msg in poll.result().pub_messages:
if msg.sequence_id <= self._max_processed_sequence_id:
logger.warn(f"Ignoring out of order message {msg}")
continue
self._max_processed_sequence_id = msg.sequence_id
self._queue.append(msg)
except grpc.RpcError as e:
if self._should_terminate_polling(e):
Expand Down
44 changes: 44 additions & 0 deletions python/ray/tests/test_gcs_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
wait_for_pid_to_exit,
run_string_as_driver,
)
from ray._private.gcs_pubsub import (
GcsPublisher,
GcsErrorSubscriber,
)
from ray.core.generated.gcs_pb2 import ErrorTableData

import psutil

Expand Down Expand Up @@ -649,6 +654,45 @@ def pid(self):
ray.get_actor("A")


@pytest.mark.parametrize(
"ray_start_regular_with_external_redis",
[
generate_system_config_map(
gcs_failover_worker_reconnect_timeout=20,
gcs_rpc_server_reconnect_timeout_s=60,
gcs_server_request_timeout_seconds=10,
)
],
indirect=True,
)
@pytest.mark.skip(
reason="python publisher and subscriber doesn't handle gcs server failover"
)
def test_publish_and_subscribe_error_info(ray_start_regular_with_external_redis):
address_info = ray_start_regular_with_external_redis
gcs_server_addr = address_info["gcs_address"]

subscriber = GcsErrorSubscriber(address=gcs_server_addr)
subscriber.subscribe()

publisher = GcsPublisher(address=gcs_server_addr)
err1 = ErrorTableData(error_message="test error message 1")
err2 = ErrorTableData(error_message="test error message 2")
print("sending error message 1")
publisher.publish_error(b"aaa_id", err1)

ray._private.worker._global_node.kill_gcs_server()
ray._private.worker._global_node.start_gcs_server()

print("sending error message 2")
publisher.publish_error(b"bbb_id", err2)
print("done")

assert subscriber.poll() == (b"bbb_id", err2)

subscriber.close()


@pytest.fixture
def redis_replicas(monkeypatch):
monkeypatch.setenv("TEST_EXTERNAL_REDIS_REPLICAS", "3")
Expand Down
4 changes: 2 additions & 2 deletions src/mock/ray/pubsub/publisher.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class MockPublisherInterface : public PublisherInterface {
const SubscriberID &subscriber_id,
const std::optional<std::string> &key_id),
(override));
MOCK_METHOD(void, Publish, (const rpc::PubMessage &pub_message), (override));
MOCK_METHOD(void, Publish, (rpc::PubMessage pub_message), (override));
MOCK_METHOD(void,
PublishFailure,
(const rpc::ChannelType channel_type, const std::string &key_id),
Expand All @@ -86,7 +86,7 @@ class MockPublisher : public Publisher {
const SubscriberID &subscriber_id,
const std::optional<std::string> &key_id),
(override));
MOCK_METHOD(void, Publish, (const rpc::PubMessage &pub_message), (override));
MOCK_METHOD(void, Publish, (rpc::PubMessage pub_message), (override));
MOCK_METHOD(void,
PublishFailure,
(const rpc::ChannelType channel_type, const std::string &key_id),
Expand Down
5 changes: 3 additions & 2 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
/*periodical_runner=*/&periodical_runner_,
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
/*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(),
/*publish_batch_size_=*/RayConfig::instance().publish_batch_size());
/*publish_batch_size_=*/RayConfig::instance().publish_batch_size(),
GetWorkerID());
object_info_subscriber_ = std::make_unique<pubsub::Subscriber>(
/*subscriber_id=*/GetWorkerID(),
/*channels=*/
Expand Down Expand Up @@ -3121,7 +3122,7 @@ void CoreWorker::ProcessSubscribeForObjectEviction(
pub_message.mutable_worker_object_eviction_message()->set_object_id(
object_id.Binary());

object_info_publisher_->Publish(pub_message);
object_info_publisher_->Publish(std::move(pub_message));
};

const auto object_id = ObjectID::FromBinary(message.object_id());
Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker/reference_count.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ void ReferenceCounter::HandleRefRemoved(const ObjectID &object_id) {
RAY_LOG(DEBUG) << "Publishing WaitForRefRemoved message for " << object_id
<< ", message has " << worker_ref_removed_message->borrowed_refs().size()
<< " borrowed references.";
object_info_publisher_->Publish(pub_message);
object_info_publisher_->Publish(std::move(pub_message));
}

void ReferenceCounter::SetRefRemovedCallback(
Expand Down Expand Up @@ -1459,7 +1459,7 @@ void ReferenceCounter::PushToLocationSubscribers(ReferenceTable::iterator it) {
auto object_locations_msg = pub_message.mutable_worker_object_locations_message();
FillObjectInformationInternal(it, object_locations_msg);

object_info_publisher_->Publish(pub_message);
object_info_publisher_->Publish(std::move(pub_message));
}

Status ReferenceCounter::FillObjectInformation(
Expand Down
5 changes: 3 additions & 2 deletions src/ray/core_worker/test/reference_count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ class MockDistributedSubscriber : public pubsub::SubscriberInterface {
subscriber_id,
/*get_time_ms=*/[]() { return 1.0; },
/*subscriber_timeout_ms=*/1000,
/*publish_batch_size=*/1000)),
/*publish_batch_size=*/1000,
UniqueID::FromRandom())),
client_factory_(client_factory) {}

~MockDistributedSubscriber() = default;
Expand Down Expand Up @@ -249,7 +250,7 @@ class MockDistributedPublisher : public pubsub::PublisherInterface {
void PublishFailure(const rpc::ChannelType channel_type,
const std::string &key_id_binary) {}

void Publish(const rpc::PubMessage &pub_message) {
void Publish(rpc::PubMessage pub_message) {
if (pub_message.channel_type() == rpc::ChannelType::WORKER_OBJECT_LOCATIONS_CHANNEL) {
// TODO(swang): Test object locations pubsub too.
return;
Expand Down
3 changes: 3 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ void GcsSubscriberClient::PubsubLongPolling(
const rpc::ClientCallback<rpc::PubsubLongPollingReply> &callback) {
rpc::GcsSubscriberPollRequest req;
req.set_subscriber_id(request.subscriber_id());
req.set_max_processed_sequence_id(request.max_processed_sequence_id());
req.set_publisher_id(request.publisher_id());
rpc_client_->GcsSubscriberPoll(
req,
[callback](const Status &status, const rpc::GcsSubscriberPollReply &poll_reply) {
rpc::PubsubLongPollingReply reply;
*reply.mutable_pub_messages() = poll_reply.pub_messages();
*reply.mutable_publisher_id() = poll_reply.publisher_id();
callback(status, reply);
});
}
Expand Down
3 changes: 2 additions & 1 deletion src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config,
/*periodical_runner=*/&pubsub_periodical_runner_,
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
/*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(),
/*publish_batch_size_=*/RayConfig::instance().publish_batch_size());
/*publish_batch_size_=*/RayConfig::instance().publish_batch_size(),
/*publisher_id=*/NodeID::FromRandom());

gcs_publisher_ = std::make_shared<GcsPublisher>(std::move(inner_publisher));
}
Expand Down
4 changes: 4 additions & 0 deletions src/ray/gcs/gcs_server/pubsub_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void InternalPubSubHandler::HandleGcsPublish(rpc::GcsPublishRequest request,
nullptr);
return;
}
RAY_LOG(DEBUG) << "received publish request: " << request.DebugString();
for (const auto &msg : request.pub_messages()) {
gcs_publisher_->GetPublisher()->Publish(msg);
}
Expand All @@ -63,6 +64,8 @@ void InternalPubSubHandler::HandleGcsSubscriberPoll(
}
rpc::PubsubLongPollingRequest pubsub_req;
pubsub_req.set_subscriber_id(request.subscriber_id());
pubsub_req.set_publisher_id(request.publisher_id());
pubsub_req.set_max_processed_sequence_id(request.max_processed_sequence_id());
auto pubsub_reply = std::make_shared<rpc::PubsubLongPollingReply>();
auto pubsub_reply_ptr = pubsub_reply.get();
gcs_publisher_->GetPublisher()->ConnectToSubscriber(
Expand All @@ -74,6 +77,7 @@ void InternalPubSubHandler::HandleGcsSubscriberPoll(
std::function<void()> success_cb,
std::function<void()> failure_cb) {
reply->mutable_pub_messages()->Swap(pubsub_reply->mutable_pub_messages());
reply->set_publisher_id(std::move(*pubsub_reply->mutable_publisher_id()));
reply_cb(std::move(status), std::move(success_cb), std::move(failure_cb));
});
}
Expand Down
10 changes: 5 additions & 5 deletions src/ray/gcs/pubsub/gcs_pub_sub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Status GcsPublisher::PublishActor(const ActorID &id,
msg.set_channel_type(rpc::ChannelType::GCS_ACTOR_CHANNEL);
msg.set_key_id(id.Binary());
*msg.mutable_actor_message() = message;
publisher_->Publish(msg);
publisher_->Publish(std::move(msg));
if (done != nullptr) {
done(Status::OK());
}
Expand All @@ -40,7 +40,7 @@ Status GcsPublisher::PublishJob(const JobID &id,
msg.set_channel_type(rpc::ChannelType::GCS_JOB_CHANNEL);
msg.set_key_id(id.Binary());
*msg.mutable_job_message() = message;
publisher_->Publish(msg);
publisher_->Publish(std::move(msg));
if (done != nullptr) {
done(Status::OK());
}
Expand All @@ -54,7 +54,7 @@ Status GcsPublisher::PublishNodeInfo(const NodeID &id,
msg.set_channel_type(rpc::ChannelType::GCS_NODE_INFO_CHANNEL);
msg.set_key_id(id.Binary());
*msg.mutable_node_info_message() = message;
publisher_->Publish(msg);
publisher_->Publish(std::move(msg));
if (done != nullptr) {
done(Status::OK());
}
Expand All @@ -68,7 +68,7 @@ Status GcsPublisher::PublishWorkerFailure(const WorkerID &id,
msg.set_channel_type(rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL);
msg.set_key_id(id.Binary());
*msg.mutable_worker_delta_message() = message;
publisher_->Publish(msg);
publisher_->Publish(std::move(msg));
if (done != nullptr) {
done(Status::OK());
}
Expand All @@ -82,7 +82,7 @@ Status GcsPublisher::PublishError(const std::string &id,
msg.set_channel_type(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
msg.set_key_id(id);
*msg.mutable_error_info_message() = message;
publisher_->Publish(msg);
publisher_->Publish(std::move(msg));
if (done != nullptr) {
done(Status::OK());
}
Expand Down
9 changes: 9 additions & 0 deletions src/ray/protobuf/gcs_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,20 @@ message GcsPublishReply {
message GcsSubscriberPollRequest {
/// The id of the subscriber.
bytes subscriber_id = 1;
/// The max squence_id that has been processed by the subscriber. The Publisher
/// will drop queued messages with smaller sequence_id for this subscriber.
int64 max_processed_sequence_id = 2;
/// The expected publisher_id. The publisher will ignore the
/// max_processed_sequence_id if the publisher_id doesn't match.
/// This usuall happens when gcs failover.
bytes publisher_id = 3;
}

message GcsSubscriberPollReply {
/// The messages that are published.
repeated PubMessage pub_messages = 1;
/// The publisher's id.
bytes publisher_id = 2;
// Not populated.
GcsStatus status = 100;
}
Expand Down
11 changes: 11 additions & 0 deletions src/ray/protobuf/pubsub.proto
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ message PubMessage {
// The message that indicates the given key id is not available anymore.
FailureMessage failure_message = 6;
}
/// A monotonically increasing sequence_id generated by the publisher.
int64 sequence_id = 16;
}

message WorkerObjectEvictionMessage {
Expand Down Expand Up @@ -202,11 +204,20 @@ message WorkerObjectLocationsSubMessage {
message PubsubLongPollingRequest {
/// The id of the subscriber.
bytes subscriber_id = 1;
/// The max squence_id that has been processed by the subscriber. The Publisher
/// will drop queued messages with smaller sequence_id for this subscriber.
int64 max_processed_sequence_id = 2;
/// The expected publisher_id. The publisher will ignore the
/// max_processed_sequence_id if the publisher_id doesn't match.
/// This usuall happens when gcs failover.
bytes publisher_id = 3;
}

message PubsubLongPollingReply {
/// The messages that are published.
repeated PubMessage pub_messages = 1;
/// The publisher_id.
bytes publisher_id = 2;
}

message PubsubCommandBatchRequest {
Expand Down
2 changes: 1 addition & 1 deletion src/ray/pubsub/mock_pubsub.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class MockPublisher : public pubsub::PublisherInterface {
const pubsub::SubscriberID &subscriber_id,
const std::optional<std::string> &key_id));

MOCK_METHOD1(Publish, void(const rpc::PubMessage &pub_message));
MOCK_METHOD1(Publish, void(rpc::PubMessage pub_message));

MOCK_METHOD3(UnregisterSubscription,
bool(const rpc::ChannelType channel_type,
Expand Down
Loading

0 comments on commit 897a282

Please sign in to comment.