Skip to content

Commit

Permalink
[FLINK-1349] [runtime] Various cleanups to make scala runtime code in…
Browse files Browse the repository at this point in the history
…teract smoother with java
  • Loading branch information
StephanEwen committed Jan 6, 2015
1 parent 972a7b0 commit cec30ff
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ public class Execution implements Serializable {

private static final int NUM_CANCEL_CALL_TRIES = 3;

public static FiniteDuration timeout = new FiniteDuration(ConfigConstants
.DEFAULT_AKKA_ASK_TIMEOUT, TimeUnit.SECONDS);
public static FiniteDuration timeout = new FiniteDuration(
ConfigConstants.DEFAULT_AKKA_ASK_TIMEOUT, TimeUnit.SECONDS);

// --------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -289,9 +289,9 @@ public void deployToSlot(final AllocatedSlot slot) throws JobException {

@Override
public void onComplete(Throwable failure, Object success) throws Throwable {
if(failure != null){
if (failure != null) {
markFailed(failure);
}else{
} else {
TaskOperationResult result = (TaskOperationResult) success;
if (success == null) {
markFailed(new Exception("Failed to deploy the task to slot " + slot + ": TaskOperationResult was null"));
Expand All @@ -305,8 +305,7 @@ else if (result.success()) {
else {
// deployment failed :(
markFailed(new Exception("Failed to deploy the task " +
getVertexWithAttempt() + " to slot " + slot + ": " + result
.description()));
getVertexWithAttempt() + " to slot " + slot + ": " + result.description()));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ package org.apache.flink.runtime.messages
import org.apache.flink.core.io.InputSplit
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID
import org.apache.flink.runtime.instance.InstanceID
import org.apache.flink.runtime.instance.InstanceID

object TaskManagerMessages {


/**
* Cancels the task associated with [[attemptID]]. The result is sent back to the sender as a
* [[TaskOperationResult]] message.
Expand Down Expand Up @@ -113,4 +114,28 @@ object TaskManagerMessages {
* @param cause reason for the external failure
*/
case class FailTask(executionID: ExecutionAttemptID, cause: Throwable)

// --------------------------------------------------------------------------
// Utility methods to allow simpler case object access from Java
// --------------------------------------------------------------------------

def getNotifyWhenRegisteredAtJobManagerMessage() : AnyRef = {
NotifyWhenRegisteredAtJobManager
}

def getRegisteredAtJobManagerMessage() : AnyRef = {
RegisteredAtJobManager
}

def getRegisterAtJobManagerMessage() : AnyRef = {
RegisterAtJobManager
}

def getSendHeartbeatMessage() : AnyRef = {
SendHeartbeat
}

def getLogMemoryUsageMessage() : AnyRef = {
RegisteredAtJobManager
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,16 @@ class TaskManager(val connectionInfo: InstanceConnectionInfo, val jobManagerAkka

if (registered) {
registrationScheduler.foreach(_.cancel())
} else if (registrationAttempts <= TaskManager.MAX_REGISTRATION_ATTEMPTS) {
}
else if (registrationAttempts <= TaskManager.MAX_REGISTRATION_ATTEMPTS) {

log.info(s"Try to register at master ${jobManagerAkkaURL}. ${registrationAttempts}. " +
s"Attempt")
val jobManager = context.actorSelection(jobManagerAkkaURL)

jobManager ! RegisterTaskManager(connectionInfo, hardwareDescription, numberOfSlots)
} else {
}
else {
log.error("TaskManager could not register at JobManager.");
self ! PoisonPill
}
Expand Down Expand Up @@ -212,6 +214,10 @@ class TaskManager(val connectionInfo: InstanceConnectionInfo, val jobManagerAkka
waitForRegistration.clear()
}
}

case SubmitTask(tdd) => {
submitTask(tdd)
}

case CancelTask(executionID) => {
runningTasks.get(executionID) match {
Expand Down Expand Up @@ -502,7 +508,8 @@ class TaskManager(val connectionInfo: InstanceConnectionInfo, val jobManagerAkka
}

/**
* TaskManager companion object. Contains TaskManager executable entry point, command line parsing, and constants.
* TaskManager companion object. Contains TaskManager executable entry point, command
* line parsing, and constants.
*/
object TaskManager {

Expand Down Expand Up @@ -749,7 +756,9 @@ object TaskManager {
s"NON HEAP: $nonHeapUsed/$nonHeapCommitted/$nonHeapMax MB (used/committed/max)]"
}

private def getGarbageCollectorStatsAsString(gcMXBeans: Iterable[GarbageCollectorMXBean]): String = {
private def getGarbageCollectorStatsAsString(gcMXBeans: Iterable[GarbageCollectorMXBean])
: String =
{
val beans = gcMXBeans map {
bean =>
s"[${bean.getName}, GC TIME (ms): ${bean.getCollectionTime}, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import akka.pattern.Patterns;
import akka.testkit.JavaTestKit;
import akka.util.Timeout;

import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.GlobalConfiguration;
Expand All @@ -46,15 +47,16 @@
import org.apache.flink.runtime.messages.JobManagerMessages;
import org.apache.flink.runtime.messages.RegistrationMessages;
import org.apache.flink.runtime.messages.TaskManagerMessages.CancelTask;
import org.apache.flink.runtime.messages.TaskManagerMessages.NotifyWhenRegisteredAtJobManager$;
import org.apache.flink.runtime.messages.TaskManagerMessages.SubmitTask;
import org.apache.flink.runtime.messages.TaskManagerMessages.TaskOperationResult;
import org.apache.flink.runtime.messages.TaskManagerMessages;
import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages;
import org.apache.flink.runtime.testingUtils.TestingUtils;
import org.apache.flink.runtime.types.IntegerRecord;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import scala.concurrent.Await;
import scala.concurrent.Future;
import scala.concurrent.duration.FiniteDuration;
Expand Down Expand Up @@ -162,8 +164,8 @@ protected void run() {

expectMsgEquals(new TaskOperationResult(eid1, true));
expectMsgEquals(new TaskOperationResult(eid2, true));

tm.tell(TestingTaskManagerMessages.RequestRunningTasks$.MODULE$, getRef());
tm.tell(TestingTaskManagerMessages.getRequestRunningTasksMessage(), getRef());

Map<ExecutionAttemptID, Task> runningTasks = expectMsgClass(TestingTaskManagerMessages
.ResponseRunningTasks.class).asJava();
Expand All @@ -187,7 +189,7 @@ protected void run() {

assertEquals(ExecutionState.CANCELED, t1.getExecutionState());

tm.tell(TestingTaskManagerMessages.RequestRunningTasks$.MODULE$, getRef());
tm.tell(TestingTaskManagerMessages.getRequestRunningTasksMessage(), getRef());
runningTasks = expectMsgClass(TestingTaskManagerMessages
.ResponseRunningTasks.class).asJava();

Expand All @@ -206,7 +208,7 @@ protected void run() {

assertEquals(ExecutionState.CANCELED, t2.getExecutionState());

tm.tell(TestingTaskManagerMessages.RequestRunningTasks$.MODULE$, getRef());
tm.tell(TestingTaskManagerMessages.getRequestRunningTasksMessage(), getRef());
runningTasks = expectMsgClass(TestingTaskManagerMessages
.ResponseRunningTasks.class).asJava();

Expand Down Expand Up @@ -276,7 +278,7 @@ protected void run() {
expectMsgEquals(true);
expectMsgEquals(true);

tm.tell(TestingTaskManagerMessages.RequestRunningTasks$.MODULE$, getRef());
tm.tell(TestingTaskManagerMessages.getRequestRunningTasksMessage(), getRef());
Map<ExecutionAttemptID, Task> tasks = expectMsgClass(TestingTaskManagerMessages
.ResponseRunningTasks.class).asJava();

Expand Down Expand Up @@ -337,7 +339,7 @@ protected void run() {
tm.tell(new SubmitTask(tdd1), getRef());
expectMsgEquals(new TaskOperationResult(eid1, true));

tm.tell(TestingTaskManagerMessages.RequestRunningTasks$.MODULE$, getRef());
tm.tell(TestingTaskManagerMessages.getRequestRunningTasksMessage(), getRef());
Map<ExecutionAttemptID, Task> tasks = expectMsgClass(TestingTaskManagerMessages.ResponseRunningTasks
.class).asJava();

Expand All @@ -359,7 +361,7 @@ protected void run() {
assertEquals(ExecutionState.FINISHED, t2.getExecutionState());
}

tm.tell(TestingTaskManagerMessages.RequestRunningTasks$.MODULE$, getRef());
tm.tell(TestingTaskManagerMessages.getRequestRunningTasksMessage(), getRef());
tasks = expectMsgClass(TestingTaskManagerMessages.ResponseRunningTasks
.class).asJava();

Expand Down Expand Up @@ -424,7 +426,7 @@ protected void run() {
expectMsgEquals(new TaskOperationResult(eid2, true));
expectMsgEquals(new TaskOperationResult(eid1, true));

tm.tell(TestingTaskManagerMessages.RequestRunningTasks$.MODULE$, getRef());
tm.tell(TestingTaskManagerMessages.getRequestRunningTasksMessage(), getRef());
Map<ExecutionAttemptID, Task> tasks = expectMsgClass(TestingTaskManagerMessages
.ResponseRunningTasks.class).asJava();

Expand All @@ -450,7 +452,7 @@ protected void run() {
Await.ready(response, d);
}

tm.tell(TestingTaskManagerMessages.RequestRunningTasks$.MODULE$, getRef());
tm.tell(TestingTaskManagerMessages.getRequestRunningTasksMessage(), getRef());
tasks = expectMsgClass(TestingTaskManagerMessages
.ResponseRunningTasks.class).asJava();

Expand Down Expand Up @@ -514,6 +516,7 @@ public void onReceive(Object message) throws Exception{
}
}

@SuppressWarnings("serial")
public static class SimpleLookupJobManagerCreator implements Creator<SimpleLookupJobManager>{
private final ChannelID receiverID;

Expand All @@ -527,6 +530,7 @@ public SimpleLookupJobManager create() throws Exception {
}
}

@SuppressWarnings("serial")
public static class SimpleLookupFailingUpdateJobManagerCreator implements
Creator<SimpleLookupFailingUpdateJobManager>{
private final ChannelID receiverID;
Expand All @@ -550,8 +554,8 @@ public static ActorRef createTaskManager(ActorRef jm) {

ActorRef taskManager = TestingUtils.startTestingTaskManagerWithConfiguration("localhost", cfg, system);

Future<Object> response = Patterns.ask(taskManager, NotifyWhenRegisteredAtJobManager$.MODULE$,
timeout);
Future<Object> response = Patterns.ask(taskManager,
TaskManagerMessages.getNotifyWhenRegisteredAtJobManagerMessage(), timeout);

try {
FiniteDuration d = new FiniteDuration(20, TimeUnit.SECONDS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,30 @@ trait TestingTaskManager extends ActorLogMessages {
}

def receiveTestMessages: Receive = {

case RequestRunningTasks =>
sender ! ResponseRunningTasks(runningTasks.toMap)

case NotifyWhenTaskRemoved(executionID) =>
runningTasks.get(executionID) match {
case Some(_) =>
val set = waitForRemoval.getOrElse(executionID, Set())
waitForRemoval += (executionID -> (set + sender))
case None => sender ! true
}

case UnregisterTask(executionID) =>
super.receiveWithLogMessages(UnregisterTask(executionID))
waitForRemoval.get(executionID) match {
case Some(actors) => for(actor <- actors) actor ! true
case None =>
}

case RequestBroadcastVariablesWithReferences => {
sender ! ResponseBroadcastVariablesWithReferences(
bcVarManager.getNumberOfVariablesWithReferences)
}

case NotifyWhenJobRemoved(jobID) => {
if(runningTasks.values.exists(_.getJobID == jobID)){
val set = waitForJobRemoval.getOrElse(jobID, Set())
Expand All @@ -71,13 +76,14 @@ trait TestingTaskManager extends ActorLogMessages {
}
}
}

case CheckIfJobRemoved(jobID) => {
if(runningTasks.values.forall(_.getJobID != jobID)){
waitForJobRemoval.get(jobID) match {
case Some(listeners) => listeners foreach (_ ! true)
case None =>
}
}else{
} else {
import context.dispatcher
context.system.scheduler.scheduleOnce(200 milliseconds, this.self, CheckIfJobRemoved(jobID))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,36 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID
import org.apache.flink.runtime.jobgraph.JobID
import org.apache.flink.runtime.taskmanager.Task

object TestingTaskManagerMessages{
/**
* Additional messages that the [[TestingTaskManager]] understands.
*/
object TestingTaskManagerMessages {

case class NotifyWhenTaskRemoved(executionID: ExecutionAttemptID)

case object RequestRunningTasks

case class ResponseRunningTasks(tasks: Map[ExecutionAttemptID, Task]){
import collection.JavaConverters._
def asJava: java.util.Map[ExecutionAttemptID, Task] = tasks.asJava
}
case object RequestBroadcastVariablesWithReferences

case class ResponseBroadcastVariablesWithReferences(number: Int)

case class CheckIfJobRemoved(jobID: JobID)

case object RequestRunningTasks

case object RequestBroadcastVariablesWithReferences

// --------------------------------------------------------------------------
// Utility methods to allow simpler case object access from Java
// --------------------------------------------------------------------------

def getRequestRunningTasksMessage() : AnyRef = {
RequestRunningTasks
}

def getRequestBroadcastVariablesWithReferencesMessage() : AnyRef = {
RequestBroadcastVariablesWithReferences
}
}

0 comments on commit cec30ff

Please sign in to comment.