Skip to content

Commit

Permalink
[FLINK-12838][network] netty-fy SSL configuration
Browse files Browse the repository at this point in the history
Refactor the SSL configuration done for Netty to have it more like the way
Netty intends it to be: using its SslContextBuilder. This will make it much
easier to set a different Netty SSL engine provider.

[hotfix][network] extract key and trust manager factory creation
  • Loading branch information
NicoK committed Jun 15, 2019
1 parent 500c133 commit 65e822b
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ public MesosArtifactServer(String prefix, String serverHostname, int configuredP

router = new Router();

final Configuration sslConfig = config;
ChannelInitializer<SocketChannel> initializer = new ChannelInitializer<SocketChannel>() {

@Override
Expand All @@ -135,7 +134,8 @@ protected void initChannel(SocketChannel ch) {

// SSL should be the first handler in the pipeline
if (sslFactory != null) {
ch.pipeline().addLast("ssl", sslFactory.createNettySSLHandler());
ch.pipeline().addLast("ssl",
sslFactory.createNettySSLHandler(ch.alloc()));
}

ch.pipeline()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ protected void initChannel(SocketChannel ch) {

// SSL should be the first handler in the pipeline
if (serverSSLFactory != null) {
ch.pipeline().addLast("ssl", serverSSLFactory.createNettySSLHandler());
ch.pipeline().addLast("ssl",
serverSSLFactory.createNettySSLHandler(ch.alloc()));
}

ch.pipeline()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ public void initChannel(SocketChannel channel) throws Exception {
// SSL handler should be added first in the pipeline
if (clientSSLFactory != null) {
SslHandler sslHandler = clientSSLFactory.createNettySSLHandler(
channel.alloc(),
serverSocketAddress.getAddress().getCanonicalHostName(),
serverSocketAddress.getPort());
channel.pipeline().addLast("ssl", sslHandler);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ int init(final NettyProtocol protocol, NettyBufferPool nettyBufferPool) throws I
@Override
public void initChannel(SocketChannel channel) throws Exception {
if (sslHandlerFactory != null) {
channel.pipeline().addLast("ssl", sslHandlerFactory.createNettySSLHandler());
channel.pipeline().addLast("ssl",
sslHandlerFactory.createNettySSLHandler(channel.alloc()));
}

channel.pipeline().addLast(protocol.getServerChannelHandlers());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

package org.apache.flink.runtime.io.network.netty;

import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator;
import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslContext;
import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslHandler;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;

import static java.util.Objects.requireNonNull;
Expand All @@ -30,15 +31,7 @@
*/
public class SSLHandlerFactory {

private final SSLContext sslContext;

private final String[] enabledProtocols;

private final String[] enabledCipherSuites;

private final boolean clientMode;

private final boolean clientAuthentication;
private final SslContext sslContext;

private final int handshakeTimeoutMs;

Expand All @@ -54,29 +47,21 @@ public class SSLHandlerFactory {
* default)
*/
public SSLHandlerFactory(
final SSLContext sslContext,
final String[] enabledProtocols,
final String[] enabledCipherSuites,
final boolean clientMode,
final boolean clientAuthentication,
final SslContext sslContext,
final int handshakeTimeoutMs,
final int closeNotifyFlushTimeoutMs) {

this.sslContext = requireNonNull(sslContext, "sslContext must not be null");
this.enabledProtocols = requireNonNull(enabledProtocols, "enabledProtocols must not be null");
this.enabledCipherSuites = requireNonNull(enabledCipherSuites, "cipherSuites must not be null");
this.clientMode = clientMode;
this.clientAuthentication = clientAuthentication;
this.handshakeTimeoutMs = handshakeTimeoutMs;
this.closeNotifyFlushTimeoutMs = closeNotifyFlushTimeoutMs;
}

public SslHandler createNettySSLHandler() {
return createNettySSLHandler(createSSLEngine());
public SslHandler createNettySSLHandler(ByteBufAllocator allocator) {
return createNettySSLHandler(createSSLEngine(allocator));
}

public SslHandler createNettySSLHandler(String hostname, int port) {
return createNettySSLHandler(createSSLEngine(hostname, port));
public SslHandler createNettySSLHandler(ByteBufAllocator allocator, String hostname, int port) {
return createNettySSLHandler(createSSLEngine(allocator, hostname, port));
}

private SslHandler createNettySSLHandler(SSLEngine sslEngine) {
Expand All @@ -91,24 +76,11 @@ private SslHandler createNettySSLHandler(SSLEngine sslEngine) {
return sslHandler;
}

private SSLEngine createSSLEngine() {
final SSLEngine sslEngine = sslContext.createSSLEngine();
configureSSLEngine(sslEngine);
return sslEngine;
private SSLEngine createSSLEngine(ByteBufAllocator allocator) {
return sslContext.newEngine(allocator);
}

private SSLEngine createSSLEngine(String hostname, int port) {
final SSLEngine sslEngine = sslContext.createSSLEngine(hostname, port);
configureSSLEngine(sslEngine);
return sslEngine;
}

private void configureSSLEngine(SSLEngine sslEngine) {
sslEngine.setEnabledProtocols(enabledProtocols);
sslEngine.setEnabledCipherSuites(enabledCipherSuites);
sslEngine.setUseClientMode(clientMode);
if (!clientMode) {
sslEngine.setNeedClientAuth(clientAuthentication);
}
private SSLEngine createSSLEngine(ByteBufAllocator allocator, String hostname, int port) {
return sslContext.newEngine(allocator, hostname, port);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ protected void decode(ChannelHandlerContext context, ByteBuf in, List<Object> ou
}

private void handleSsl(ChannelHandlerContext context) {
SslHandler sslHandler = sslHandlerFactory.createNettySSLHandler();
SslHandler sslHandler = sslHandlerFactory.createNettySSLHandler(context.alloc());
try {
context.pipeline().replace(this, SSL_HANDLER_NAME, sslHandler);
} catch (Throwable t) {
Expand Down
Loading

0 comments on commit 65e822b

Please sign in to comment.