Skip to content

Commit

Permalink
[FLINK-18050][task][checkpointing] Use CloseableIterator to write Res…
Browse files Browse the repository at this point in the history
…ultSubpartition state

Currently, buffers passed to ChannelStateWriterImpl can be recycled
twice: once in normal case after writing; second in
CheckpointInProgressRequest.cancel (called from ChannelStateWriteRequestDispatcher
and other places).

This change prevents this by using CloseableIterator which distinguishes
used and unused elements.
  • Loading branch information
rkhachatryan authored and zhijiangW committed Jun 8, 2020
1 parent 44af789 commit ed7b0b1
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.BiConsumerWithException;
import org.apache.flink.util.function.ThrowingConsumer;

import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -31,6 +32,7 @@
import static org.apache.flink.runtime.checkpoint.channel.CheckpointInProgressRequestState.EXECUTING;
import static org.apache.flink.runtime.checkpoint.channel.CheckpointInProgressRequestState.FAILED;
import static org.apache.flink.runtime.checkpoint.channel.CheckpointInProgressRequestState.NEW;
import static org.apache.flink.util.CloseableIterator.ofElements;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;

Expand All @@ -48,8 +50,20 @@ static CheckpointInProgressRequest completeOutput(long checkpointId) {
}

static ChannelStateWriteRequest write(long checkpointId, InputChannelInfo info, CloseableIterator<Buffer> iterator) {
return buildWriteRequest(checkpointId, "writeInput", iterator, (writer, buffer) -> writer.writeInput(info, buffer));
}

static ChannelStateWriteRequest write(long checkpointId, ResultSubpartitionInfo info, Buffer... buffers) {
return buildWriteRequest(checkpointId, "writeOutput", ofElements(Buffer::recycleBuffer, buffers), (writer, buffer) -> writer.writeOutput(info, buffer));
}

static ChannelStateWriteRequest buildWriteRequest(
long checkpointId,
String name,
CloseableIterator<Buffer> iterator,
BiConsumerWithException<ChannelStateCheckpointWriter, Buffer, Exception> bufferConsumer) {
return new CheckpointInProgressRequest(
"writeInput",
name,
checkpointId,
writer -> {
while (iterator.hasNext()) {
Expand All @@ -60,17 +74,13 @@ static ChannelStateWriteRequest write(long checkpointId, InputChannelInfo info,
buffer.recycleBuffer();
throw e;
}
writer.writeInput(info, buffer);
bufferConsumer.accept(writer, buffer);
}
},
throwable -> iterator.close(),
false);
}

static ChannelStateWriteRequest write(long checkpointId, ResultSubpartitionInfo info, Buffer... flinkBuffers) {
return new CheckpointInProgressRequest("writeOutput", checkpointId, writer -> writer.writeOutput(info, flinkBuffers), recycle(flinkBuffers), false);
}

static ChannelStateWriteRequest start(long checkpointId, ChannelStateWriteResult targetResult, CheckpointStorageLocationReference locationReference) {
return new CheckpointStartRequest(checkpointId, targetResult, locationReference);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void addOutputData(long checkpointId, ResultSubpartitionInfo info, int st
info,
startSeqNum,
data == null ? 0 : data.length);
enqueue(write(checkpointId, info, checkBufferType(data)), false);
enqueue(write(checkpointId, info, data), false);
}

@Override
Expand Down Expand Up @@ -196,27 +196,6 @@ private void enqueue(ChannelStateWriteRequest request, boolean atTheFront) {
}
}

private static Buffer[] checkBufferType(Buffer... data) {
if (data == null) {
return new Buffer[0];
}
try {
for (Buffer buffer : data) {
if (!buffer.isBuffer()) {
throw new IllegalArgumentException(buildBufferTypeErrorMessage(buffer));
}
}
} catch (Exception e) {
for (Buffer buffer : data) {
if (buffer.isBuffer()) {
buffer.recycleBuffer();
}
}
throw e;
}
return data;
}

private static String buildBufferTypeErrorMessage(Buffer buffer) {
try {
AbstractEvent event = EventSerializer.fromBuffer(buffer, ChannelStateWriterImpl.class.getClassLoader());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.checkpoint.channel;

import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorage;

import org.junit.Test;

import java.util.function.Function;

import static org.apache.flink.util.CloseableIterator.ofElements;
import static org.junit.Assert.assertTrue;

/**
* {@link ChannelStateWriteRequestDispatcherImpl} test.
*/
public class ChannelStateWriteRequestDispatcherImplTest {

@Test
public void testPartialInputChannelStateWrite() throws Exception {
testBuffersRecycled(buffers -> ChannelStateWriteRequest.write(1L, new InputChannelInfo(1, 2), ofElements(Buffer::recycleBuffer, buffers)));
}

@Test
public void testPartialResultSubpartitionStateWrite() throws Exception {
testBuffersRecycled(buffers -> ChannelStateWriteRequest.write(1L, new ResultSubpartitionInfo(1, 2), buffers));
}

private void testBuffersRecycled(Function<NetworkBuffer[], ChannelStateWriteRequest> requestBuilder) throws Exception {
ChannelStateWriteRequestDispatcher dispatcher = new ChannelStateWriteRequestDispatcherImpl(new MemoryBackendCheckpointStorage(new JobID(), null, null, 1), new ChannelStateSerializerImpl());
ChannelStateWriteResult result = new ChannelStateWriteResult();
dispatcher.dispatch(ChannelStateWriteRequest.start(1L, result, CheckpointStorageLocationReference.getDefault()));

result.getResultSubpartitionStateHandles().completeExceptionally(new TestException());
result.getInputChannelStateHandles().completeExceptionally(new TestException());

NetworkBuffer[] buffers = new NetworkBuffer[]{buffer(), buffer()};
dispatcher.dispatch(requestBuilder.apply(buffers));
for (NetworkBuffer buffer : buffers) {
assertTrue(buffer.isRecycled());
}
}

private NetworkBuffer buffer() {
return new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(10), FreeingBufferRecycler.INSTANCE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ private static ChannelStateWriteRequest writeIn() {
}

private static ChannelStateWriteRequest writeOut() {
return write(CHECKPOINT_ID, new ResultSubpartitionInfo(1, 1));
return write(CHECKPOINT_ID, new ResultSubpartitionInfo(1, 1), new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(1), FreeingBufferRecycler.INSTANCE));
}

private static CheckpointStartRequest start() {
Expand Down

0 comments on commit ed7b0b1

Please sign in to comment.