Skip to content

Commit

Permalink
[FLINK-7878] Hide GpuResource in ResourceSpec
Browse files Browse the repository at this point in the history
  • Loading branch information
tillrohrmann committed Dec 14, 2017
1 parent 5b9ac95 commit fba72d0
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.util.Preconditions;

import javax.annotation.Nonnull;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -32,9 +33,9 @@
/**
* Describe the different resource factors of the operator with UDF.
*
* The state backend provides the method to estimate memory usages based on state size in the resource.
* <p>The state backend provides the method to estimate memory usages based on state size in the resource.
*
* Resource provides {@link #merge(ResourceSpec)} method for chained operators when generating job graph.
* <p>Resource provides {@link #merge(ResourceSpec)} method for chained operators when generating job graph.
*
* <p>Resource provides {@link #lessThanOrEqual(ResourceSpec)} method to compare these fields in sequence:
* <ol>
Expand All @@ -53,21 +54,21 @@ public class ResourceSpec implements Serializable {

public static final ResourceSpec DEFAULT = new ResourceSpec(0, 0, 0, 0, 0);

private static String GPU_NAME = "GPU";
private static final String GPU_NAME = "GPU";

/** How many cpu cores are needed, use double so we can specify cpu like 0.1 */
/** How many cpu cores are needed, use double so we can specify cpu like 0.1. */
private final double cpuCores;

/** How many java heap memory in mb are needed */
/** How many java heap memory in mb are needed. */
private final int heapMemoryInMB;

/** How many nio direct memory in mb are needed */
/** How many nio direct memory in mb are needed. */
private final int directMemoryInMB;

/** How many native memory in mb are needed */
/** How many native memory in mb are needed. */
private final int nativeMemoryInMB;

/** How many state size in mb are used */
/** How many state size in mb are used. */
private final int stateSizeInMB;

private final Map<String, Resource> extendedResources = new HashMap<>(1);
Expand Down Expand Up @@ -239,8 +240,13 @@ public String toString() {
'}';
}

public static Builder newBuilder() { return new Builder(); }
public static Builder newBuilder() {
return new Builder();
}

/**
* Builder for the {@link ResourceSpec}.
*/
public static class Builder {

public double cpuCores;
Expand Down Expand Up @@ -275,28 +281,40 @@ public Builder setStateSizeInMB(int stateSize) {
return this;
}

public Builder setGPUResource(GPUResource gpuResource) {
this.gpuResource = gpuResource;
public Builder setGPUResource(double gpus) {
this.gpuResource = new GPUResource(gpus);
return this;
}

public ResourceSpec build() {
return new ResourceSpec(cpuCores, heapMemoryInMB, directMemoryInMB, nativeMemoryInMB, stateSizeInMB, gpuResource);
return new ResourceSpec(
cpuCores,
heapMemoryInMB,
directMemoryInMB,
nativeMemoryInMB,
stateSizeInMB,
gpuResource);
}
}

public static abstract class Resource implements Serializable {
/**
* Base class for additional resources one can specify.
*/
protected abstract static class Resource implements Serializable {

private static final long serialVersionUID = 1L;

/**
* Enum defining how resources are aggregated.
*/
public enum ResourceAggregateType {
/**
* Denotes keeping the sum of the values with same name when merging two resource specs for operator chaining
* Denotes keeping the sum of the values with same name when merging two resource specs for operator chaining.
*/
AGGREGATE_TYPE_SUM,

/**
* Denotes keeping the max of the values with same name when merging two resource specs for operator chaining
* Denotes keeping the max of the values with same name when merging two resource specs for operator chaining.
*/
AGGREGATE_TYPE_MAX
}
Expand All @@ -305,7 +323,7 @@ public enum ResourceAggregateType {

private final double value;

final private ResourceAggregateType type;
private final ResourceAggregateType type;

public Resource(String name, double value, ResourceAggregateType type) {
this.name = checkNotNull(name);
Expand Down Expand Up @@ -348,14 +366,14 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
int result = name != null ? name.hashCode() : 0;
int result = name.hashCode();
result = 31 * result + type.ordinal();
result = 31 * result + (int)value;
result = 31 * result + (int) value;
return result;
}

/**
* Create a resource of the same resource type
* Create a resource of the same resource type.
*
* @param value The value of the resource
* @param type The aggregate type of the resource
Expand All @@ -369,6 +387,8 @@ public int hashCode() {
*/
public static class GPUResource extends Resource {

private static final long serialVersionUID = -2276080061777135142L;

public GPUResource(double value) {
this(value, ResourceAggregateType.AGGREGATE_TYPE_SUM);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@

import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.TestLogger;

import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
* Tests for ResourceSpec class, including its all public api: isValid, lessThanOrEqual, equals, hashCode and merge.
Expand All @@ -40,14 +40,14 @@ public void testIsValid() throws Exception {
rs = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1)).
setGPUResource(1).
build();
assertTrue(rs.isValid());

rs = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(-1)).
setGPUResource(-1).
build();
assertFalse(rs.isValid());
}
Expand All @@ -62,15 +62,15 @@ public void testLessThanOrEqual() throws Exception {
ResourceSpec rs3 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1.1)).
setGPUResource(1.1).
build();
assertTrue(rs1.lessThanOrEqual(rs3));
assertFalse(rs3.lessThanOrEqual(rs1));

ResourceSpec rs4 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2)).
setGPUResource(2.2).
build();
assertFalse(rs4.lessThanOrEqual(rs3));
assertTrue(rs3.lessThanOrEqual(rs4));
Expand All @@ -86,19 +86,19 @@ public void testEquals() throws Exception {
ResourceSpec rs3 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2)).
setGPUResource(2.2).
build();
ResourceSpec rs4 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1)).
setGPUResource(1).
build();
assertFalse(rs3.equals(rs4));

