Skip to content

Commit

Permalink
[FLINK-11126][YARN][security] Filter out AMRMToken in the TaskManager…
Browse files Browse the repository at this point in the history
… credentials

This closes apache#7895.
  • Loading branch information
link3280 authored and tillrohrmann committed May 26, 2019
1 parent 4f558e4 commit dac3648
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,50 @@
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ResourceManagerOptions;
import org.apache.flink.core.testutils.CommonTestUtils;
import org.apache.flink.runtime.clusterframework.ContaineredTaskManagerParameters;
import org.apache.flink.util.TestLogger;

import org.apache.hadoop.io.Text;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.security.AMRMTokenIdentifier;
import org.apache.log4j.AppenderSkeleton;
import org.apache.log4j.Level;
import org.apache.log4j.spi.LoggingEvent;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

/**
* Tests for various utilities.
*/
public class UtilsTest extends TestLogger {
private static final Logger LOG = LoggerFactory.getLogger(UtilsTest.class);

@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();

@Test
public void testUberjarLocator() {
File dir = YarnTestBase.findFile("..", new YarnTestBase.RootDirFilenameFilter());
Expand Down Expand Up @@ -136,6 +158,72 @@ public void testGetEnvironmentVariablesErroneous() {
Assert.assertEquals(0, res.size());
}

@Test
public void testCreateTaskExecutorCredentials() throws Exception {
File root = temporaryFolder.getRoot();
File home = new File(root, "home");
boolean created = home.mkdir();
assertTrue(created);

Configuration flinkConf = new Configuration();
YarnConfiguration yarnConf = new YarnConfiguration();

Map<String, String> env = new HashMap<>();
env.put(YarnConfigKeys.ENV_APP_ID, "foo");
env.put(YarnConfigKeys.ENV_CLIENT_HOME_DIR, home.getAbsolutePath());
env.put(YarnConfigKeys.ENV_CLIENT_SHIP_FILES, "");
env.put(YarnConfigKeys.ENV_FLINK_CLASSPATH, "");
env.put(YarnConfigKeys.ENV_HADOOP_USER_NAME, "foo");
env.put(YarnConfigKeys.FLINK_JAR_PATH, root.toURI().toString());
env = Collections.unmodifiableMap(env);

File credentialFile = temporaryFolder.newFile("container_tokens");
final Text amRmTokenKind = AMRMTokenIdentifier.KIND_NAME;
final Text hdfsDelegationTokenKind = new Text("HDFS_DELEGATION_TOKEN");
final Text service = new Text("test-service");
Credentials amCredentials = new Credentials();
amCredentials.addToken(amRmTokenKind, new Token<>(new byte[4], new byte[4], amRmTokenKind, service));
amCredentials.addToken(hdfsDelegationTokenKind, new Token<>(new byte[4], new byte[4],
hdfsDelegationTokenKind, service));
amCredentials.writeTokenStorageFile(new org.apache.hadoop.fs.Path(credentialFile.getAbsolutePath()), yarnConf);

ContaineredTaskManagerParameters tmParams = new ContaineredTaskManagerParameters(64,
64, 16, 1, new HashMap<>(1));
Configuration taskManagerConf = new Configuration();

String workingDirectory = root.getAbsolutePath();
Class<?> taskManagerMainClass = YarnTaskExecutorRunner.class;
ContainerLaunchContext ctx;

final Map<String, String> originalEnv = System.getenv();
try {
Map<String, String> systemEnv = new HashMap<>(originalEnv);
systemEnv.put("HADOOP_TOKEN_FILE_LOCATION", credentialFile.getAbsolutePath());
CommonTestUtils.setEnv(systemEnv);
ctx = Utils.createTaskExecutorContext(flinkConf, yarnConf, env, tmParams,
taskManagerConf, workingDirectory, taskManagerMainClass, LOG);
} finally {
CommonTestUtils.setEnv(originalEnv);
}

Credentials credentials = new Credentials();
try (DataInputStream dis = new DataInputStream(new ByteArrayInputStream(ctx.getTokens().array()))) {
credentials.readTokenStorageStream(dis);
}
Collection<Token<? extends TokenIdentifier>> tokens = credentials.getAllTokens();
boolean hasHdfsDelegationToken = false;
boolean hasAmRmToken = false;
for (Token<? extends TokenIdentifier> token : tokens) {
if (token.getKind().equals(amRmTokenKind)) {
hasAmRmToken = true;
} else if (token.getKind().equals(hdfsDelegationTokenKind)) {
hasHdfsDelegationToken = true;
}
}
assertTrue(hasHdfsDelegationToken);
assertFalse(hasAmRmToken);
}

//
// --------------- Tools to test if a certain string has been logged with Log4j. -------------
// See : http:https://stackoverflow.com/questions/3717402/how-to-test-w-junit-that-warning-was-logged-w-log4j
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
import org.apache.flink.test.util.SecureTestEnvironment;
import org.apache.flink.test.util.TestingSecurityContext;

import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;

import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.security.AMRMTokenIdentifier;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.ResourceScheduler;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.fifo.FifoScheduler;
import org.hamcrest.Matchers;
Expand All @@ -39,6 +42,7 @@

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;

/**
Expand Down Expand Up @@ -116,6 +120,21 @@ public void testDetachedMode() throws InterruptedException, IOException {
"The JobManager and the TaskManager should both run with Kerberos.",
jobManagerRunsWithKerberos && taskManagerRunsWithKerberos,
Matchers.is(true));

final List<String> amRMTokens = Lists.newArrayList(AMRMTokenIdentifier.KIND_NAME.toString());
final String jobmanagerContainerId = getContainerIdByLogName("jobmanager.log");
final String taskmanagerContainerId = getContainerIdByLogName("taskmanager.log");
final boolean jobmanagerWithAmRmToken = verifyTokenKindInContainerCredentials(amRMTokens, jobmanagerContainerId);
final boolean taskmanagerWithAmRmToken = verifyTokenKindInContainerCredentials(amRMTokens, taskmanagerContainerId);

Assert.assertThat(
"The JobManager should have AMRMToken.",
jobmanagerWithAmRmToken,
Matchers.is(true));
Assert.assertThat(
"The TaskManager should not have AMRMToken.",
taskmanagerWithAmRmToken,
Matchers.is(false));
}

/* For secure cluster testing, it is enough to run only one test and override below test methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.service.Service;
import org.apache.hadoop.yarn.api.records.ApplicationReport;
import org.apache.hadoop.yarn.api.records.ContainerId;
Expand Down Expand Up @@ -73,6 +76,7 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -464,9 +468,8 @@ public boolean accept(File dir, String name) {
}
File f = new File(dir.getAbsolutePath() + "/" + name);
LOG.info("Searching in {}", f.getAbsolutePath());
try {
try (Scanner scanner = new Scanner(f)) {
Set<String> foundSet = new HashSet<>(mustHave.length);
Scanner scanner = new Scanner(f);
while (scanner.hasNextLine()) {
final String lineFromFile = scanner.nextLine();
for (String str : mustHave) {
Expand All @@ -493,6 +496,53 @@ public boolean accept(File dir, String name) {
}
}

public static boolean verifyTokenKindInContainerCredentials(final Collection<String> tokens, final String containerId)
throws IOException {
File cwd = new File("target/" + YARN_CONFIGURATION.get(TEST_CLUSTER_NAME_KEY));
if (!cwd.exists() || !cwd.isDirectory()) {
return false;
}

File containerTokens = findFile(cwd.getAbsolutePath(), new FilenameFilter() {
@Override
public boolean accept(File dir, String name) {
return name.equals(containerId + ".tokens");
}
});

if (containerTokens != null) {
LOG.info("Verifying tokens in {}", containerTokens.getAbsolutePath());

Credentials tmCredentials = Credentials.readTokenStorageFile(containerTokens, new Configuration());

Collection<Token<? extends TokenIdentifier>> userTokens = tmCredentials.getAllTokens();
Set<String> tokenKinds = new HashSet<>(4);
for (Token<? extends TokenIdentifier> token : userTokens) {
tokenKinds.add(token.getKind().toString());
}

return tokenKinds.containsAll(tokens);
} else {
LOG.warn("Unable to find credential file for container {}", containerId);
return false;
}
}

public static String getContainerIdByLogName(String logName) {
File cwd = new File("target/" + YARN_CONFIGURATION.get(TEST_CLUSTER_NAME_KEY));
File containerLog = findFile(cwd.getAbsolutePath(), new FilenameFilter() {
@Override
public boolean accept(File dir, String name) {
return name.equals(logName);
}
});
if (containerLog != null) {
return containerLog.getParentFile().getName();
} else {
throw new IllegalStateException("No container has log named " + logName);
}
}

public static void sleep(int time) {
try {
Thread.sleep(time);
Expand Down
13 changes: 12 additions & 1 deletion flink-yarn/src/main/java/org/apache/flink/yarn/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.security.AMRMTokenIdentifier;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.hadoop.yarn.util.Records;
import org.slf4j.Logger;
Expand Down Expand Up @@ -567,7 +568,17 @@ static ContainerLaunchContext createTaskExecutorContext(
new File(fileLocation),
HadoopUtils.getHadoopConfiguration(flinkConfig));

cred.writeTokenStorageToStream(dob);
// Filter out AMRMToken before setting the tokens to the TaskManager container context.
Credentials taskManagerCred = new Credentials();
Collection<Token<? extends TokenIdentifier>> userTokens = cred.getAllTokens();
for (Token<? extends TokenIdentifier> token : userTokens) {
if (!token.getKind().equals(AMRMTokenIdentifier.KIND_NAME)) {
final Text id = new Text(token.getIdentifier());
taskManagerCred.addToken(id, token);
}
}

taskManagerCred.writeTokenStorageToStream(dob);
ByteBuffer securityTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength());
ctx.setTokens(securityTokens);
} catch (Throwable t) {
Expand Down

0 comments on commit dac3648

Please sign in to comment.