diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/ResourceSpec.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/ResourceSpec.java index f87d9975257f5..4554b544ea11b 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/ResourceSpec.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/ResourceSpec.java @@ -22,6 +22,7 @@ import org.apache.flink.util.Preconditions; import javax.annotation.Nonnull; + import java.io.Serializable; import java.util.HashMap; import java.util.Map; @@ -32,9 +33,9 @@ /** * Describe the different resource factors of the operator with UDF. * - * The state backend provides the method to estimate memory usages based on state size in the resource. + *

The state backend provides the method to estimate memory usages based on state size in the resource. * - * Resource provides {@link #merge(ResourceSpec)} method for chained operators when generating job graph. + *

Resource provides {@link #merge(ResourceSpec)} method for chained operators when generating job graph. * *

Resource provides {@link #lessThanOrEqual(ResourceSpec)} method to compare these fields in sequence: *

    @@ -53,21 +54,21 @@ public class ResourceSpec implements Serializable { public static final ResourceSpec DEFAULT = new ResourceSpec(0, 0, 0, 0, 0); - private static String GPU_NAME = "GPU"; + private static final String GPU_NAME = "GPU"; - /** How many cpu cores are needed, use double so we can specify cpu like 0.1 */ + /** How many cpu cores are needed, use double so we can specify cpu like 0.1. */ private final double cpuCores; - /** How many java heap memory in mb are needed */ + /** How many java heap memory in mb are needed. */ private final int heapMemoryInMB; - /** How many nio direct memory in mb are needed */ + /** How many nio direct memory in mb are needed. */ private final int directMemoryInMB; - /** How many native memory in mb are needed */ + /** How many native memory in mb are needed. */ private final int nativeMemoryInMB; - /** How many state size in mb are used */ + /** How many state size in mb are used. */ private final int stateSizeInMB; private final Map extendedResources = new HashMap<>(1); @@ -239,8 +240,13 @@ public String toString() { '}'; } - public static Builder newBuilder() { return new Builder(); } + public static Builder newBuilder() { + return new Builder(); + } + /** + * Builder for the {@link ResourceSpec}. + */ public static class Builder { public double cpuCores; @@ -275,28 +281,40 @@ public Builder setStateSizeInMB(int stateSize) { return this; } - public Builder setGPUResource(GPUResource gpuResource) { - this.gpuResource = gpuResource; + public Builder setGPUResource(double gpus) { + this.gpuResource = new GPUResource(gpus); return this; } public ResourceSpec build() { - return new ResourceSpec(cpuCores, heapMemoryInMB, directMemoryInMB, nativeMemoryInMB, stateSizeInMB, gpuResource); + return new ResourceSpec( + cpuCores, + heapMemoryInMB, + directMemoryInMB, + nativeMemoryInMB, + stateSizeInMB, + gpuResource); } } - public static abstract class Resource implements Serializable { + /** + * Base class for additional resources one can specify. + */ + protected abstract static class Resource implements Serializable { private static final long serialVersionUID = 1L; + /** + * Enum defining how resources are aggregated. + */ public enum ResourceAggregateType { /** - * Denotes keeping the sum of the values with same name when merging two resource specs for operator chaining + * Denotes keeping the sum of the values with same name when merging two resource specs for operator chaining. */ AGGREGATE_TYPE_SUM, /** - * Denotes keeping the max of the values with same name when merging two resource specs for operator chaining + * Denotes keeping the max of the values with same name when merging two resource specs for operator chaining. */ AGGREGATE_TYPE_MAX } @@ -305,7 +323,7 @@ public enum ResourceAggregateType { private final double value; - final private ResourceAggregateType type; + private final ResourceAggregateType type; public Resource(String name, double value, ResourceAggregateType type) { this.name = checkNotNull(name); @@ -348,14 +366,14 @@ public boolean equals(Object o) { @Override public int hashCode() { - int result = name != null ? name.hashCode() : 0; + int result = name.hashCode(); result = 31 * result + type.ordinal(); - result = 31 * result + (int)value; + result = 31 * result + (int) value; return result; } /** - * Create a resource of the same resource type + * Create a resource of the same resource type. * * @param value The value of the resource * @param type The aggregate type of the resource @@ -369,6 +387,8 @@ public int hashCode() { */ public static class GPUResource extends Resource { + private static final long serialVersionUID = -2276080061777135142L; + public GPUResource(double value) { this(value, ResourceAggregateType.AGGREGATE_TYPE_SUM); } diff --git a/flink-core/src/test/java/org/apache/flink/api/common/operators/ResourceSpecTest.java b/flink-core/src/test/java/org/apache/flink/api/common/operators/ResourceSpecTest.java index 5dfe5d00adc79..5f1e7d1d9dcbf 100644 --- a/flink-core/src/test/java/org/apache/flink/api/common/operators/ResourceSpecTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/common/operators/ResourceSpecTest.java @@ -20,12 +20,12 @@ import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.TestLogger; + import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; /** * Tests for ResourceSpec class, including its all public api: isValid, lessThanOrEqual, equals, hashCode and merge. @@ -40,14 +40,14 @@ public void testIsValid() throws Exception { rs = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(1)). + setGPUResource(1). build(); assertTrue(rs.isValid()); rs = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(-1)). + setGPUResource(-1). build(); assertFalse(rs.isValid()); } @@ -62,7 +62,7 @@ public void testLessThanOrEqual() throws Exception { ResourceSpec rs3 = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(1.1)). + setGPUResource(1.1). build(); assertTrue(rs1.lessThanOrEqual(rs3)); assertFalse(rs3.lessThanOrEqual(rs1)); @@ -70,7 +70,7 @@ public void testLessThanOrEqual() throws Exception { ResourceSpec rs4 = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(2.2)). + setGPUResource(2.2). build(); assertFalse(rs4.lessThanOrEqual(rs3)); assertTrue(rs3.lessThanOrEqual(rs4)); @@ -86,19 +86,19 @@ public void testEquals() throws Exception { ResourceSpec rs3 = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(2.2)). + setGPUResource(2.2). build(); ResourceSpec rs4 = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(1)). + setGPUResource(1). build(); assertFalse(rs3.equals(rs4)); ResourceSpec rs5 = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(2.2)). + setGPUResource(2.2). build(); assertTrue(rs3.equals(rs5)); } @@ -112,28 +112,21 @@ public void testHashCode() throws Exception { ResourceSpec rs3 = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(2.2)). + setGPUResource(2.2). build(); ResourceSpec rs4 = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(1)). + setGPUResource(1). build(); assertFalse(rs3.hashCode() == rs4.hashCode()); ResourceSpec rs5 = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(2.2)). + setGPUResource(2.2). build(); assertEquals(rs3.hashCode(), rs5.hashCode()); - - ResourceSpec rs6 = ResourceSpec.newBuilder(). - setCpuCores(1.0). - setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(2.2, ResourceSpec.Resource.ResourceAggregateType.AGGREGATE_TYPE_MAX)). - build(); - assertFalse(rs6.hashCode() == rs5.hashCode()); } @Test @@ -141,7 +134,7 @@ public void testMerge() throws Exception { ResourceSpec rs1 = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(1.1)). + setGPUResource(1.1). build(); ResourceSpec rs2 = ResourceSpec.newBuilder().setCpuCores(1.0).setHeapMemoryInMB(100).build(); @@ -150,26 +143,6 @@ public void testMerge() throws Exception { ResourceSpec rs4 = rs1.merge(rs3); assertEquals(2.2, rs4.getGPUResource(), 0.000001); - - ResourceSpec rs5 = ResourceSpec.newBuilder(). - setCpuCores(1.0). - setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(1.1, ResourceSpec.Resource.ResourceAggregateType.AGGREGATE_TYPE_MAX)). - build(); - try { - rs4.merge(rs5); - fail("Merge with different aggregate type should fail"); - } catch (IllegalArgumentException ignored) { - } - - ResourceSpec rs6 = ResourceSpec.newBuilder(). - setCpuCores(1.0). - setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(1.5, ResourceSpec.Resource.ResourceAggregateType.AGGREGATE_TYPE_MAX)). - build(); - ResourceSpec rs7 = rs5.merge(rs6); - assertEquals(1.5, rs7.getGPUResource(), 0.000001); - } @Test @@ -177,7 +150,7 @@ public void testSerializable() throws Exception { ResourceSpec rs1 = ResourceSpec.newBuilder(). setCpuCores(1.0). setHeapMemoryInMB(100). - setGPUResource(new ResourceSpec.GPUResource(1.1)). + setGPUResource(1.1). build(); byte[] buffer = InstantiationUtil.serializeObject(rs1); ResourceSpec rs2 = InstantiationUtil.deserializeObject(buffer, ClassLoader.getSystemClassLoader());