ResourceSpec rs5 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2)).
setGPUResource(2.2).
build();
assertTrue(rs3.equals(rs5));
}
Expand All @@ -112,36 +112,29 @@ public void testHashCode() throws Exception {
ResourceSpec rs3 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2)).
setGPUResource(2.2).
build();
ResourceSpec rs4 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1)).
setGPUResource(1).
build();
assertFalse(rs3.hashCode() == rs4.hashCode());

ResourceSpec rs5 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2)).
setGPUResource(2.2).
build();
assertEquals(rs3.hashCode(), rs5.hashCode());

ResourceSpec rs6 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2, ResourceSpec.Resource.ResourceAggregateType.AGGREGATE_TYPE_MAX)).
build();
assertFalse(rs6.hashCode() == rs5.hashCode());
}

@Test
public void testMerge() throws Exception {
ResourceSpec rs1 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1.1)).
setGPUResource(1.1).
build();
ResourceSpec rs2 = ResourceSpec.newBuilder().setCpuCores(1.0).setHeapMemoryInMB(100).build();

Expand All @@ -150,34 +143,14 @@ public void testMerge() throws Exception {

ResourceSpec rs4 = rs1.merge(rs3);
assertEquals(2.2, rs4.getGPUResource(), 0.000001);

ResourceSpec rs5 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1.1, ResourceSpec.Resource.ResourceAggregateType.AGGREGATE_TYPE_MAX)).
build();
try {
rs4.merge(rs5);
fail("Merge with different aggregate type should fail");
} catch (IllegalArgumentException ignored) {
}

ResourceSpec rs6 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1.5, ResourceSpec.Resource.ResourceAggregateType.AGGREGATE_TYPE_MAX)).
build();
ResourceSpec rs7 = rs5.merge(rs6);
assertEquals(1.5, rs7.getGPUResource(), 0.000001);

}

@Test
public void testSerializable() throws Exception {
ResourceSpec rs1 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1.1)).
setGPUResource(1.1).
build();
byte[] buffer = InstantiationUtil.serializeObject(rs1);
ResourceSpec rs2 = InstantiationUtil.deserializeObject(buffer, ClassLoader.getSystemClassLoader());
Expand Down

0 comments on commit fba72d0

Please sign in to comment.