Skip to content

Commit

Permalink
[java] Tapping the Node session when there is WebSocket activity
Browse files Browse the repository at this point in the history
  • Loading branch information
diemol committed Jan 12, 2024
1 parent 19a1813 commit c0ddca6
Showing 1 changed file with 67 additions and 27 deletions.
94 changes: 67 additions & 27 deletions java/src/org/openqa/selenium/grid/node/ProxyNodeWebsockets.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.openqa.selenium.Capabilities;
import org.openqa.selenium.devtools.CdpEndpointFinder;
import org.openqa.selenium.grid.data.Session;
import org.openqa.selenium.internal.Require;
import org.openqa.selenium.remote.SessionId;
import org.openqa.selenium.remote.http.BinaryMessage;
import org.openqa.selenium.remote.http.ClientConfig;
Expand Down Expand Up @@ -75,57 +76,68 @@ public Optional<Consumer<Message>> apply(String uri, Consumer<Message> downstrea
return Optional.empty();
}

String sessionId =
Stream.of(fwdMatch, cdpMatch, bidiMatch, vncMatch)
.filter(Objects::nonNull)
.findFirst()
.get()
.getParameters()
.get("sessionId");
Optional<UrlTemplate.Match> firstMatch =
Stream.of(fwdMatch, cdpMatch, bidiMatch, vncMatch).filter(Objects::nonNull).findFirst();

if (firstMatch.isEmpty()) {
LOG.warning("No session id found in uri " + uri);
return Optional.empty();
}

String sessionId = firstMatch.get().getParameters().get("sessionId");

LOG.fine("Matching websockets for session id: " + sessionId);
SessionId id = new SessionId(sessionId);

if (!node.isSessionOwner(id)) {
LOG.info("Not owner of " + id);
LOG.warning("Not owner of " + id);
return Optional.empty();
}

Session session = node.getSession(id);
Capabilities caps = session.getCapabilities();
LOG.fine("Scanning for endpoint: " + caps);

// Used by the ForwardingListener to notify the node that the session is still active
Consumer<SessionId> sessionConsumer = node::isSessionOwner;

if (bidiMatch != null) {
return findBiDiEndpoint(downstream, caps);
return findBiDiEndpoint(downstream, caps, sessionConsumer, id);
}

if (vncMatch != null) {
return findVncEndpoint(downstream, caps);
// Passing a fake consumer to the ForwardingListener to avoid sending a session notification
// when VNC is used.
sessionConsumer = fakeConsumer -> {};
return findVncEndpoint(downstream, caps, sessionConsumer, id);
}

// This match happens when a user wants to do CDP over Dynamic Grid
if (fwdMatch != null) {
LOG.info("Matched endpoint where CDP connection is being forwarded");
return findCdpEndpoint(downstream, caps);
return findCdpEndpoint(downstream, caps, sessionConsumer, id);
}
if (caps.getCapabilityNames().contains("se:forwardCdp")) {
LOG.info("Found endpoint where CDP connection needs to be forwarded");
return findForwardCdpEndpoint(downstream, caps);
return findForwardCdpEndpoint(downstream, caps, sessionConsumer, id);
}
return findCdpEndpoint(downstream, caps);
return findCdpEndpoint(downstream, caps, sessionConsumer, id);
}

