Skip to content

Commit

Permalink
[FLINK-12881][ml] Add more functionalities for ML Params and ParamInf…
Browse files Browse the repository at this point in the history
…o class

Add more functionalities, including the support of aliases, the config of size/clear/isEmpty/contains/fromJason in Params

This closes apache#8776
  • Loading branch information
xuyang1706 authored and shaoxuan-wang committed Jun 30, 2019
1 parent 4c93e68 commit 1660c6b
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@

package org.apache.flink.ml.api.core;

import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.WithParams;
import org.apache.flink.ml.util.param.ExtractParamInfosUtil;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* Base class for a stage in a pipeline. The interface is only a concept, and does not have any
Expand All @@ -46,11 +41,6 @@ default String toJson() {
}

default void loadJson(String json) {
List<ParamInfo> paramInfos = ExtractParamInfosUtil.extractParamInfos(this);
Map<String, Class<?>> classMap = new HashMap<>();
for (ParamInfo i : paramInfos) {
classMap.put(i.getName(), i.getValueClass());
}
getParams().loadJson(json, classMap);
getParams().loadJson(json);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,26 @@
package org.apache.flink.ml.api.misc.param;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.util.Preconditions;

/**
* Definition of a parameter, including name, type, default value, validator and so on.
*
* <p>This class is provided to unify the interaction with parameters.
* <p>A parameter can either be optional or non-optional.
* <ul>
* <li>
* A non-optional parameter should not have a default value. Instead, its value must be provided by the users.
* </li>
* <li>
* An optional parameter may or may not have a default value.
* </li>
* </ul>
*
* <p>Please see {@link Params#get(ParamInfo)} and {@link Params#contains(ParamInfo)} for more details about the behavior.
*
* <p>A parameter may have aliases in addition to the parameter name for convenience and compatibility purposes. One
* should not set values for both parameter name and an alias. One and only one value should be set either under
* the parameter name or one of the alias.
*
* @param <V> the type of the param value
*/
Expand Down Expand Up @@ -67,6 +82,7 @@ public String getName() {
* @return the aliases of the parameter
*/
public String[] getAlias() {
Preconditions.checkNotNull(alias);
return alias;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,104 @@

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* The map-like container class for parameter. This class is provided to unify the interaction with
* parameters.
*/
@PublicEvolving
public class Params implements Serializable {
private final Map<String, Object> paramMap = new HashMap<>();
public class Params implements Serializable, Cloneable {
private static final long serialVersionUID = 1L;

/**
* A mapping from param name to its value.
*
* <p>The value is stored in map using json format.
*/
private final Map<String, String> params;

private transient ObjectMapper mapper;

public Params() {
this.params = new HashMap<>();
}

/**
* Return the number of params.
*
* @return Return the number of params.
*/
public int size() {
return params.size();
}

/**
* Removes all of the params.
* The params will be empty after this call returns.
*/
public void clear() {
params.clear();
}

/**
* Returns <tt>true</tt> if this params contains no mappings.
*
* @return <tt>true</tt> if this map contains no mappings
*/
public boolean isEmpty() {
return params.isEmpty();
}

/**
* Returns the value of the specific parameter, or default value defined in the {@code info} if
* this Params doesn't contain the param.
* this Params doesn't have a value set for the parameter. An exception will be thrown in the
* following cases because no value could be found for the specified parameter.
* <ul>
* <li>
* Non-optional parameter: no value is defined in this params for a non-optional parameter.
* </li>
* <li>
* Optional parameter: no value is defined in this params and no default value is defined.
* </li>
* </ul>
*
* @param info the info of the specific parameter, usually with default value
* @param <V> the type of the specific parameter
* @return the value of the specific parameter, or default value defined in the {@code info} if
* this Params doesn't contain the parameter
* @throws RuntimeException if the Params doesn't contains the specific parameter, while the
* param is not optional but has no default value in the {@code info}
* @throws IllegalArgumentException if no value can be found for specified parameter
*/
@SuppressWarnings("unchecked")
public <V> V get(ParamInfo<V> info) {
V value = (V) paramMap.getOrDefault(info.getName(), info.getDefaultValue());
if (value == null && !info.isOptional() && !info.hasDefaultValue()) {
throw new RuntimeException(info.getName() +
" not exist which is not optional and don't have a default value");
String value = null;
String usedParamName = null;
for (String nameOrAlias : getParamNameAndAlias(info)) {
if (params.containsKey(nameOrAlias)) {
if (usedParamName != null) {
throw new IllegalArgumentException(String.format("Duplicate parameters of %s and %s",
usedParamName, nameOrAlias));
}
usedParamName = nameOrAlias;
value = params.get(nameOrAlias);
}
}

if (usedParamName != null) {
// The param value was set by the user.
return valueFromJson(value, info.getValueClass());
} else {
// The param value was not set by the user.
if (!info.isOptional()) {
throw new IllegalArgumentException("Missing non-optional parameter " + info.getName());
} else if (!info.hasDefaultValue()) {
throw new IllegalArgumentException("Cannot find default value for optional parameter " + info.getName());
}
return info.getDefaultValue();
}
return value;
}

/**
Expand All @@ -69,20 +137,11 @@ public <V> V get(ParamInfo<V> info) {
* evaluated as illegal by the validator
*/
public <V> Params set(ParamInfo<V> info, V value) {
if (!info.isOptional() && value == null) {
throw new RuntimeException(
"Setting " + info.getName() + " as null while it's not a optional param");
}
if (value == null) {
remove(info);
return this;
}

if (info.getValidator() != null && !info.getValidator().validate(value)) {
throw new RuntimeException(
"Setting " + info.getName() + " as a invalid value:" + value);
}
paramMap.put(info.getName(), value);
params.put(info.getName(), valueToJson(value));
return this;
}

Expand All @@ -93,18 +152,20 @@ public <V> Params set(ParamInfo<V> info, V value) {
* @param <V> the type of the specific parameter
*/
public <V> void remove(ParamInfo<V> info) {
paramMap.remove(info.getName());
params.remove(info.getName());
for (String a : info.getAlias()) {
params.remove(a);
}
}

/**
* Creates and returns a deep clone of this Params.
* Check whether this params has a value set for the given {@code info}.
*
* @return a deep clone of this Params
* @return <tt>true</tt> if this params has a value set for the specified {@code info}, false otherwise.
*/
public Params clone() {
Params newParams = new Params();
newParams.paramMap.putAll(this.paramMap);
return newParams;
public <V> boolean contains(ParamInfo<V> info) {
return params.containsKey(info.getName()) ||
Arrays.stream(info.getAlias()).anyMatch(params::containsKey);
}

/**
Expand All @@ -114,38 +175,104 @@ public Params clone() {
* @return a json containing all parameters in this Params
*/
public String toJson() {
ObjectMapper mapper = new ObjectMapper();
Map<String, String> stringMap = new HashMap<>();
assertMapperInited();
try {
for (Map.Entry<String, Object> e : paramMap.entrySet()) {
stringMap.put(e.getKey(), mapper.writeValueAsString(e.getValue()));
}
return mapper.writeValueAsString(stringMap);
return mapper.writeValueAsString(params);
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize params to json", e);
}
}

/**
* Restores the parameters from the given json. The parameters should be exactly the same with
* the one who was serialized to the input json after the restoration. The class mapping of the
* parameters in the json is required because it is hard to directly restore a param of a user
* defined type. Params will be treated as String if it doesn't exist in the {@code classMap}.
* the one who was serialized to the input json after the restoration.
*
* @param json the json String to restore from
* @param classMap the classes of the parameters contained in the json
* @param json the json String to restore from
*/
@SuppressWarnings("unchecked")
public void loadJson(String json, Map<String, Class<?>> classMap) {
ObjectMapper mapper = new ObjectMapper();
public void loadJson(String json) {
assertMapperInited();
Map<String, String> params;
try {
params = mapper.readValue(json, Map.class);
} catch (IOException e) {
throw new RuntimeException("Failed to deserialize json:" + json, e);
}
this.params.putAll(params);
}

/**
* Factory method for constructing params.
*
* @param json the json string to load
* @return the {@code Params} loaded from the json string.
*/
public static Params fromJson(String json) {
Params params = new Params();
params.loadJson(json);
return params;
}

/**
* Merge other params into this.
*
* @param otherParams other params
* @return this
*/
public Params merge(Params otherParams) {
if (otherParams != null) {
this.params.putAll(otherParams.params);
}
return this;
}

/**
* Creates and returns a deep clone of this Params.
*
* @return a deep clone of this Params
*/
@Override
public Params clone() {
Params newParams = new Params();
newParams.params.putAll(this.params);
return newParams;
}

private void assertMapperInited() {
if (mapper == null) {
mapper = new ObjectMapper();
}
}

private String valueToJson(Object value) {
assertMapperInited();
try {
Map<String, String> m = mapper.readValue(json, Map.class);
for (Map.Entry<String, String> e : m.entrySet()) {
Class<?> valueClass = classMap.getOrDefault(e.getKey(), String.class);
paramMap.put(e.getKey(), mapper.readValue(e.getValue(), valueClass));
if (value == null) {
return null;
}
return mapper.writeValueAsString(value);
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize to json:" + value, e);
}
}

private <T> T valueFromJson(String json, Class<T> clazz) {
assertMapperInited();
try {
if (json == null) {
return null;
}
return mapper.readValue(json, clazz);
} catch (IOException e) {
throw new RuntimeException("Failed to deserialize json:" + json, e);
}
}

private <V> List<String> getParamNameAndAlias(
ParamInfo <V> info) {
List<String> paramNames = new ArrayList<>(info.getAlias().length + 1);
paramNames.add(info.getName());
paramNames.addAll(Arrays.asList(info.getAlias()));
return paramNames;
}
}
Loading

0 comments on commit 1660c6b

Please sign in to comment.