Skip to content

Commit

Permalink
[hotfix] Change signature of MemorySegmentProvider#requestMemorySegme…
Browse files Browse the repository at this point in the history
…nts from requestMemorySegments() to requestMemorySegments(int)

This change makes the interface more flexible and decouples NetworkBufferPool from the concept of exclusive buffer.

This close apache#12994.
  • Loading branch information
wsry authored and zhijiangW committed Aug 13, 2020
1 parent 444fc34 commit a3ba0eb
Show file tree
Hide file tree
Showing 24 changed files with 109 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* The provider used for requesting and releasing batch of memory segments.
*/
public interface MemorySegmentProvider {
Collection<MemorySegment> requestMemorySegments() throws IOException;
Collection<MemorySegment> requestMemorySegments(int numberOfSegmentsToRequest) throws IOException;

void recycleMemorySegments(Collection<MemorySegment> segments) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ static NettyShuffleEnvironment createNettyShuffleEnvironment(
NetworkBufferPool networkBufferPool = new NetworkBufferPool(
config.numNetworkBuffers(),
config.networkBufferSize(),
config.networkBuffersPerChannel(),
config.getRequestSegmentsTimeout());

registerShuffleMetrics(metricGroup, networkBufferPool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,13 @@ public class NetworkBufferPool implements BufferPoolFactory, MemorySegmentProvid

private int numTotalRequiredBuffers;

private final int numberOfSegmentsToRequest;

private final Duration requestSegmentsTimeout;

private final AvailabilityHelper availabilityHelper = new AvailabilityHelper();

@VisibleForTesting
public NetworkBufferPool(int numberOfSegmentsToAllocate, int segmentSize, int numberOfSegmentsToRequest) {
this(numberOfSegmentsToAllocate, segmentSize, numberOfSegmentsToRequest, Duration.ofMillis(Integer.MAX_VALUE));
public NetworkBufferPool(int numberOfSegmentsToAllocate, int segmentSize) {
this(numberOfSegmentsToAllocate, segmentSize, Duration.ofMillis(Integer.MAX_VALUE));
}

/**
Expand All @@ -94,14 +92,10 @@ public NetworkBufferPool(int numberOfSegmentsToAllocate, int segmentSize, int nu
public NetworkBufferPool(
int numberOfSegmentsToAllocate,
int segmentSize,
int numberOfSegmentsToRequest,
Duration requestSegmentsTimeout) {
this.totalNumberOfMemorySegments = numberOfSegmentsToAllocate;
this.memorySegmentSize = segmentSize;

checkArgument(numberOfSegmentsToRequest > 0, "The number of required buffers should be larger than 0.");
this.numberOfSegmentsToRequest = numberOfSegmentsToRequest;

Preconditions.checkNotNull(requestSegmentsTimeout);
checkArgument(requestSegmentsTimeout.toMillis() > 0,
"The timeout for requesting exclusive buffers should be positive.");
Expand Down Expand Up @@ -161,13 +155,15 @@ public void recycle(MemorySegment segment) {
}

@Override
public List<MemorySegment> requestMemorySegments() throws IOException {
public List<MemorySegment> requestMemorySegments(int numberOfSegmentsToRequest) throws IOException {
checkArgument(numberOfSegmentsToRequest > 0, "Number of buffers to request must be larger than 0.");

synchronized (factoryLock) {
if (isDestroyed) {
throw new IllegalStateException("Network buffer pool has already been destroyed.");
}

tryRedistributeBuffers();
tryRedistributeBuffers(numberOfSegmentsToRequest);
}

final List<MemorySegment> segments = new ArrayList<>(numberOfSegmentsToRequest);
Expand Down Expand Up @@ -427,7 +423,7 @@ public void destroyAllBufferPools() {
}

// Must be called from synchronized block
private void tryRedistributeBuffers() throws IOException {
private void tryRedistributeBuffers(int numberOfSegmentsToRequest) throws IOException {
assert Thread.holdsLock(factoryLock);

if (numTotalRequiredBuffers + numberOfSegmentsToRequest > totalNumberOfMemorySegments) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ private boolean shouldContinueRequest(BufferPool bufferPool) {
/**
* Requests exclusive buffers from the provider.
*/
void requestExclusiveBuffers() throws IOException {
Collection<MemorySegment> segments = globalPool.requestMemorySegments();
void requestExclusiveBuffers(int numExclusiveBuffers) throws IOException {
Collection<MemorySegment> segments = globalPool.requestMemorySegments(numExclusiveBuffers);
checkArgument(!segments.isEmpty(), "The number of exclusive buffers per channel should be larger than 0.");

synchronized (bufferQueue) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void assignExclusiveSegments() throws IOException {
checkState(bufferManager.unsynchronizedGetAvailableExclusiveBuffers() == 0,
"Bug in input channel setup logic: exclusive buffers have already been set for this input channel.");

bufferManager.requestExclusiveBuffers();
bufferManager.requestExclusiveBuffers(initialCredit);
}

// ------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public InputChannel toInputChannel() throws IOException {
void assignExclusiveSegments() throws IOException {
checkState(!exclusiveBuffersAssigned, "Exclusive buffers should be assigned only once.");

bufferManager.requestExclusiveBuffers();
bufferManager.requestExclusiveBuffers(networkBuffersPerChannel);
exclusiveBuffersAssigned = true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public class RecordWriterDelegateTest extends TestLogger {

@Before
public void setup() {
globalPool = new NetworkBufferPool(numberOfBuffers, memorySegmentSize, numberOfSegmentsToRequest);
globalPool = new NetworkBufferPool(numberOfBuffers, memorySegmentSize);
}

@After
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ public void testBroadcastEmitRecord() throws Exception {
@Test
public void testIsAvailableOrNot() throws Exception {
// setup
final NetworkBufferPool globalPool = new NetworkBufferPool(10, 128, 2);
final NetworkBufferPool globalPool = new NetworkBufferPool(10, 128);
final BufferPool localPool = globalPool.createBufferPool(1, 1, null, 1, Integer.MAX_VALUE);
final ResultPartitionWriter resultPartition = new ResultPartitionBuilder()
.setBufferPoolFactory(p -> localPool)
Expand Down Expand Up @@ -455,7 +455,7 @@ public void testEmitRecordWithPartitionStateRecovery() throws Exception {
final int[] records = {5, 6, 7, 8};
final int bufferSize = states.length * Integer.BYTES;

final NetworkBufferPool globalPool = new NetworkBufferPool(totalBuffers, bufferSize, 1);
final NetworkBufferPool globalPool = new NetworkBufferPool(totalBuffers, bufferSize);
final ChannelStateReader stateReader = new ResultPartitionTest.FiniteChannelStateReader(totalStates, states);
final ResultPartition partition = new ResultPartitionBuilder()
.setNetworkBufferPool(globalPool)
Expand Down Expand Up @@ -507,7 +507,7 @@ public void testEmitRecordWithPartitionStateRecovery() throws Exception {
@Test
public void testIdleTime() throws IOException, InterruptedException {
// setup
final NetworkBufferPool globalPool = new NetworkBufferPool(10, 128, 2);
final NetworkBufferPool globalPool = new NetworkBufferPool(10, 128);
final BufferPool localPool = globalPool.createBufferPool(1, 1, null, 1, Integer.MAX_VALUE);
final ResultPartitionWriter resultPartition = new ResultPartitionBuilder()
.setBufferPoolFactory(p -> localPool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class BufferPoolFactoryTest {

@Before
public void setupNetworkBufferPool() {
networkBufferPool = new NetworkBufferPool(numBuffers, memorySegmentSize, 1);
networkBufferPool = new NetworkBufferPool(numBuffers, memorySegmentSize);
}

@After
Expand Down Expand Up @@ -244,7 +244,7 @@ public void testUniformDistributionBounded2() throws IOException {

@Test
public void testUniformDistributionBounded3() throws IOException {
NetworkBufferPool globalPool = new NetworkBufferPool(3, 128, 1);
NetworkBufferPool globalPool = new NetworkBufferPool(3, 128);
try {
BufferPool first = globalPool.createBufferPool(0, 10);
assertEquals(3, first.getNumBuffers());
Expand Down Expand Up @@ -277,25 +277,25 @@ public void testUniformDistributionBounded3() throws IOException {
*/
@Test
public void testUniformDistributionBounded4() throws IOException {
NetworkBufferPool globalPool = new NetworkBufferPool(10, 128, 2);
NetworkBufferPool globalPool = new NetworkBufferPool(10, 128);
try {
BufferPool first = globalPool.createBufferPool(0, 10);
assertEquals(10, first.getNumBuffers());

List<MemorySegment> segmentList1 = globalPool.requestMemorySegments();
List<MemorySegment> segmentList1 = globalPool.requestMemorySegments(2);
assertEquals(2, segmentList1.size());
assertEquals(8, first.getNumBuffers());

BufferPool second = globalPool.createBufferPool(0, 10);
assertEquals(4, first.getNumBuffers());
assertEquals(4, second.getNumBuffers());

List<MemorySegment> segmentList2 = globalPool.requestMemorySegments();
List<MemorySegment> segmentList2 = globalPool.requestMemorySegments(2);
assertEquals(2, segmentList2.size());
assertEquals(3, first.getNumBuffers());
assertEquals(3, second.getNumBuffers());

List<MemorySegment> segmentList3 = globalPool.requestMemorySegments();
List<MemorySegment> segmentList3 = globalPool.requestMemorySegments(2);
assertEquals(2, segmentList3.size());
assertEquals(2, first.getNumBuffers());
assertEquals(2, second.getNumBuffers());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public void testDestroyWhileBlockingRequest() throws Exception {
LocalBufferPool localBufferPool = null;

try {
networkBufferPool = new NetworkBufferPool(1, 4096, 1);
networkBufferPool = new NetworkBufferPool(1, 4096);
localBufferPool = new LocalBufferPool(networkBufferPool, 1);

// Drain buffer pool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public class LocalBufferPoolTest extends TestLogger {

@Before
public void setupLocalBufferPool() {
networkBufferPool = new NetworkBufferPool(numBuffers, memorySegmentSize, 1);
networkBufferPool = new NetworkBufferPool(numBuffers, memorySegmentSize);
localBufferPool = new LocalBufferPool(networkBufferPool, 1);

assertEquals(0, localBufferPool.getNumberOfAvailableMemorySegments());
Expand Down
Loading

0 comments on commit a3ba0eb

Please sign in to comment.