Skip to content

Commit

Permalink
added more refactorings of BrokerInitializer code to speed up the sta…
Browse files Browse the repository at this point in the history
…rtup
  • Loading branch information
dmytro-landiak committed May 23, 2024
1 parent f5f3a38 commit e1656ed
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.thingsboard.mqtt.broker.actors.client.service;

import com.google.common.collect.Maps;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.context.event.ApplicationReadyEvent;
Expand All @@ -32,7 +33,6 @@
import org.thingsboard.mqtt.broker.actors.config.ActorSystemLifecycle;
import org.thingsboard.mqtt.broker.cluster.ServiceInfoProvider;
import org.thingsboard.mqtt.broker.common.data.ClientSessionInfo;
import org.thingsboard.mqtt.broker.common.data.SessionInfo;
import org.thingsboard.mqtt.broker.common.data.id.ActorType;
import org.thingsboard.mqtt.broker.common.data.subscription.TopicSubscription;
import org.thingsboard.mqtt.broker.exception.QueuePersistenceException;
Expand All @@ -48,14 +48,10 @@
import org.thingsboard.mqtt.broker.service.processing.downlink.basic.BasicDownLinkConsumer;
import org.thingsboard.mqtt.broker.service.processing.downlink.persistent.PersistentDownLinkConsumer;
import org.thingsboard.mqtt.broker.service.subscription.ClientSubscriptionConsumer;
import org.thingsboard.mqtt.broker.util.ClientSessionInfoFactory;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@Slf4j
@Component
Expand Down Expand Up @@ -89,13 +85,9 @@ public void onApplicationEvent(ApplicationReadyEvent event) {
log.info("Initializing Client Sessions and Subscriptions.");
try {
Map<String, ClientSessionInfo> allClientSessions = initClientSessions();

initClientSubscriptions(allClientSessions);

clearNonPersistentClientsOnCurrentNode(allClientSessions);

clientSessionService.startListening(clientSessionConsumer);

startSubscriptionListening();

initRetainedMessages();
Expand All @@ -109,17 +101,40 @@ public void onApplicationEvent(ApplicationReadyEvent event) {
}
}

Map<String, ClientSessionInfo> initClientSessions() throws QueuePersistenceException {
Map<String, ClientSessionInfo> allClientSessions = clientSessionConsumer.initLoad();
log.info("Loaded {} stored client sessions from Kafka.", allClientSessions.size());

Map<String, ClientSessionInfo> currentNodeSessions = filterAndDisconnectCurrentNodeSessions(allClientSessions);
allClientSessions.putAll(currentNodeSessions);

clientSessionService.init(allClientSessions);
return allClientSessions;
}

private Map<String, ClientSessionInfo> filterAndDisconnectCurrentNodeSessions(Map<String, ClientSessionInfo> allClientSessions) {
Map<String, ClientSessionInfo> currentNodeSessions = Maps.newHashMapWithExpectedSize(allClientSessions.size());
for (Map.Entry<String, ClientSessionInfo> entry : allClientSessions.entrySet()) {
ClientSessionInfo clientSessionInfo = entry.getValue();

if (sessionWasOnThisNode(clientSessionInfo)) {
if (isCleanSession(clientSessionInfo)) {
clientSessionEventService.requestClientSessionCleanup(clientSessionInfo);
}
ClientSessionInfo disconnectedClientSession = markDisconnected(clientSessionInfo);
currentNodeSessions.put(entry.getKey(), disconnectedClientSession);
}
}
log.info("{} client sessions were on {} node.", currentNodeSessions.size(), serviceInfoProvider.getServiceId());
return currentNodeSessions;
}

private void initRetainedMessages() throws QueuePersistenceException {
Map<String, RetainedMsg> allRetainedMessages = retainedMsgConsumer.initLoad();
log.info("Loaded {} stored retained messages from Kafka.", allRetainedMessages.size());
retainedMsgListenerService.init(allRetainedMessages);
}

private void clearNonPersistentClientsOnCurrentNode(Map<String, ClientSessionInfo> allClientSessions) {
Map<String, ClientSessionInfo> currentNodeSessions = filterSessions(allClientSessions);
clearNonPersistentClients(currentNodeSessions);
}

private void startConsuming() {
clientSessionEventConsumer.startConsuming();
publishMsgConsumerService.startConsuming();
Expand Down Expand Up @@ -149,36 +164,6 @@ private void removeSubscriptionIfSessionIsAbsent(Map<String, ClientSessionInfo>
}
}

Map<String, ClientSessionInfo> initClientSessions() throws QueuePersistenceException {
Map<String, ClientSessionInfo> allClientSessions = clientSessionConsumer.initLoad();
log.info("Loaded {} stored client sessions from Kafka.", allClientSessions.size());

Map<String, ClientSessionInfo> currentNodeSessions = filterSessionsAndDisconnect(allClientSessions);
log.info("{} client sessions were on {} node.", currentNodeSessions.size(), serviceInfoProvider.getServiceId());

allClientSessions.putAll(currentNodeSessions);

clientSessionService.init(allClientSessions);

return allClientSessions;
}

private Map<String, ClientSessionInfo> filterSessions(Map<String, ClientSessionInfo> allClientSessions) {
return filterCurrentNodeSessions(allClientSessions)
.collect(Collectors.toMap(this::getClientId, Function.identity()));
}

private Map<String, ClientSessionInfo> filterSessionsAndDisconnect(Map<String, ClientSessionInfo> allClientSessions) {
return filterCurrentNodeSessions(allClientSessions)
.map(this::markDisconnected)
.collect(Collectors.toMap(this::getClientId, Function.identity()));
}

private Stream<ClientSessionInfo> filterCurrentNodeSessions(Map<String, ClientSessionInfo> allClientSessions) {
return allClientSessions.values().stream()
.filter(this::sessionWasOnThisNode);
}

private void startSubscriptionListening() {
clientSubscriptionConsumer.listen((clientId, serviceId, topicSubscriptions) -> {
if (serviceInfoProvider.getServiceId().equals(serviceId)) {
Expand Down Expand Up @@ -208,15 +193,8 @@ private TbActorRef getActor(String clientId) {
return actorSystem.getActor(new TbTypeActorId(ActorType.CLIENT, clientId));
}

void clearNonPersistentClients(Map<String, ClientSessionInfo> currentNodeSessions) {
currentNodeSessions.values().stream()
.map(ClientSessionInfoFactory::clientSessionInfoToSessionInfo)
.filter(this::isCleanSession)
.forEach(clientSessionEventService::requestSessionCleanup);
}

boolean isCleanSession(SessionInfo sessionInfo) {
return sessionInfo.isCleanSession();
boolean isCleanSession(ClientSessionInfo clientSessionInfo) {
return clientSessionInfo.isCleanSession();
}

private boolean sessionWasOnThisNode(ClientSessionInfo clientSessionInfo) {
Expand All @@ -226,8 +204,4 @@ private boolean sessionWasOnThisNode(ClientSessionInfo clientSessionInfo) {
private ClientSessionInfo markDisconnected(ClientSessionInfo clientSessionInfo) {
return clientSessionInfo.toBuilder().connected(false).build();
}

private String getClientId(ClientSessionInfo clientSessionInfo) {
return clientSessionInfo.getClientId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ public static QueueProtos.SessionInfoProto convertToSessionInfoProto(SessionInfo
getSessionInfoProto(sessionInfo) : getSessionInfoProtoWithSessionExpiryInterval(sessionInfo);
}

public static QueueProtos.SessionInfoProto convertToSessionInfoProto(ClientSessionInfo clientSessionInfo) {
return clientSessionInfo.getSessionExpiryInterval() == -1 ?
getSessionInfoProto(clientSessionInfo) : getSessionInfoProtoWithSessionExpiryInterval(clientSessionInfo);
}

private static QueueProtos.SessionInfoProto getSessionInfoProto(SessionInfo sessionInfo) {
return QueueProtos.SessionInfoProto.newBuilder()
.setServiceInfo(QueueProtos.ServiceInfo.newBuilder().setServiceId(sessionInfo.getServiceId()).build())
Expand All @@ -221,6 +226,17 @@ private static QueueProtos.SessionInfoProto getSessionInfoProto(SessionInfo sess
.build();
}

private static QueueProtos.SessionInfoProto getSessionInfoProto(ClientSessionInfo clientSessionInfo) {
return QueueProtos.SessionInfoProto.newBuilder()
.setServiceInfo(QueueProtos.ServiceInfo.newBuilder().setServiceId(clientSessionInfo.getServiceId()).build())
.setSessionIdMSB(clientSessionInfo.getSessionId().getMostSignificantBits())
.setSessionIdLSB(clientSessionInfo.getSessionId().getLeastSignificantBits())
.setCleanStart(clientSessionInfo.isCleanStart())
.setClientInfo(convertToClientInfoProto(clientSessionInfo))
.setConnectionInfo(convertToConnectionInfoProto(clientSessionInfo))
.build();
}

private static QueueProtos.SessionInfoProto getSessionInfoProtoWithSessionExpiryInterval(SessionInfo sessionInfo) {
return QueueProtos.SessionInfoProto.newBuilder()
.setServiceInfo(QueueProtos.ServiceInfo.newBuilder().setServiceId(sessionInfo.getServiceId()).build())
Expand All @@ -233,6 +249,18 @@ private static QueueProtos.SessionInfoProto getSessionInfoProtoWithSessionExpiry
.build();
}

private static QueueProtos.SessionInfoProto getSessionInfoProtoWithSessionExpiryInterval(ClientSessionInfo clientSessionInfo) {
return QueueProtos.SessionInfoProto.newBuilder()
.setServiceInfo(QueueProtos.ServiceInfo.newBuilder().setServiceId(clientSessionInfo.getServiceId()).build())
.setSessionIdMSB(clientSessionInfo.getSessionId().getMostSignificantBits())
.setSessionIdLSB(clientSessionInfo.getSessionId().getLeastSignificantBits())
.setCleanStart(clientSessionInfo.isCleanStart())
.setClientInfo(convertToClientInfoProto(clientSessionInfo))
.setConnectionInfo(convertToConnectionInfoProto(clientSessionInfo))
.setSessionExpiryInterval(clientSessionInfo.getSessionExpiryInterval())
.build();
}

public static SessionInfo convertToSessionInfo(QueueProtos.SessionInfoProto sessionInfoProto) {
return SessionInfo.builder()
.serviceId(sessionInfoProto.getServiceInfo().getServiceId())
Expand All @@ -252,6 +280,14 @@ public static QueueProtos.ClientInfoProto convertToClientInfoProto(ClientInfo cl
.build() : QueueProtos.ClientInfoProto.getDefaultInstance();
}

public static QueueProtos.ClientInfoProto convertToClientInfoProto(ClientSessionInfo clientSessionInfo) {
return clientSessionInfo != null ? QueueProtos.ClientInfoProto.newBuilder()
.setClientId(clientSessionInfo.getClientId())
.setClientType(clientSessionInfo.getType().toString())
.setClientIpAdr(ByteString.copyFrom(clientSessionInfo.getClientIpAdr()))
.build() : QueueProtos.ClientInfoProto.getDefaultInstance();
}

public static ClientInfo convertToClientInfo(QueueProtos.ClientInfoProto clientInfoProto) {
return clientInfoProto != null ? ClientInfo.builder()
.clientId(clientInfoProto.getClientId())
Expand All @@ -268,6 +304,15 @@ public static QueueProtos.ConnectionInfoProto convertToConnectionInfoProto(Conne
.build() : QueueProtos.ConnectionInfoProto.getDefaultInstance();
}

public static QueueProtos.ConnectionInfoProto convertToConnectionInfoProto(ClientSessionInfo clientSessionInfo) {
return clientSessionInfo != null ? QueueProtos.ConnectionInfoProto.newBuilder()
.setConnectedAt(clientSessionInfo.getConnectedAt())
.setDisconnectedAt(clientSessionInfo.getDisconnectedAt())
.setKeepAlive(clientSessionInfo.getKeepAlive())
.build() : QueueProtos.ConnectionInfoProto.getDefaultInstance();
}


private static ConnectionInfo convertToConnectionInfo(QueueProtos.ConnectionInfoProto connectionInfoProto) {
return connectionInfoProto != null ? ConnectionInfo.builder()
.connectedAt(connectionInfoProto.getConnectedAt())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.thingsboard.mqtt.broker.service.mqtt.client.event;

import org.thingsboard.mqtt.broker.common.data.ClientInfo;
import org.thingsboard.mqtt.broker.common.data.ClientSessionInfo;
import org.thingsboard.mqtt.broker.common.data.SessionInfo;
import org.thingsboard.mqtt.broker.gen.queue.QueueProtos;

Expand All @@ -30,5 +31,7 @@ QueueProtos.ClientSessionEventProto createDisconnectedEventProto(ClientInfo clie

QueueProtos.ClientSessionEventProto createTryClearSessionRequestEventProto(SessionInfo sessionInfo);

QueueProtos.ClientSessionEventProto createTryClearSessionRequestEventProto(ClientSessionInfo clientSessionInfo);

QueueProtos.ClientSessionEventProto createApplicationTopicRemoveRequestProto(String clientId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.springframework.stereotype.Service;
import org.thingsboard.mqtt.broker.adaptor.ProtoConverter;
import org.thingsboard.mqtt.broker.common.data.ClientInfo;
import org.thingsboard.mqtt.broker.common.data.ClientSessionInfo;
import org.thingsboard.mqtt.broker.common.data.SessionInfo;
import org.thingsboard.mqtt.broker.common.util.BrokerConstants;
import org.thingsboard.mqtt.broker.gen.queue.QueueProtos;
Expand Down Expand Up @@ -58,6 +59,14 @@ public QueueProtos.ClientSessionEventProto createTryClearSessionRequestEventProt
.build();
}

@Override
public QueueProtos.ClientSessionEventProto createTryClearSessionRequestEventProto(ClientSessionInfo clientSessionInfo) {
return QueueProtos.ClientSessionEventProto.newBuilder()
.setSessionInfo(ProtoConverter.convertToSessionInfoProto(clientSessionInfo))
.setEventType(ClientSessionEventType.CLEAR_SESSION_REQUEST.toString())
.build();
}

@Override
public QueueProtos.ClientSessionEventProto createApplicationTopicRemoveRequestProto(String clientId) {
return QueueProtos.ClientSessionEventProto.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import com.google.common.util.concurrent.ListenableFuture;
import org.thingsboard.mqtt.broker.common.data.ClientInfo;
import org.thingsboard.mqtt.broker.common.data.ClientSessionInfo;
import org.thingsboard.mqtt.broker.common.data.SessionInfo;
import org.thingsboard.mqtt.broker.queue.TbQueueCallback;

Expand All @@ -32,5 +33,7 @@ public interface ClientSessionEventService {

void requestSessionCleanup(SessionInfo sessionInfo);

void requestClientSessionCleanup(ClientSessionInfo clientSessionInfo);

void requestApplicationTopicRemoved(String clientId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.thingsboard.mqtt.broker.adaptor.ProtoConverter;
import org.thingsboard.mqtt.broker.cluster.ServiceInfoProvider;
import org.thingsboard.mqtt.broker.common.data.ClientInfo;
import org.thingsboard.mqtt.broker.common.data.ClientSessionInfo;
import org.thingsboard.mqtt.broker.common.data.SessionInfo;
import org.thingsboard.mqtt.broker.common.util.ThingsBoardThreadFactory;
import org.thingsboard.mqtt.broker.gen.queue.QueueProtos;
Expand Down Expand Up @@ -133,6 +134,15 @@ public void requestSessionCleanup(SessionInfo sessionInfo) {
null);
}

@Override
public void requestClientSessionCleanup(ClientSessionInfo clientSessionInfo) {
sendEvent(
clientSessionInfo.getClientId(),
eventFactory.createTryClearSessionRequestEventProto(clientSessionInfo),
false,
null);
}

@Override
public void requestApplicationTopicRemoved(String clientId) {
sendEvent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.thingsboard.mqtt.broker.actors.client.service.subscription.ClientSubscriptionService;
import org.thingsboard.mqtt.broker.cluster.ServiceInfoProvider;
import org.thingsboard.mqtt.broker.common.data.ClientSessionInfo;
import org.thingsboard.mqtt.broker.common.data.SessionInfo;
import org.thingsboard.mqtt.broker.exception.QueuePersistenceException;
import org.thingsboard.mqtt.broker.service.mqtt.client.disconnect.DisconnectClientCommandConsumer;
import org.thingsboard.mqtt.broker.service.mqtt.client.event.ClientSessionEventConsumer;
Expand Down Expand Up @@ -128,21 +127,21 @@ private Map<String, ClientSessionInfo> prepareSessions() {

@Test
public void testIsNotPersistent() {
SessionInfo clientSessionInfo1 = getSessionInfo(false, 0);
SessionInfo clientSessionInfo2 = getSessionInfo(false, 10);
SessionInfo clientSessionInfo3 = getSessionInfo(true, 0);
SessionInfo clientSessionInfo4 = getSessionInfo(true, 10);
ClientSessionInfo clientSessionInfo1 = getSessionInfo(false, 0);
ClientSessionInfo clientSessionInfo2 = getSessionInfo(false, 10);
ClientSessionInfo clientSessionInfo3 = getSessionInfo(true, 0);
ClientSessionInfo clientSessionInfo4 = getSessionInfo(true, 10);

Assert.assertFalse(brokerInitializer.isCleanSession(clientSessionInfo1));
Assert.assertFalse(brokerInitializer.isCleanSession(clientSessionInfo2));
Assert.assertTrue(brokerInitializer.isCleanSession(clientSessionInfo3));
Assert.assertFalse(brokerInitializer.isCleanSession(clientSessionInfo4));
}

private SessionInfo getSessionInfo(boolean cleanStart, int sessionExpiryInterval) {
return SessionInfo.builder()
private ClientSessionInfo getSessionInfo(boolean cleanStart, int sessionExpiryInterval) {
return ClientSessionInfo.builder()
.cleanStart(cleanStart)
.sessionExpiryInterval(sessionExpiryInterval)
.build();
}
}
}

0 comments on commit e1656ed

Please sign in to comment.