Skip to content

Commit

Permalink
[FLINK-13751][ml] Add TypeInformation of built-in vector types
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyang1706 authored and becketqin committed Aug 28, 2019
1 parent c038b92 commit 11258f3
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.flink.ml.common.utils;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.common.linalg.DenseVector;
import org.apache.flink.ml.common.linalg.SparseVector;
import org.apache.flink.ml.common.linalg.Vector;

/**
* Built-in vector types.
*/
public class VectorTypes {
/**
* <code>DenseVector</code> type information.
*/
public static final TypeInformation<DenseVector> DENSE_VECTOR = TypeInformation.of(DenseVector.class);

/**
* <code>SparseVector</code> type information.
*/
public static final TypeInformation<SparseVector> SPARSE_VECTOR = TypeInformation.of(SparseVector.class);

/**
* <code>Vector</code> type information.
* For efficiency, use type information of sub-class <code>DenseVector</code> and <code>SparseVector</code>
* as much as possible. When an operator output both sub-class type of vectors, use this one.
*/
public static final TypeInformation<Vector> VECTOR = TypeInformation.of(Vector.class);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.flink.ml.common.utils;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputDeserializer;
import org.apache.flink.core.memory.DataOutputSerializer;
import org.apache.flink.ml.common.linalg.DenseVector;
import org.apache.flink.ml.common.linalg.SparseVector;
import org.apache.flink.ml.common.linalg.Vector;

import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;
import java.util.HashMap;
import java.util.concurrent.ThreadLocalRandom;

/**
* Test cases for VectorTypes.
*/
public class VectorTypesTest {
@SuppressWarnings("unchecked")
private static <V extends Vector> void doVectorSerDeserTest(TypeSerializer ser, V vector) throws IOException {
DataOutputSerializer out = new DataOutputSerializer(1024);
ser.serialize(vector, out);
DataInputDeserializer in = new DataInputDeserializer(out.getCopyOfBuffer());
Vector deserialize = (Vector) ser.deserialize(in);
Assert.assertEquals(vector.getClass(), deserialize.getClass());
Assert.assertEquals(vector, deserialize);
}

@Test
public void testVectorsSerDeser() throws IOException {
// Prepare data
SparseVector sparseVector = new SparseVector(10, new HashMap<Integer, Double>() {{
ThreadLocalRandom rand = ThreadLocalRandom.current();
for (int i = 0; i < 10; i += 2) {
this.put(i, rand.nextDouble());
}
}});
DenseVector denseVector = DenseVector.rand(10);

// Prepare serializer
ExecutionConfig config = new ExecutionConfig();
TypeSerializer<Vector> vecSer = VectorTypes.VECTOR.createSerializer(config);
TypeSerializer<SparseVector> sparseSer = VectorTypes.SPARSE_VECTOR.createSerializer(config);
TypeSerializer<DenseVector> denseSer = VectorTypes.DENSE_VECTOR.createSerializer(config);

// Do tests.
doVectorSerDeserTest(vecSer, sparseVector);
doVectorSerDeserTest(vecSer, denseVector);
doVectorSerDeserTest(sparseSer, sparseVector);
doVectorSerDeserTest(denseSer, denseVector);
}
}

0 comments on commit 11258f3

Please sign in to comment.