Skip to content

Commit

Permalink
Merge pull request #124 from thingsboard/bug/subscription-matching
Browse files Browse the repository at this point in the history
[1.3.1] Fix subscription matching
  • Loading branch information
dmytro-landiak committed Jun 10, 2024
2 parents 9afe7a2 + 1639590 commit eff45b1
Show file tree
Hide file tree
Showing 12 changed files with 757 additions and 252 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.thingsboard.mqtt.broker.common.data.mqtt.MsgExpiryResult;
import org.thingsboard.mqtt.broker.common.data.subscription.TopicSubscription;
import org.thingsboard.mqtt.broker.common.data.util.CallbackUtil;
import org.thingsboard.mqtt.broker.common.util.BrokerConstants;
import org.thingsboard.mqtt.broker.dao.client.application.ApplicationSharedSubscriptionService;
import org.thingsboard.mqtt.broker.dao.exception.DataValidationException;
import org.thingsboard.mqtt.broker.dao.topic.TopicValidationService;
Expand Down Expand Up @@ -350,11 +351,15 @@ void validateSharedSubscription(TopicSubscription subscription) {
if (shareName != null && shareName.isEmpty()) {
throw new DataValidationException("Shared subscription 'shareName' must be at least one character long");
}
if (!StringUtils.isEmpty(shareName) && (shareName.contains("+") || shareName.contains("#"))) {
if (!StringUtils.isEmpty(shareName) && shareNameContainsWildcards(shareName)) {
throw new DataValidationException("Shared subscription 'shareName' can not contain single lvl (+) or multi lvl (#) wildcards");
}
}

private boolean shareNameContainsWildcards(String shareName) {
return shareName.contains(BrokerConstants.SINGLE_LEVEL_WILDCARD) || shareName.contains(BrokerConstants.MULTI_LEVEL_WILDCARD);
}

boolean isSharedSubscriptionWithNoLocal(TopicSubscription subscription) {
return subscription.isSharedSubscription() && subscription.getOptions().isNoLocal();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ public static MqttSubscribeMsg createMqttSubscribeMsg(UUID sessionId, MqttSubscr

public static String getTopicName(String topicName) {
return isSharedTopic(topicName) ?
topicName.substring(topicName.indexOf("/", BrokerConstants.SHARE_NAME_IDX) + 1) : topicName;
topicName.substring(topicName.indexOf(BrokerConstants.TOPIC_DELIMITER_STR, BrokerConstants.SHARE_NAME_IDX) + 1) : topicName;
}

public static String getShareName(String topicName) {
try {
return isSharedTopic(topicName) ?
topicName.substring(BrokerConstants.SHARE_NAME_IDX, topicName.indexOf("/", BrokerConstants.SHARE_NAME_IDX)) : null;
topicName.substring(BrokerConstants.SHARE_NAME_IDX, topicName.indexOf(BrokerConstants.TOPIC_DELIMITER_STR, BrokerConstants.SHARE_NAME_IDX)) : null;
} catch (IndexOutOfBoundsException e) {
log.error("[{}] Could not extract 'shareName' from shared subscription", topicName, e);
throw new RuntimeException("Could not extract 'shareName' from shared subscription", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,39 +92,42 @@ public List<T> get(String topicFilter) {
}
continue;
}
if (topicPosition.prevDelimiterIndex >= topicFilter.length()) {
if (topicPosition.segmentStartIndex > topicFilter.length()) {
if (value != null) {
result.add(value);
}
continue;
}
String segment = getSegment(topicFilter, topicPosition.prevDelimiterIndex);
int nextDelimiterIndex = topicPosition.prevDelimiterIndex + segment.length() + 1;
String segment = getSegment(topicFilter, topicPosition.segmentStartIndex);
int nextSegmentStartIndex = getNextSegmentStartIndex(topicPosition.segmentStartIndex, segment);
if (segment.equals(BrokerConstants.MULTI_LEVEL_WILDCARD)) {
childNodes.values().stream()
.filter(childNode -> notStartingWith$(topicPosition.prevDelimiterIndex == 0, childNode))
.filter(childNode -> notStartingWith$(topicPosition, childNode))
.forEach(childNode -> topicPositions.add(new TopicPosition<>(0, childNode, true)));
if (value != null) {
result.add(value);
}
} else if (segment.equals(BrokerConstants.SINGLE_LEVEL_WILDCARD)) {
childNodes.values().stream()
.filter(childNode -> notStartingWith$(topicPosition.prevDelimiterIndex == 0, childNode))
.forEach(childNode -> topicPositions.add(new TopicPosition<>(nextDelimiterIndex, childNode, false)));
.filter(childNode -> notStartingWith$(topicPosition, childNode))
.forEach(childNode -> topicPositions.add(new TopicPosition<>(nextSegmentStartIndex, childNode, false)));
} else {
Node<T> segmentNode = childNodes.get(segment);
if (segmentNode != null) {
topicPositions.add(new TopicPosition<>(nextDelimiterIndex, segmentNode, false));
topicPositions.add(new TopicPosition<>(nextSegmentStartIndex, segmentNode, false));
}
}
}
return result;
}

private boolean notStartingWith$(boolean isFirstSegment, Node<T> childNode) {
return !isFirstSegment || childNode.key.charAt(0) != '$';
private boolean notStartingWith$(TopicPosition<T> topicPosition, Node<T> childNode) {
return topicPosition.segmentStartIndex != 0 || childNode.key.isEmpty() || childNode.key.charAt(0) != '$';
}

@AllArgsConstructor
private static class TopicPosition<T> {
private final int prevDelimiterIndex;
private final int segmentStartIndex;
private final Node<T> node;
private final boolean isMultiLevelWildcard;
}
Expand All @@ -143,19 +146,19 @@ public void put(String topic, T val) {
}
}

private void put(Node<T> x, String topic, T val, int prevDelimiterIndex) {
if (prevDelimiterIndex >= topic.length()) {
private void put(Node<T> x, String topic, T val, int segmentStartIndex) {
if (segmentStartIndex > topic.length()) {
T prevValue = x.value.getAndSet(val);
if (prevValue == null) {
size.getAndIncrement();
}
} else {
String segment = getSegment(topic, prevDelimiterIndex);
String segment = getSegment(topic, segmentStartIndex);
Node<T> nextNode = x.children.computeIfAbsent(segment, s -> {
nodesCount.incrementAndGet();
return new Node<>(segment);
});
put(nextNode, topic, val, prevDelimiterIndex + segment.length() + 1);
put(nextNode, topic, val, getNextSegmentStartIndex(segmentStartIndex, segment));
}
}

Expand All @@ -165,7 +168,7 @@ public void delete(String topic) {
if (topic == null) {
throw new IllegalArgumentException("Topic cannot be null");
}
Node<T> x = getNode(root, topic, 0);
Node<T> x = getDeleteNode(root, topic, 0);
if (x != null) {
T prevValue = x.value.getAndSet(null);
if (prevValue != null) {
Expand All @@ -174,13 +177,19 @@ public void delete(String topic) {
}
}

private Node<T> getNode(Node<T> x, String topic, int prevDelimiterIndex) {
if (x == null) return null;
if (prevDelimiterIndex >= topic.length()) {
private Node<T> getDeleteNode(Node<T> x, String topic, int segmentStartIndex) {
if (x == null) {
return null;
}
if (segmentStartIndex > topic.length()) {
return x;
}
String segment = getSegment(topic, prevDelimiterIndex);
return getNode(x.children.get(segment), topic, prevDelimiterIndex + segment.length() + 1);
String segment = getSegment(topic, segmentStartIndex);
return getDeleteNode(x.children.get(segment), topic, getNextSegmentStartIndex(segmentStartIndex, segment));
}

private int getNextSegmentStartIndex(int segmentStartIndex, String segment) {
return segmentStartIndex + segment.length() + 1;
}

@Override
Expand Down Expand Up @@ -240,11 +249,11 @@ private void acquireClearTrieLock() throws RetainMsgTrieClearException {
}
}

private String getSegment(String key, int prevDelimiterIndex) {
int nextDelimitedIndex = key.indexOf(BrokerConstants.TOPIC_DELIMITER, prevDelimiterIndex);
private String getSegment(String key, int segmentStartIndex) {
int nextDelimiterIndex = key.indexOf(BrokerConstants.TOPIC_DELIMITER, segmentStartIndex);

return nextDelimitedIndex == -1 ?
key.substring(prevDelimiterIndex)
: key.substring(prevDelimiterIndex, nextDelimitedIndex);
return nextDelimiterIndex == -1 ?
key.substring(segmentStartIndex)
: key.substring(segmentStartIndex, nextDelimiterIndex);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,23 @@ public List<ValueWithTopicFilter<T>> get(String topic) {
}
List<ValueWithTopicFilter<T>> result = new ArrayList<>();
Stack<TopicPosition<T>> topicPositions = new Stack<>();
topicPositions.add(new TopicPosition<>(BrokerConstants.EMPTY_STR, 0, root));
topicPositions.add(new TopicPosition<>(BrokerConstants.NULL_CHAR_STR, 0, root));

while (!topicPositions.isEmpty()) {
TopicPosition<T> topicPosition = topicPositions.pop();
if (topicPosition.prevDelimiterIndex >= topic.length()) {
if (topicPosition.segmentStartIndex > topic.length()) {
result.addAll(wrapValuesWithTopicFilter(topicPosition.prevTopicFilter, topicPosition.node.values));

Node<T> multiLevelWildcardSubs = topicPosition.node.children.get(BrokerConstants.MULTI_LEVEL_WILDCARD);
if (multiLevelWildcardSubs != null) {
String currentTopicFilter = appendSegment(topicPosition.prevTopicFilter, BrokerConstants.MULTI_LEVEL_WILDCARD);
result.addAll(wrapValuesWithTopicFilter(currentTopicFilter, multiLevelWildcardSubs.values));
}
continue;
}
ConcurrentMap<String, Node<T>> childNodes = topicPosition.node.children;
String segment = getSegment(topic, topicPosition.prevDelimiterIndex);
int nextDelimiterIndex = topicPosition.prevDelimiterIndex + segment.length() + 1;
String segment = getSegment(topic, topicPosition.segmentStartIndex);
int nextSegmentStartIndex = getNextSegmentStartIndex(topicPosition.segmentStartIndex, segment);

if (notStartingWith$(topic, topicPosition)) {
Node<T> multiLevelWildcardSubs = childNodes.get(BrokerConstants.MULTI_LEVEL_WILDCARD);
Expand All @@ -88,21 +94,21 @@ public List<ValueWithTopicFilter<T>> get(String topic) {
Node<T> singleLevelWildcardSubs = childNodes.get(BrokerConstants.SINGLE_LEVEL_WILDCARD);
if (singleLevelWildcardSubs != null) {
String currentTopicFilter = appendSegment(topicPosition.prevTopicFilter, BrokerConstants.SINGLE_LEVEL_WILDCARD);
topicPositions.add(new TopicPosition<>(currentTopicFilter, nextDelimiterIndex, singleLevelWildcardSubs));
topicPositions.add(new TopicPosition<>(currentTopicFilter, nextSegmentStartIndex, singleLevelWildcardSubs));
}
}

Node<T> segmentNode = childNodes.get(segment);
if (segmentNode != null) {
String currentTopicFilter = appendSegment(topicPosition.prevTopicFilter, segment);
topicPositions.add(new TopicPosition<>(currentTopicFilter, nextDelimiterIndex, segmentNode));
topicPositions.add(new TopicPosition<>(currentTopicFilter, nextSegmentStartIndex, segmentNode));
}
}
return result;
}

private boolean notStartingWith$(String topic, TopicPosition<T> topicPosition) {
return topicPosition.prevDelimiterIndex != 0 || topic.charAt(0) != '$';
return topicPosition.segmentStartIndex != 0 || topic.charAt(0) != '$';
}

private List<ValueWithTopicFilter<T>> wrapValuesWithTopicFilter(String topicFilter, Collection<T> values) {
Expand All @@ -129,16 +135,16 @@ public void put(String topicFilter, T val) {
}
}

private void put(Node<T> x, String key, T val, int prevDelimiterIndex) {
if (prevDelimiterIndex >= key.length()) {
private void put(Node<T> x, String key, T val, int segmentStartIndex) {
if (segmentStartIndex > key.length()) {
addOrReplace(x.values, val);
} else {
String segment = getSegment(key, prevDelimiterIndex);
String segment = getSegment(key, segmentStartIndex);
Node<T> nextNode = x.children.computeIfAbsent(segment, s -> {
nodesCount.incrementAndGet();
return new Node<>();
});
put(nextNode, key, val, prevDelimiterIndex + segment.length() + 1);
put(nextNode, key, val, getNextSegmentStartIndex(segmentStartIndex, segment));
}
}

Expand All @@ -159,7 +165,7 @@ public boolean delete(String topicFilter, Predicate<T> deletionFilter) {
if (topicFilter == null || deletionFilter == null) {
throw new IllegalArgumentException("Topic filter or deletionFilter cannot be null");
}
Node<T> x = getNode(root, topicFilter, 0);
Node<T> x = getDeleteNode(root, topicFilter, 0);
if (x != null) {
Set<T> valuesToDelete = x.values.stream().filter(deletionFilter).collect(Collectors.toSet());
if (valuesToDelete.isEmpty()) {
Expand Down Expand Up @@ -230,25 +236,34 @@ private boolean clearEmptyChildren(Node<T> node) {
return isNodeEmpty;
}

private Node<T> getNode(Node<T> x, String key, int prevDelimiterIndex) {
if (x == null) return null;
if (prevDelimiterIndex >= key.length()) {
private Node<T> getDeleteNode(Node<T> x, String key, int segmentStartIndex) {
if (x == null) {
return null;
}
if (segmentStartIndex > key.length()) {
return x;
}
String segment = getSegment(key, prevDelimiterIndex);
return getNode(x.children.get(segment), key, prevDelimiterIndex + segment.length() + 1);
String segment = getSegment(key, segmentStartIndex);
return getDeleteNode(x.children.get(segment), key, getNextSegmentStartIndex(segmentStartIndex, segment));
}

private int getNextSegmentStartIndex(int segmentStartIndex, String segment) {
return segmentStartIndex + segment.length() + 1;
}

private String getSegment(String key, int prevDelimiterIndex) {
int nextDelimitedIndex = key.indexOf(BrokerConstants.TOPIC_DELIMITER, prevDelimiterIndex);
private String getSegment(String key, int segmentStartIndex) {
int nextDelimiterIndex = key.indexOf(BrokerConstants.TOPIC_DELIMITER, segmentStartIndex);

return nextDelimitedIndex == -1 ?
key.substring(prevDelimiterIndex)
: key.substring(prevDelimiterIndex, nextDelimitedIndex);
return nextDelimiterIndex == -1 ?
key.substring(segmentStartIndex)
: key.substring(segmentStartIndex, nextDelimiterIndex);
}

private String appendSegment(String topicFilter, String segment) {
return topicFilter.isEmpty() ? segment : topicFilter + BrokerConstants.TOPIC_DELIMITER + segment;
if (topicFilter.equals(BrokerConstants.NULL_CHAR_STR)) {
return segment;
}
return topicFilter + BrokerConstants.TOPIC_DELIMITER + segment;
}

private static class Node<T> {
Expand Down Expand Up @@ -276,8 +291,8 @@ public int hashCode() {
@AllArgsConstructor
private static class TopicPosition<T> {
private final String prevTopicFilter;
private final int prevDelimiterIndex;
private final int segmentStartIndex;
private final Node<T> node;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@

public interface SubscriptionTrie<T> {

/*
ValueWithTopicFilter<T> is required for persistent DEVICE clients to save timestamps per topicFilter
*/
List<ValueWithTopicFilter<T>> get(String topic);

void put(String topicFilter, T val);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;

@AllArgsConstructor
@Getter
@EqualsAndHashCode
@ToString
public class ValueWithTopicFilter<T> {
private final T value;
private final String topicFilter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,15 @@ public void givenTopicSubscription_whenValidateSharedSubscriptionWithEmptyShareN
}

@Test(expected = DataValidationException.class)
public void givenTopicSubscription_whenValidateSharedSubscriptionWithWildcardShareName_thenFailure() {
public void givenTopicSubscription_whenValidateSharedSubscriptionWithMultiLvlWildcardShareName_thenFailure() {
mqttSubscribeHandler.validateSharedSubscription(new TopicSubscription("tf", 1, "#"));
}

@Test(expected = DataValidationException.class)
public void givenTopicSubscription_whenValidateSharedSubscriptionWithSingleLvlWildcardShareName_thenFailure() {
mqttSubscribeHandler.validateSharedSubscription(new TopicSubscription("tf", 1, "abc+"));
}

@Test
public void givenTopicSubscriptions_whenFilterSameSubscriptionsWithDifferentQos_thenGetExpectedResult() {
BiFunction<TopicSharedSubscription, TopicSharedSubscription, Boolean> f =
Expand Down
Loading

0 comments on commit eff45b1

Please sign in to comment.