private Optional<Consumer<Message>> findCdpEndpoint(
Consumer<Message> downstream, Capabilities caps) {
// Using strings here to avoid Node depending upon specific drivers.
Consumer<Message> downstream,
Capabilities caps,
Consumer<SessionId> sessionConsumer,
SessionId sessionId) {

for (String cdpEndpointCap : CDP_ENDPOINT_CAPS) {
Optional<URI> reportedUri = CdpEndpointFinder.getReportedUri(cdpEndpointCap, caps);
Optional<HttpClient> client =
reportedUri.map(uri -> CdpEndpointFinder.getHttpClient(clientFactory, uri));
Optional<URI> cdpUri;

try {
cdpUri = client.flatMap(httpClient -> CdpEndpointFinder.getCdpEndPoint(httpClient));
cdpUri = client.flatMap(CdpEndpointFinder::getCdpEndPoint);
} catch (Exception e) {
try {
client.ifPresent(HttpClient::close);
Expand All @@ -137,7 +149,7 @@ private Optional<Consumer<Message>> findCdpEndpoint(

if (cdpUri.isPresent()) {
LOG.log(getDebugLogLevel(), String.format("Endpoint found in %s", cdpEndpointCap));
return cdpUri.map(cdp -> createWsEndPoint(cdp, downstream));
return cdpUri.map(cdp -> createWsEndPoint(cdp, downstream, sessionConsumer, sessionId));
} else {
try {
client.ifPresent(HttpClient::close);
Expand All @@ -154,30 +166,41 @@ private Optional<Consumer<Message>> findCdpEndpoint(
}

private Optional<Consumer<Message>> findBiDiEndpoint(
Consumer<Message> downstream, Capabilities caps) {
Consumer<Message> downstream,
Capabilities caps,
Consumer<SessionId> sessionConsumer,
SessionId sessionId) {
try {
URI uri = new URI(String.valueOf(caps.getCapability("webSocketUrl")));
return Optional.of(uri).map(bidi -> createWsEndPoint(bidi, downstream));
return Optional.of(uri)
.map(bidi -> createWsEndPoint(bidi, downstream, sessionConsumer, sessionId));
} catch (URISyntaxException e) {
LOG.warning("Unable to create URI from: " + caps.getCapability("webSocketUrl"));
return Optional.empty();
}
}

private Optional<Consumer<Message>> findForwardCdpEndpoint(
Consumer<Message> downstream, Capabilities caps) {
Consumer<Message> downstream,
Capabilities caps,
Consumer<SessionId> sessionConsumer,
SessionId sessionId) {
// When using Dynamic Grid, we need to connect to a container before using the debuggerAddress
try {
URI uri = new URI(String.valueOf(caps.getCapability("se:forwardCdp")));
return Optional.of(uri).map(cdp -> createWsEndPoint(cdp, downstream));
return Optional.of(uri)
.map(cdp -> createWsEndPoint(cdp, downstream, sessionConsumer, sessionId));
} catch (URISyntaxException e) {
LOG.warning("Unable to create URI from: " + caps.getCapability("se:forwardCdp"));
return Optional.empty();
}
}

private Optional<Consumer<Message>> findVncEndpoint(
Consumer<Message> downstream, Capabilities caps) {
Consumer<Message> downstream,
Capabilities caps,
Consumer<SessionId> sessionConsumer,
SessionId sessionId) {
String vncLocalAddress = (String) caps.getCapability("se:vncLocalAddress");
Optional<URI> vncUri;
try {
Expand All @@ -187,40 +210,57 @@ private Optional<Consumer<Message>> findVncEndpoint(
return Optional.empty();
}
LOG.log(getDebugLogLevel(), String.format("Endpoint found in %s", "se:vncLocalAddress"));
return vncUri.map(vnc -> createWsEndPoint(vnc, downstream));
return vncUri.map(vnc -> createWsEndPoint(vnc, downstream, sessionConsumer, sessionId));
}

private Consumer<Message> createWsEndPoint(URI uri, Consumer<Message> downstream) {
Objects.requireNonNull(uri);
private Consumer<Message> createWsEndPoint(
URI uri,
Consumer<Message> downstream,
Consumer<SessionId> sessionConsumer,
SessionId sessionId) {
Require.nonNull("downstream", downstream);
Require.nonNull("uri", uri);
Require.nonNull("sessionConsumer", sessionConsumer);
Require.nonNull("sessionId", sessionId);

LOG.info("Establishing connection to " + uri);

HttpClient client = clientFactory.createClient(ClientConfig.defaultConfig().baseUri(uri));
WebSocket upstream =
client.openSocket(new HttpRequest(GET, uri.toString()), new ForwardingListener(downstream));
client.openSocket(
new HttpRequest(GET, uri.toString()),
new ForwardingListener(downstream, sessionConsumer, sessionId));
return upstream::send;
}

private static class ForwardingListener implements WebSocket.Listener {
private final Consumer<Message> downstream;
private final Consumer<SessionId> sessionConsumer;
private final SessionId sessionId;

public ForwardingListener(Consumer<Message> downstream) {
public ForwardingListener(
Consumer<Message> downstream, Consumer<SessionId> sessionConsumer, SessionId sessionId) {
this.downstream = Objects.requireNonNull(downstream);
this.sessionConsumer = Objects.requireNonNull(sessionConsumer);
this.sessionId = Objects.requireNonNull(sessionId);
}

@Override
public void onBinary(byte[] data) {
downstream.accept(new BinaryMessage(data));
sessionConsumer.accept(sessionId);
}

@Override
public void onClose(int code, String reason) {
downstream.accept(new CloseMessage(code, reason));
sessionConsumer.accept(sessionId);
}

@Override
public void onText(CharSequence data) {
downstream.accept(new TextMessage(data));
sessionConsumer.accept(sessionId);
}

@Override
Expand Down

0 comments on commit c0ddca6

Please sign in to comment.