diff --git a/flink-ml-parent/flink-ml-lib/pom.xml b/flink-ml-parent/flink-ml-lib/pom.xml
index bc5ca37b38c89..391b2acd86b12 100644
--- a/flink-ml-parent/flink-ml-lib/pom.xml
+++ b/flink-ml-parent/flink-ml-lib/pom.xml
@@ -34,5 +34,10 @@ under the License.
Package private to allow access from {@link MatVecOp} and {@link BLAS}. + */ + int m; + + /** + * Column dimension. + * + *
Package private to allow access from {@link MatVecOp} and {@link BLAS}. + */ + int n; + + /** + * Array for internal storage of elements. + * + *
Package private to allow access from {@link MatVecOp} and {@link BLAS}. + * + *
The matrix data is stored in column major format internally.
+ */
+ double[] data;
+
+ /**
+ * Construct an m-by-n matrix of zeros.
+ *
+ * @param m Number of rows.
+ * @param n Number of colums.
+ */
+ public DenseMatrix(int m, int n) {
+ this(m, n, new double[m * n], false);
+ }
+
+ /**
+ * Construct a matrix from a 1-D array. The data in the array should organize
+ * in column major.
+ *
+ * @param m Number of rows.
+ * @param n Number of cols.
+ * @param data One-dimensional array of doubles.
+ */
+ public DenseMatrix(int m, int n, double[] data) {
+ this(m, n, data, false);
+ }
+
+ /**
+ * Construct a matrix from a 1-D array. The data in the array is organized
+ * in column major or in row major, which is specified by parameter 'inRowMajor'
+ *
+ * @param m Number of rows.
+ * @param n Number of cols.
+ * @param data One-dimensional array of doubles.
+ * @param inRowMajor Whether the matrix in 'data' is in row major format.
+ */
+ public DenseMatrix(int m, int n, double[] data, boolean inRowMajor) {
+ assert (data.length == m * n);
+ this.m = m;
+ this.n = n;
+ if (inRowMajor) {
+ toColumnMajor(m, n, data);
+ }
+ this.data = data;
+ }
+
+ /**
+ * Construct a matrix from a 2-D array.
+ *
+ * @param data Two-dimensional array of doubles.
+ * @throws IllegalArgumentException All rows must have the same size
+ */
+ public DenseMatrix(double[][] data) {
+ this.m = data.length;
+ if (this.m == 0) {
+ this.n = 0;
+ this.data = new double[0];
+ return;
+ }
+ this.n = data[0].length;
+ for (int i = 0; i < m; i++) {
+ if (data[i].length != n) {
+ throw new IllegalArgumentException("All rows must have the same size.");
+ }
+ }
+ this.data = new double[m * n];
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < n; j++) {
+ this.set(i, j, data[i][j]);
+ }
+ }
+ }
+
+ /**
+ * Create an identity matrix.
+ *
+ * @param n the dimension of the eye matrix.
+ * @return an identity matrix.
+ */
+ public static DenseMatrix eye(int n) {
+ return eye(n, n);
+ }
+
+ /**
+ * Create a m * n identity matrix.
+ *
+ * @param m the row dimension.
+ * @param n the column dimension.e
+ * @return the m * n identity matrix.
+ */
+ public static DenseMatrix eye(int m, int n) {
+ DenseMatrix mat = new DenseMatrix(m, n);
+ int k = Math.min(m, n);
+ for (int i = 0; i < k; i++) {
+ mat.data[i * m + i] = 1.0;
+ }
+ return mat;
+ }
+
+ /**
+ * Create a zero matrix.
+ *
+ * @param m the row dimension.
+ * @param n the column dimension.
+ * @return a m * n zero matrix.
+ */
+ public static DenseMatrix zeros(int m, int n) {
+ return new DenseMatrix(m, n);
+ }
+
+ /**
+ * Create a matrix with all elements set to 1.
+ *
+ * @param m the row dimension
+ * @param n the column dimension
+ * @return the m * n matrix with all elements set to 1.
+ */
+ public static DenseMatrix ones(int m, int n) {
+ DenseMatrix mat = new DenseMatrix(m, n);
+ Arrays.fill(mat.data, 1.);
+ return mat;
+ }
+
+ /**
+ * Create a random matrix.
+ *
+ * @param m the row dimension
+ * @param n the column dimension.
+ * @return a m * n random matrix.
+ */
+ public static DenseMatrix rand(int m, int n) {
+ DenseMatrix mat = new DenseMatrix(m, n);
+ for (int i = 0; i < mat.data.length; i++) {
+ mat.data[i] = Math.random();
+ }
+ return mat;
+ }
+
+ /**
+ * Create a random symmetric matrix.
+ *
+ * @param n the dimension of the symmetric matrix.
+ * @return a n * n random symmetric matrix.
+ */
+ public static DenseMatrix randSymmetric(int n) {
+ DenseMatrix mat = new DenseMatrix(n, n);
+ for (int i = 0; i < n; i++) {
+ for (int j = i; j < n; j++) {
+ double r = Math.random();
+ mat.set(i, j, r);
+ if (i != j) {
+ mat.set(j, i, r);
+ }
+ }
+ }
+ return mat;
+ }
+
+ /**
+ * Get a single element.
+ *
+ * @param i Row index.
+ * @param j Column index.
+ * @return matA(i, j)
+ * @throws ArrayIndexOutOfBoundsException
+ */
+ public double get(int i, int j) {
+ return data[j * m + i];
+ }
+
+ /**
+ * Get the data array of this matrix.
+ *
+ * @return the data array of this matrix.
+ */
+ public double[] getData() {
+ return this.data;
+ }
+
+ /**
+ * Get all the matrix data, returned as a 2-D array.
+ *
+ * @return all matrix data, returned as a 2-D array.
+ */
+ public double[][] getArrayCopy2D() {
+ double[][] arrayData = new double[m][n];
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < n; j++) {
+ arrayData[i][j] = this.get(i, j);
+ }
+ }
+ return arrayData;
+ }
+
+ /**
+ * Get all matrix data, returned as a 1-D array.
+ *
+ * @param inRowMajor Whether to return data in row major.
+ * @return all matrix data, returned as a 1-D array.
+ */
+ public double[] getArrayCopy1D(boolean inRowMajor) {
+ if (inRowMajor) {
+ double[] arrayData = new double[m * n];
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < n; j++) {
+ arrayData[i * n + j] = this.get(i, j);
+ }
+ }
+ return arrayData;
+ } else {
+ return this.data.clone();
+ }
+ }
+
+ /**
+ * Get one row.
+ *
+ * @param row the row index.
+ * @return the row with the given index.
+ */
+ public double[] getRow(int row) {
+ assert (row >= 0 && row < m) : "Invalid row index.";
+ double[] r = new double[n];
+ for (int i = 0; i < n; i++) {
+ r[i] = this.get(row, i);
+ }
+ return r;
+ }
+
+ /**
+ * Get one column.
+ *
+ * @param col the column index.
+ * @return the column with the given index.
+ */
+ public double[] getColumn(int col) {
+ assert (col >= 0 && col < n) : "Invalid column index.";
+ double[] columnData = new double[m];
+ System.arraycopy(this.data, col * m, columnData, 0, m);
+ return columnData;
+ }
+
+ /**
+ * Clone the Matrix object.
+ */
+ @Override
+ public DenseMatrix clone() {
+ return new DenseMatrix(this.m, this.n, this.data.clone(), false);
+ }
+
+ /**
+ * Create a new matrix by selecting some of the rows.
+ *
+ * @param rows the array of row indexes to select.
+ * @return a new matrix by selecting some of the rows.
+ */
+ public DenseMatrix selectRows(int[] rows) {
+ DenseMatrix sub = new DenseMatrix(rows.length, this.n);
+ for (int i = 0; i < rows.length; i++) {
+ for (int j = 0; j < this.n; j++) {
+ sub.set(i, j, this.get(rows[i], j));
+ }
+ }
+ return sub;
+ }
+
+ /**
+ * Get sub matrix.
+ *
+ * @param m0 the starting row index (inclusive)
+ * @param m1 the ending row index (exclusive)
+ * @param n0 the starting column index (inclusive)
+ * @param n1 the ending column index (exclusive)
+ * @return the specified sub matrix.
+ */
+ public DenseMatrix getSubMatrix(int m0, int m1, int n0, int n1) {
+ assert (m0 >= 0 && m1 <= m) && (n0 >= 0 && n1 <= n) : "Invalid index range.";
+ DenseMatrix sub = new DenseMatrix(m1 - m0, n1 - n0);
+ for (int i = 0; i < sub.m; i++) {
+ for (int j = 0; j < sub.n; j++) {
+ sub.set(i, j, this.get(m0 + i, n0 + j));
+ }
+ }
+ return sub;
+ }
+
+ /**
+ * Set part of the matrix values from the values of another matrix.
+ *
+ * @param sub the matrix whose element values will be assigned to the sub matrix of this matrix.
+ * @param m0 the starting row index (inclusive)
+ * @param m1 the ending row index (exclusive)
+ * @param n0 the starting column index (inclusive)
+ * @param n1 the ending column index (exclusive)
+ */
+ public void setSubMatrix(DenseMatrix sub, int m0, int m1, int n0, int n1) {
+ assert (m0 >= 0 && m1 <= m) && (n0 >= 0 && n1 <= n) : "Invalid index range.";
+ for (int i = 0; i < sub.m; i++) {
+ for (int j = 0; j < sub.n; j++) {
+ this.set(m0 + i, n0 + j, sub.get(i, j));
+ }
+ }
+ }
+
+ /**
+ * Set a single element.
+ *
+ * @param i Row index.
+ * @param j Column index.
+ * @param s A(i,j).
+ * @throws ArrayIndexOutOfBoundsException
+ */
+ public void set(int i, int j, double s) {
+ data[j * m + i] = s;
+ }
+
+ /**
+ * Add the given value to a single element.
+ *
+ * @param i Row index.
+ * @param j Column index.
+ * @param s A(i,j).
+ * @throws ArrayIndexOutOfBoundsException
+ */
+ public void add(int i, int j, double s) {
+ data[j * m + i] += s;
+ }
+
+ /**
+ * Check whether the matrix is square matrix.
+ *
+ * @return true
if this matrix is a square matrix, false
otherwise.
+ */
+ public boolean isSquare() {
+ return m == n;
+ }
+
+ /**
+ * Check whether the matrix is symmetric matrix.
+ *
+ * @return true
if this matrix is a symmetric matrix, false
otherwise.
+ */
+ public boolean isSymmetric() {
+ if (m != n) {
+ return false;
+ }
+ for (int i = 0; i < n; i++) {
+ for (int j = i + 1; j < n; j++) {
+ if (this.get(i, j) != this.get(j, i)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Get the number of rows.
+ *
+ * @return the number of rows.
+ */
+ public int numRows() {
+ return m;
+ }
+
+ /**
+ * Get the number of columns.
+ *
+ * @return the number of columns.
+ */
+ public int numCols() {
+ return n;
+ }
+
+ /**
+ * Sum of all elements of the matrix.
+ */
+ public double sum() {
+ double s = 0.;
+ for (int i = 0; i < this.data.length; i++) {
+ s += this.data[i];
+ }
+ return s;
+ }
+
+ /**
+ * Scale the vector by value "v" and create a new matrix to store the result.
+ */
+ public DenseMatrix scale(double v) {
+ DenseMatrix r = this.clone();
+ BLAS.scal(v, r);
+ return r;
+ }
+
+ /**
+ * Scale the matrix by value "v".
+ */
+ public void scaleEqual(double v) {
+ BLAS.scal(v, this);
+ }
+
+ /**
+ * Create a new matrix by plussing another matrix.
+ */
+ public DenseMatrix plus(DenseMatrix mat) {
+ DenseMatrix r = this.clone();
+ BLAS.axpy(1.0, mat, r);
+ return r;
+ }
+
+ /**
+ * Create a new matrix by plussing a constant.
+ */
+ public DenseMatrix plus(double alpha) {
+ DenseMatrix r = this.clone();
+ for (int i = 0; i < r.data.length; i++) {
+ r.data[i] += alpha;
+ }
+ return r;
+ }
+
+ /**
+ * Plus with another matrix.
+ */
+ public void plusEquals(DenseMatrix mat) {
+ BLAS.axpy(1.0, mat, this);
+ }
+
+ /**
+ * Plus with a constant.
+ */
+ public void plusEquals(double alpha) {
+ for (int i = 0; i < this.data.length; i++) {
+ this.data[i] += alpha;
+ }
+ }
+
+ /**
+ * Create a new matrix by subtracting another matrix.
+ */
+ public DenseMatrix minus(DenseMatrix mat) {
+ DenseMatrix r = this.clone();
+ BLAS.axpy(-1.0, mat, r);
+ return r;
+ }
+
+ /**
+ * Minus with another vector.
+ */
+ public void minusEquals(DenseMatrix mat) {
+ BLAS.axpy(-1.0, mat, this);
+ }
+
+ /**
+ * Multiply with another matrix.
+ */
+ public DenseMatrix multiplies(DenseMatrix mat) {
+ DenseMatrix r = new DenseMatrix(this.m, mat.n);
+ BLAS.gemm(1.0, this, false, mat, false, 0., r);
+ return r;
+ }
+
+ /**
+ * Multiply with a dense vector.
+ */
+ public DenseVector multiplies(DenseVector x) {
+ DenseVector y = new DenseVector(this.numRows());
+ BLAS.gemv(1.0, this, false, x, 0.0, y);
+ return y;
+ }
+
+ /**
+ * Multiply with a sparse vector.
+ */
+ public DenseVector multiplies(SparseVector x) {
+ DenseVector y = new DenseVector(this.numRows());
+ for (int i = 0; i < this.numRows(); i++) {
+ double s = 0.;
+ int[] indices = x.getIndices();
+ double[] values = x.getValues();
+ for (int j = 0; j < indices.length; j++) {
+ int index = indices[j];
+ if (index >= this.numCols()) {
+ throw new RuntimeException("Vector index out of bound:" + index);
+ }
+ s += this.get(i, index) * values[j];
+ }
+ y.set(i, s);
+ }
+ return y;
+ }
+
+ /**
+ * Create a new matrix by transposing current matrix.
+ *
+ *
Use cache-oblivious matrix transpose algorithm. + */ + public DenseMatrix transpose() { + DenseMatrix mat = new DenseMatrix(n, m); + int m0 = m; + int n0 = n; + int barrierSize = 16384; + while (m0 * n0 > barrierSize) { + if (m0 >= n0) { + m0 /= 2; + } else { + n0 /= 2; + } + } + for (int i0 = 0; i0 < m; i0 += m0) { + for (int j0 = 0; j0 < n; j0 += n0) { + for (int i = i0; i < i0 + m0 && i < m; i++) { + for (int j = j0; j < j0 + n0 && j < n; j++) { + mat.set(j, i, this.get(i, j)); + } + } + } + } + return mat; + } + + /** + * Converts the data layout in "data" from row major to column major. + */ + private static void toColumnMajor(int m, int n, double[] data) { + if (m == n) { + for (int i = 0; i < m; i++) { + for (int j = i + 1; j < m; j++) { + int pos0 = j * m + i; + int pos1 = i * m + j; + double t = data[pos0]; + data[pos0] = data[pos1]; + data[pos1] = t; + } + } + } else { + DenseMatrix temp = new DenseMatrix(n, m, data, false); + System.arraycopy(temp.transpose().data, 0, data, 0, data.length); + } + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(); + sbd.append(String.format("mat[%d,%d]:\n", m, n)); + for (int i = 0; i < m; i++) { + sbd.append(" "); + for (int j = 0; j < n; j++) { + if (j > 0) { + sbd.append(","); + } + sbd.append(this.get(i, j)); + } + sbd.append("\n"); + } + return sbd.toString(); + } +} diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseVector.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseVector.java new file mode 100644 index 0000000000000..6c3337b75346d --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseVector.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import java.util.Arrays; +import java.util.Random; + +/** + * A dense vector represented by a values array. + */ +public class DenseVector extends Vector { + /** + * The array holding the vector data. + *
+ * Package private to allow access from {@link MatVecOp} and {@link BLAS}.
+ */
+ double[] data;
+
+ /**
+ * Create a zero size vector.
+ */
+ public DenseVector() {
+ this(0);
+ }
+
+ /**
+ * Create a size n vector with all elements zero.
+ *
+ * @param n Size of the vector.
+ */
+ public DenseVector(int n) {
+ this.data = new double[n];
+ }
+
+ /**
+ * Create a dense vector with the user provided data.
+ *
+ * @param data The vector data.
+ */
+ public DenseVector(double[] data) {
+ this.data = data;
+ }
+
+ /**
+ * Get the data array.
+ */
+ public double[] getData() {
+ return this.data;
+ }
+
+ /**
+ * Set the data array.
+ */
+ public void setData(double[] data) {
+ this.data = data;
+ }
+
+ /**
+ * Create a dense vector with all elements one.
+ *
+ * @param n Size of the vector.
+ * @return The newly created dense vector.
+ */
+ public static DenseVector ones(int n) {
+ DenseVector r = new DenseVector(n);
+ Arrays.fill(r.data, 1.0);
+ return r;
+ }
+
+ /**
+ * Create a dense vector with all elements zero.
+ *
+ * @param n Size of the vector.
+ * @return The newly created dense vector.
+ */
+ public static DenseVector zeros(int n) {
+ DenseVector r = new DenseVector(n);
+ Arrays.fill(r.data, 0.0);
+ return r;
+ }
+
+ /**
+ * Create a dense vector with random values uniformly distributed in the range of [0.0, 1.0].
+ *
+ * @param n Size of the vector.
+ * @return The newly created dense vector.
+ */
+ public static DenseVector rand(int n) {
+ Random random = new Random();
+ DenseVector v = new DenseVector(n);
+ for (int i = 0; i < n; i++) {
+ v.data[i] = random.nextDouble();
+ }
+ return v;
+ }
+
+ @Override
+ public DenseVector clone() {
+ return new DenseVector(this.data.clone());
+ }
+
+ @Override
+ public String toString() {
+ return VectorUtil.toString(this);
+ }
+
+ @Override
+ public int size() {
+ return data.length;
+ }
+
+ @Override
+ public double get(int i) {
+ return data[i];
+ }
+
+ @Override
+ public void set(int i, double d) {
+ data[i] = d;
+ }
+
+ @Override
+ public void add(int i, double d) {
+ data[i] += d;
+ }
+
+ @Override
+ public double normL1() {
+ double d = 0;
+ for (double t : data) {
+ d += Math.abs(t);
+ }
+ return d;
+ }
+
+ @Override
+ public double normL2() {
+ double d = 0;
+ for (double t : data) {
+ d += t * t;
+ }
+ return Math.sqrt(d);
+ }
+
+ @Override
+ public double normL2Square() {
+ double d = 0;
+ for (double t : data) {
+ d += t * t;
+ }
+ return d;
+ }
+
+ @Override
+ public double normInf() {
+ double d = 0;
+ for (double t : data) {
+ d = Math.max(Math.abs(t), d);
+ }
+ return d;
+ }
+
+ @Override
+ public DenseVector slice(int[] indices) {
+ double[] values = new double[indices.length];
+ for (int i = 0; i < indices.length; ++i) {
+ if (indices[i] >= data.length) {
+ throw new RuntimeException("Index is larger than vector size.");
+ }
+ values[i] = data[indices[i]];
+ }
+ return new DenseVector(values);
+ }
+
+ @Override
+ public DenseVector prefix(double d) {
+ double[] data = new double[this.size() + 1];
+ data[0] = d;
+ System.arraycopy(this.data, 0, data, 1, this.data.length);
+ return new DenseVector(data);
+ }
+
+ @Override
+ public DenseVector append(double d) {
+ double[] data = new double[this.size() + 1];
+ System.arraycopy(this.data, 0, data, 0, this.data.length);
+ data[this.size()] = d;
+ return new DenseVector(data);
+ }
+
+ @Override
+ public void scaleEqual(double d) {
+ BLAS.scal(d, this);
+ }
+
+ @Override
+ public DenseVector plus(Vector other) {
+ DenseVector r = this.clone();
+ if (other instanceof DenseVector) {
+ BLAS.axpy(1.0, (DenseVector) other, r);
+ } else {
+ BLAS.axpy(1.0, (SparseVector) other, r);
+ }
+ return r;
+ }
+
+ @Override
+ public DenseVector minus(Vector other) {
+ DenseVector r = this.clone();
+ if (other instanceof DenseVector) {
+ BLAS.axpy(-1.0, (DenseVector) other, r);
+ } else {
+ BLAS.axpy(-1.0, (SparseVector) other, r);
+ }
+ return r;
+ }
+
+ @Override
+ public DenseVector scale(double d) {
+ DenseVector r = this.clone();
+ BLAS.scal(d, r);
+ return r;
+ }
+
+ @Override
+ public double dot(Vector vec) {
+ if (vec instanceof DenseVector) {
+ return BLAS.dot(this, (DenseVector) vec);
+ } else {
+ return vec.dot(this);
+ }
+ }
+
+ @Override
+ public void standardizeEqual(double mean, double stdvar) {
+ int size = data.length;
+ for (int i = 0; i < size; i++) {
+ data[i] -= mean;
+ data[i] *= (1.0 / stdvar);
+ }
+ }
+
+ @Override
+ public void normalizeEqual(double p) {
+ double norm = 0.0;
+ if (Double.isInfinite(p)) {
+ norm = normInf();
+ } else if (p == 1.0) {
+ norm = normL1();
+ } else if (p == 2.0) {
+ norm = normL2();
+ } else {
+ for (int i = 0; i < data.length; i++) {
+ norm += Math.pow(Math.abs(data[i]), p);
+ }
+ norm = Math.pow(norm, 1 / p);
+ }
+ for (int i = 0; i < data.length; i++) {
+ data[i] /= norm;
+ }
+ }
+
+ /**
+ * Set the data of the vector the same as those of another vector.
+ */
+ public void setEqual(DenseVector other) {
+ assert this.size() == other.size() : "Size of the two vectors mismatched.";
+ System.arraycopy(other.data, 0, this.data, 0, this.size());
+ }
+
+ /**
+ * Plus with another vector.
+ */
+ public void plusEqual(Vector other) {
+ if (other instanceof DenseVector) {
+ BLAS.axpy(1.0, (DenseVector) other, this);
+ } else {
+ BLAS.axpy(1.0, (SparseVector) other, this);
+ }
+ }
+
+ /**
+ * Minus with another vector.
+ */
+ public void minusEqual(Vector other) {
+ if (other instanceof DenseVector) {
+ BLAS.axpy(-1.0, (DenseVector) other, this);
+ } else {
+ BLAS.axpy(-1.0, (SparseVector) other, this);
+ }
+ }
+
+ /**
+ * Plus with another vector scaled by "alpha".
+ */
+ public void plusScaleEqual(Vector other, double alpha) {
+ if (other instanceof DenseVector) {
+ BLAS.axpy(alpha, (DenseVector) other, this);
+ } else {
+ BLAS.axpy(alpha, (SparseVector) other, this);
+ }
+ }
+
+ @Override
+ public DenseMatrix outer() {
+ return this.outer(this);
+ }
+
+ /**
+ * Compute the outer product with another vector.
+ *
+ * @return The outer product matrix.
+ */
+ public DenseMatrix outer(DenseVector other) {
+ int nrows = this.size();
+ int ncols = other.size();
+ double[] data = new double[nrows * ncols];
+ int pos = 0;
+ for (int j = 0; j < ncols; j++) {
+ for (int i = 0; i < nrows; i++) {
+ data[pos++] = this.data[i] * other.data[j];
+ }
+ }
+ return new DenseMatrix(nrows, ncols, data, false);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ DenseVector that = (DenseVector) o;
+ return Arrays.equals(data, that.data);
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(data);
+ }
+
+ @Override
+ public VectorIterator iterator() {
+ return new DenseVectorIterator();
+ }
+
+ private class DenseVectorIterator implements VectorIterator {
+ private int cursor = 0;
+
+ @Override
+ public boolean hasNext() {
+ return cursor < data.length;
+ }
+
+ @Override
+ public void next() {
+ cursor++;
+ }
+
+ @Override
+ public int getIndex() {
+ if (cursor >= data.length) {
+ throw new RuntimeException("Iterator out of bound.");
+ }
+ return cursor;
+ }
+
+ @Override
+ public double getValue() {
+ if (cursor >= data.length) {
+ throw new RuntimeException("Iterator out of bound.");
+ }
+ return data[cursor];
+ }
+ }
+}
diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/MatVecOp.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/MatVecOp.java
new file mode 100644
index 0000000000000..9bbfbbf714052
--- /dev/null
+++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/MatVecOp.java
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.common.linalg;
+
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
+/**
+ * A utility class that provides operations over {@link DenseVector}, {@link SparseVector}
+ * and {@link DenseMatrix}.
+ */
+public class MatVecOp {
+ /**
+ * compute vec1 + vec2 .
+ */
+ public static Vector plus(Vector vec1, Vector vec2) {
+ return vec1.plus(vec2);
+ }
+
+ /**
+ * compute vec1 - vec2 .
+ */
+ public static Vector minus(Vector vec1, Vector vec2) {
+ return vec1.minus(vec2);
+ }
+
+ /**
+ * Compute vec1 \cdot vec2 .
+ */
+ public static double dot(Vector vec1, Vector vec2) {
+ return vec1.dot(vec2);
+ }
+
+ /**
+ * Compute || vec1 - vec2 ||_1 .
+ */
+ public static double sumAbsDiff(Vector vec1, Vector vec2) {
+ if (vec1 instanceof DenseVector) {
+ if (vec2 instanceof DenseVector) {
+ return MatVecOp.applySum((DenseVector) vec1, (DenseVector) vec2, (a, b) -> Math.abs(a - b));
+ } else {
+ return MatVecOp.applySum((DenseVector) vec1, (SparseVector) vec2, (a, b) -> Math.abs(a - b));
+ }
+ } else {
+ if (vec2 instanceof DenseVector) {
+ return MatVecOp.applySum((SparseVector) vec1, (DenseVector) vec2, (a, b) -> Math.abs(a - b));
+ } else {
+ return MatVecOp.applySum((SparseVector) vec1, (SparseVector) vec2, (a, b) -> Math.abs(a - b));
+ }
+ }
+ }
+
+ /**
+ * Compute || vec1 - vec2 ||_2^2 .
+ */
+ public static double sumSquaredDiff(Vector vec1, Vector vec2) {
+ if (vec1 instanceof DenseVector) {
+ if (vec2 instanceof DenseVector) {
+ return MatVecOp.applySum((DenseVector) vec1, (DenseVector) vec2, (a, b) -> (a - b) * (a - b));
+ } else {
+ return MatVecOp.applySum((DenseVector) vec1, (SparseVector) vec2, (a, b) -> (a - b) * (a - b));
+ }
+ } else {
+ if (vec2 instanceof DenseVector) {
+ return MatVecOp.applySum((SparseVector) vec1, (DenseVector) vec2, (a, b) -> (a - b) * (a - b));
+ } else {
+ return MatVecOp.applySum((SparseVector) vec1, (SparseVector) vec2, (a, b) -> (a - b) * (a - b));
+ }
+ }
+ }
+
+ /**
+ * y = func(x).
+ */
+ public static void apply(DenseMatrix x, DenseMatrix y, Function Package private to allow access from {@link MatVecOp} and {@link BLAS}.
+ */
+ int n;
+
+ /**
+ * Column indices.
+ *
+ * Package private to allow access from {@link MatVecOp} and {@link BLAS}.
+ */
+ int[] indices;
+
+ /**
+ * Column values.
+ *
+ * Package private to allow access from {@link MatVecOp} and {@link BLAS}.
+ */
+ double[] values;
+
+ /**
+ * Construct an empty sparse vector with undetermined size.
+ */
+ public SparseVector() {
+ this(-1);
+ }
+
+ /**
+ * Construct an empty sparse vector with determined size.
+ */
+ public SparseVector(int n) {
+ this.n = n;
+ this.indices = new int[0];
+ this.values = new double[0];
+ }
+
+ /**
+ * Construct a sparse vector with the given indices and values.
+ *
+ * @throws IllegalArgumentException If size of indices array and values array differ.
+ * @throws IllegalArgumentException If n >= 0 and the indices are out of bound.
+ */
+ public SparseVector(int n, int[] indices, double[] values) {
+ this.n = n;
+ this.indices = indices;
+ this.values = values;
+ checkSizeAndIndicesRange();
+ sortIndices();
+ }
+
+ /**
+ * Construct a sparse vector with given indices to values map.
+ *
+ * @throws IllegalArgumentException If n >= 0 and the indices are out of bound.
+ */
+ public SparseVector(int n, Map Usage:
+ *
+ * The {@link #getIndex()} while returns the index of the
+ * element which the cursor points.
+ * The {@link #getValue()} ()} while returns the value of
+ * the element which the cursor points.
+ *
+ * The difference to the {@link Iterator#next()} is that this
+ * can avoid the return of boxed type.
+ */
+ void next();
+
+ /**
+ * Returns the index of the element which the cursor points.
+ *
+ * @returnthe the index of the element which the cursor points.
+ */
+ int getIndex();
+
+ /**
+ * Returns the value of the element which the cursor points.
+ *
+ * @returnthe the value of the element which the cursor points.
+ */
+ double getValue();
+}
diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorUtil.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorUtil.java
new file mode 100644
index 0000000000000..b605f635e61d7
--- /dev/null
+++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorUtil.java
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.common.linalg;
+
+import org.apache.commons.lang3.StringUtils;
+
+/**
+ * Utility class for the operations on {@link Vector} and its subclasses.
+ */
+public class VectorUtil {
+ /**
+ * Delimiter between elements.
+ */
+ private static final char ELEMENT_DELIMITER = ' ';
+ /**
+ * Delimiter between vector size and vector data.
+ */
+ private static final char HEADER_DELIMITER = '$';
+ /**
+ * Delimiter between index and value.
+ */
+ private static final char INDEX_VALUE_DELIMITER = ':';
+
+ /**
+ * Parse either a {@link SparseVector} or a {@link DenseVector} from a formatted string.
+ *
+ * The format of a dense vector is space separated values such as "1 2 3 4".
+ * The format of a sparse vector is space separated index-value pairs, such as "0:1 2:3 3:4".
+ * If the sparse vector has determined vector size, the size is prepended to the head. For example,
+ * the string "$4$0:1 2:3 3:4" represents a sparse vector with size 4.
+ *
+ * @param str A formatted string representing a vector.
+ * @return The parsed vector.
+ */
+ public static Vector parse(String str) {
+ boolean isSparse = org.apache.flink.util.StringUtils.isNullOrWhitespaceOnly(str)
+ || StringUtils.indexOf(str, INDEX_VALUE_DELIMITER) != -1
+ || StringUtils.indexOf(str, HEADER_DELIMITER) != -1;
+ if (isSparse) {
+ return parseSparse(str);
+ } else {
+ return parseDense(str);
+ }
+ }
+
+ /**
+ * Parse the dense vector from a formatted string.
+ *
+ * The format of a dense vector is space separated values such as "1 2 3 4".
+ *
+ * @param str A string of space separated values.
+ * @return The parsed vector.
+ */
+ public static DenseVector parseDense(String str) {
+ if (org.apache.flink.util.StringUtils.isNullOrWhitespaceOnly(str)) {
+ return new DenseVector();
+ }
+
+ int len = str.length();
+
+ int inDataBuffPos = 0;
+ boolean isInBuff = false;
+
+ for (int i = 0; i < len; ++i) {
+ char c = str.charAt(i);
+
+ if (c == ELEMENT_DELIMITER
+ // to be compatible with previous delimiter
+ || c == ',') {
+ if (isInBuff) {
+ inDataBuffPos++;
+ }
+
+ isInBuff = false;
+ } else {
+ isInBuff = true;
+ }
+ }
+
+ if (isInBuff) {
+ inDataBuffPos++;
+ }
+
+ double[] data = new double[inDataBuffPos];
+ int lastestInCharBuffPos = 0;
+
+ inDataBuffPos = 0;
+ isInBuff = false;
+
+ for (int i = 0; i < len; ++i) {
+ char c = str.charAt(i);
+
+ if (c == ELEMENT_DELIMITER) {
+ if (isInBuff) {
+ data[inDataBuffPos++] = Double.parseDouble(
+ StringUtils.substring(str, lastestInCharBuffPos, i).trim()
+ );
+
+ lastestInCharBuffPos = i + 1;
+ }
+
+ isInBuff = false;
+ } else {
+ isInBuff = true;
+ }
+ }
+
+ if (isInBuff) {
+ data[inDataBuffPos] = Double.valueOf(
+ StringUtils.substring(str, lastestInCharBuffPos).trim()
+ );
+ }
+
+ return new DenseVector(data);
+ }
+
+ /**
+ * Parse the sparse vector from a formatted string.
+ *
+ * The format of a sparse vector is space separated index-value pairs, such as "0:1 2:3 3:4".
+ * If the sparse vector has determined vector size, the size is prepended to the head. For example,
+ * the string "$4$0:1 2:3 3:4" represents a sparse vector with size 4.
+ *
+ * @param str A formatted string representing a sparse vector.
+ * @throws IllegalArgumentException If the string is of invalid format.
+ */
+ public static SparseVector parseSparse(String str) {
+ try {
+ if (org.apache.flink.util.StringUtils.isNullOrWhitespaceOnly(str)) {
+ return new SparseVector();
+ }
+
+ int n = -1;
+ int firstDollarPos = str.indexOf(HEADER_DELIMITER);
+ int lastDollarPos = -1;
+ if (firstDollarPos >= 0) {
+ lastDollarPos = StringUtils.lastIndexOf(str, HEADER_DELIMITER);
+ String sizeStr = StringUtils.substring(str, firstDollarPos + 1, lastDollarPos);
+ n = Integer.valueOf(sizeStr);
+ if (lastDollarPos == str.length() - 1) {
+ return new SparseVector(n);
+ }
+ }
+
+ int numValues = StringUtils.countMatches(str, String.valueOf(INDEX_VALUE_DELIMITER));
+ double[] data = new double[numValues];
+ int[] indices = new int[numValues];
+ int startPos = lastDollarPos + 1;
+ int endPos;
+ for (int i = 0; i < numValues; i++) {
+ int colonPos = StringUtils.indexOf(str, INDEX_VALUE_DELIMITER, startPos);
+ if (colonPos < 0) {
+ throw new IllegalArgumentException("Format error.");
+ }
+ endPos = StringUtils.indexOf(str, ELEMENT_DELIMITER, colonPos);
+
+ if (endPos == -1) {
+ endPos = str.length();
+ }
+ indices[i] = Integer.valueOf(str.substring(startPos, colonPos).trim());
+ data[i] = Double.valueOf(str.substring(colonPos + 1, endPos).trim());
+ startPos = endPos + 1;
+ }
+ return new SparseVector(n, indices, data);
+ } catch (Exception e) {
+ throw new IllegalArgumentException(
+ String.format("Fail to getVector sparse vector from string: \"%s\".", str), e);
+ }
+ }
+
+ /**
+ * Serialize the vector to a string.
+ *
+ * @param vector The vector to serialize.
+ * @see #toString(DenseVector)
+ * @see #toString(SparseVector)
+ */
+ public static String toString(Vector vector) {
+ if (vector instanceof SparseVector) {
+ return toString((SparseVector) vector);
+ }
+ return toString((DenseVector) vector);
+ }
+
+ /**
+ * Serialize the SparseVector to string.
+ *
+ * The format of the returned is described at {@link #parseSparse(String)}
+ *
+ * @param sparseVector The sparse vector to serialize.
+ */
+ public static String toString(SparseVector sparseVector) {
+ StringBuilder sbd = new StringBuilder();
+ if (sparseVector.n > 0) {
+ sbd.append(HEADER_DELIMITER);
+ sbd.append(sparseVector.n);
+ sbd.append(HEADER_DELIMITER);
+ }
+ if (null != sparseVector.indices) {
+ for (int i = 0; i < sparseVector.indices.length; i++) {
+ sbd.append(sparseVector.indices[i]);
+ sbd.append(INDEX_VALUE_DELIMITER);
+ sbd.append(sparseVector.values[i]);
+ if (i < sparseVector.indices.length - 1) {
+ sbd.append(ELEMENT_DELIMITER);
+ }
+ }
+ }
+
+ return sbd.toString();
+ }
+
+ /**
+ * Serialize the DenseVector to String.
+ *
+ * The format of the returned is described at {@link #parseDense(String)}
+ *
+ * @param denseVector The DenseVector to serialize.
+ */
+ public static String toString(DenseVector denseVector) {
+ StringBuilder sbd = new StringBuilder();
+
+ for (int i = 0; i < denseVector.data.length; i++) {
+ sbd.append(denseVector.data[i]);
+ if (i < denseVector.data.length - 1) {
+ sbd.append(ELEMENT_DELIMITER);
+ }
+ }
+ return sbd.toString();
+ }
+
+}
diff --git a/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/NOTICE b/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/NOTICE
new file mode 100644
index 0000000000000..db087abe66c49
--- /dev/null
+++ b/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/NOTICE
@@ -0,0 +1,10 @@
+flink-ml-lib
+Copyright 2014-2019 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+This project bundles the following dependencies under the BSD license.
+See bundled license files for details.
+
+- com.github.fommil.netlib:core:1.1.2
diff --git a/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/licenses/LICENSE.core b/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/licenses/LICENSE.core
new file mode 100644
index 0000000000000..b7d28491e4b6b
--- /dev/null
+++ b/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/licenses/LICENSE.core
@@ -0,0 +1,49 @@
+Copyright (c) 2013 Samuel Halliday
+Copyright (c) 1992-2011 The University of Tennessee and The University
+ of Tennessee Research Foundation. All rights
+ reserved.
+Copyright (c) 2000-2011 The University of California Berkeley. All
+ rights reserved.
+Copyright (c) 2006-2011 The University of Colorado Denver. All rights
+ reserved.
+
+$COPYRIGHT$
+
+Additional copyrights may follow
+
+$HEADER$
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+- Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+- Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer listed
+ in this license in the documentation and/or other materials
+ provided with the distribution.
+
+- Neither the name of the copyright holders nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+The copyright holders provide no reassurances that the source code
+provided does not infringe any patent, copyright, or any other
+intellectual property rights of third parties. The copyright holders
+disclaim any liability to any recipient for claims brought against
+recipient by any third party for infringement of that parties
+intellectual property rights.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseMatrixTest.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseMatrixTest.java
new file mode 100644
index 0000000000000..bb35294ea0033
--- /dev/null
+++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseMatrixTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.common.linalg;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test cases for DenseMatrix.
+ */
+public class DenseMatrixTest {
+
+ private static final double TOL = 1.0e-6;
+
+ private static void assertEqual2D(double[][] matA, double[][] matB) {
+ assert (matA.length == matB.length);
+ assert (matA[0].length == matB[0].length);
+ int m = matA.length;
+ int n = matA[0].length;
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < n; j++) {
+ Assert.assertEquals(matA[i][j], matB[i][j], TOL);
+ }
+ }
+ }
+
+ private static double[][] simpleMM(double[][] matA, double[][] matB) {
+ int m = matA.length;
+ int n = matB[0].length;
+ int k = matA[0].length;
+ double[][] matC = new double[m][n];
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < n; j++) {
+ matC[i][j] = 0.;
+ for (int l = 0; l < k; l++) {
+ matC[i][j] += matA[i][l] * matB[l][j];
+ }
+ }
+ }
+ return matC;
+ }
+
+ private static double[] simpleMV(double[][] matA, double[] x) {
+ int m = matA.length;
+ int n = matA[0].length;
+ assert (n == x.length);
+ double[] y = new double[m];
+ for (int i = 0; i < m; i++) {
+ y[i] = 0.;
+ for (int j = 0; j < n; j++) {
+ y[i] += matA[i][j] * x[j];
+ }
+ }
+ return y;
+ }
+
+ @Test
+ public void testPlusEquals() throws Exception {
+ DenseMatrix matA = new DenseMatrix(new double[][]{
+ new double[]{1, 3, 5},
+ new double[]{2, 4, 6},
+ });
+ DenseMatrix matB = DenseMatrix.ones(2, 3);
+ matA.plusEquals(matB);
+ Assert.assertArrayEquals(matA.getData(), new double[]{2, 3, 4, 5, 6, 7}, TOL);
+ matA.plusEquals(1.0);
+ Assert.assertArrayEquals(matA.getData(), new double[]{3, 4, 5, 6, 7, 8}, TOL);
+ }
+
+ @Test
+ public void testMinusEquals() throws Exception {
+ DenseMatrix matA = new DenseMatrix(new double[][]{
+ new double[]{1, 3, 5},
+ new double[]{2, 4, 6},
+ });
+ DenseMatrix matB = DenseMatrix.ones(2, 3);
+ matA.minusEquals(matB);
+ Assert.assertArrayEquals(matA.getData(), new double[]{0, 1, 2, 3, 4, 5}, TOL);
+ }
+
+ @Test
+ public void testPlus() throws Exception {
+ DenseMatrix matA = new DenseMatrix(new double[][]{
+ new double[]{1, 3, 5},
+ new double[]{2, 4, 6},
+ });
+ DenseMatrix matB = DenseMatrix.ones(2, 3);
+ DenseMatrix matC = matA.plus(matB);
+ Assert.assertArrayEquals(matC.getData(), new double[]{2, 3, 4, 5, 6, 7}, TOL);
+ DenseMatrix matD = matA.plus(1.0);
+ Assert.assertArrayEquals(matD.getData(), new double[]{2, 3, 4, 5, 6, 7}, TOL);
+ }
+
+ @Test
+ public void testMinus() throws Exception {
+ DenseMatrix matA = new DenseMatrix(new double[][]{
+ new double[]{1, 3, 5},
+ new double[]{2, 4, 6},
+ });
+ DenseMatrix matB = DenseMatrix.ones(2, 3);
+ DenseMatrix matC = matA.minus(matB);
+ Assert.assertArrayEquals(matC.getData(), new double[]{0, 1, 2, 3, 4, 5}, TOL);
+ }
+
+ @Test
+ public void testMM() throws Exception {
+ DenseMatrix matA = DenseMatrix.rand(4, 3);
+ DenseMatrix matB = DenseMatrix.rand(3, 5);
+ DenseMatrix matC = matA.multiplies(matB);
+ assertEqual2D(matC.getArrayCopy2D(), simpleMM(matA.getArrayCopy2D(), matB.getArrayCopy2D()));
+
+ DenseMatrix matD = new DenseMatrix(5, 4);
+ BLAS.gemm(1., matB, true, matA, true, 0., matD);
+ Assert.assertArrayEquals(matD.transpose().getData(), matC.data, TOL);
+ }
+
+ @Test
+ public void testMV() throws Exception {
+ DenseMatrix matA = DenseMatrix.rand(4, 3);
+ DenseVector x = DenseVector.ones(3);
+ DenseVector y = matA.multiplies(x);
+ Assert.assertArrayEquals(y.getData(), simpleMV(matA.getArrayCopy2D(), x.getData()), TOL);
+
+ SparseVector x2 = new SparseVector(3, new int[]{0, 1, 2}, new double[]{1, 1, 1});
+ DenseVector y2 = matA.multiplies(x2);
+ Assert.assertArrayEquals(y2.getData(), y.getData(), TOL);
+ }
+
+ @Test
+ public void testDataSelection() throws Exception {
+ DenseMatrix mat = new DenseMatrix(new double[][]{
+ new double[]{1, 2, 3},
+ new double[]{4, 5, 6},
+ new double[]{7, 8, 9},
+ });
+ DenseMatrix sub1 = mat.selectRows(new int[]{1});
+ DenseMatrix sub2 = mat.getSubMatrix(1, 2, 1, 2);
+ Assert.assertEquals(sub1.numRows(), 1);
+ Assert.assertEquals(sub1.numCols(), 3);
+ Assert.assertEquals(sub2.numRows(), 1);
+ Assert.assertEquals(sub2.numCols(), 1);
+ Assert.assertArrayEquals(sub1.getData(), new double[]{4, 5, 6}, TOL);
+ Assert.assertArrayEquals(sub2.getData(), new double[]{5}, TOL);
+
+ double[] row = mat.getRow(1);
+ double[] col = mat.getColumn(1);
+ Assert.assertArrayEquals(row, new double[]{4, 5, 6}, 0.);
+ Assert.assertArrayEquals(col, new double[]{2, 5, 8}, 0.);
+ }
+
+ @Test
+ public void testSum() throws Exception {
+ DenseMatrix matA = DenseMatrix.ones(3, 2);
+ Assert.assertEquals(matA.sum(), 6.0, TOL);
+ }
+
+ @Test
+ public void testRowMajorFormat() throws Exception {
+ double[] data = new double[]{1, 2, 3, 4, 5, 6};
+ DenseMatrix matA = new DenseMatrix(2, 3, data, true);
+ Assert.assertArrayEquals(data, new double[]{1, 4, 2, 5, 3, 6}, 0.);
+ Assert.assertArrayEquals(matA.getData(), new double[]{1, 4, 2, 5, 3, 6}, 0.);
+
+ data = new double[]{1, 2, 3, 4};
+ matA = new DenseMatrix(2, 2, data, true);
+ Assert.assertArrayEquals(data, new double[]{1, 3, 2, 4}, 0.);
+ Assert.assertArrayEquals(matA.getData(), new double[]{1, 3, 2, 4}, 0.);
+ }
+}
diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseVectorTest.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseVectorTest.java
new file mode 100644
index 0000000000000..3e972040f8047
--- /dev/null
+++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseVectorTest.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.common.linalg;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test cases for DenseVector.
+ */
+public class DenseVectorTest {
+ private static final double TOL = 1.0e-6;
+
+ @Test
+ public void testSize() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ Assert.assertEquals(vec.size(), 3);
+ }
+
+ @Test
+ public void testNormL1() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ Assert.assertEquals(vec.normL1(), 6, 0);
+ }
+
+ @Test
+ public void testNormMax() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ Assert.assertEquals(vec.normInf(), 3, 0);
+ }
+
+ @Test
+ public void testNormL2() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ Assert.assertEquals(vec.normL2(), Math.sqrt(1 + 4 + 9), TOL);
+ }
+
+ @Test
+ public void testNormL2Square() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ Assert.assertEquals(vec.normL2Square(), 1 + 4 + 9, TOL);
+ }
+
+ @Test
+ public void testSlice() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ DenseVector sliced = vec.slice(new int[]{0, 2});
+ Assert.assertArrayEquals(new double[]{1, -3}, sliced.getData(), 0);
+ }
+
+ @Test
+ public void testMinus() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ DenseVector d = new DenseVector(new double[]{1, 2, 1});
+ DenseVector vec2 = vec.minus(d);
+ Assert.assertArrayEquals(vec.getData(), new double[]{1, 2, -3}, 0);
+ Assert.assertArrayEquals(vec2.getData(), new double[]{0, 0, -4}, TOL);
+ vec.minusEqual(d);
+ Assert.assertArrayEquals(vec.getData(), new double[]{0, 0, -4}, TOL);
+ }
+
+ @Test
+ public void testPlus() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ DenseVector d = new DenseVector(new double[]{1, 2, 1});
+ DenseVector vec2 = vec.plus(d);
+ Assert.assertArrayEquals(vec.getData(), new double[]{1, 2, -3}, 0);
+ Assert.assertArrayEquals(vec2.getData(), new double[]{2, 4, -2}, TOL);
+ vec.plusEqual(d);
+ Assert.assertArrayEquals(vec.getData(), new double[]{2, 4, -2}, TOL);
+ }
+
+ @Test
+ public void testPlusScaleEqual() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ DenseVector vec2 = new DenseVector(new double[]{1, 0, 2});
+ vec.plusScaleEqual(vec2, 2.);
+ Assert.assertArrayEquals(vec.getData(), new double[]{3, 2, 1}, TOL);
+ }
+
+ @Test
+ public void testDot() throws Exception {
+ DenseVector vec1 = new DenseVector(new double[]{1, 2, -3});
+ DenseVector vec2 = new DenseVector(new double[]{3, 2, 1});
+ Assert.assertEquals(vec1.dot(vec2), 3 + 4 - 3, TOL);
+ }
+
+ @Test
+ public void testPrefix() throws Exception {
+ DenseVector vec1 = new DenseVector(new double[]{1, 2, -3});
+ DenseVector vec2 = vec1.prefix(0);
+ Assert.assertArrayEquals(vec2.getData(), new double[]{0, 1, 2, -3}, 0);
+ }
+
+ @Test
+ public void testAppend() throws Exception {
+ DenseVector vec1 = new DenseVector(new double[]{1, 2, -3});
+ DenseVector vec2 = vec1.append(0);
+ Assert.assertArrayEquals(vec2.getData(), new double[]{1, 2, -3, 0}, 0);
+ }
+
+ @Test
+ public void testOuter() throws Exception {
+ DenseVector vec1 = new DenseVector(new double[]{1, 2, -3});
+ DenseVector vec2 = new DenseVector(new double[]{3, 2, 1});
+ DenseMatrix outer = vec1.outer(vec2);
+ Assert.assertArrayEquals(outer.getArrayCopy1D(true),
+ new double[]{3, 2, 1, 6, 4, 2, -9, -6, -3}, TOL);
+ }
+
+ @Test
+ public void testNormalize() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ vec.normalizeEqual(1.0);
+ Assert.assertArrayEquals(vec.getData(), new double[]{1. / 6, 2. / 6, -3. / 6}, TOL);
+ }
+
+ @Test
+ public void testStandardize() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ vec.standardizeEqual(1.0, 1.0);
+ Assert.assertArrayEquals(vec.getData(), new double[]{0, 1, -4}, TOL);
+ }
+
+ @Test
+ public void testIterator() throws Exception {
+ DenseVector vec = new DenseVector(new double[]{1, 2, -3});
+ VectorIterator iterator = vec.iterator();
+ Assert.assertTrue(iterator.hasNext());
+ Assert.assertEquals(iterator.getIndex(), 0);
+ Assert.assertEquals(iterator.getValue(), 1, 0);
+ iterator.next();
+ Assert.assertTrue(iterator.hasNext());
+ Assert.assertEquals(iterator.getIndex(), 1);
+ Assert.assertEquals(iterator.getValue(), 2, 0);
+ iterator.next();
+ Assert.assertTrue(iterator.hasNext());
+ Assert.assertEquals(iterator.getIndex(), 2);
+ Assert.assertEquals(iterator.getValue(), -3, 0);
+ iterator.next();
+ Assert.assertFalse(iterator.hasNext());
+ }
+
+}
diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/MatVecOpTest.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/MatVecOpTest.java
new file mode 100644
index 0000000000000..2415dcb0c1bcd
--- /dev/null
+++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/MatVecOpTest.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.common.linalg;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Test cases for {@link MatVecOp}.
+ */
+public class MatVecOpTest {
+ private static final double TOL = 1.0e-6;
+ private DenseVector dv;
+ private SparseVector sv;
+
+ @Before
+ public void setUp() throws Exception {
+ dv = new DenseVector(new double[]{1, 2, 3, 4});
+ sv = new SparseVector(4, new int[]{0, 2}, new double[]{1., 1.});
+ }
+
+ @Test
+ public void testPlus() throws Exception {
+ Vector plusResult1 = MatVecOp.plus(dv, sv);
+ Vector plusResult2 = MatVecOp.plus(sv, dv);
+ Vector plusResult3 = MatVecOp.plus(sv, sv);
+ Vector plusResult4 = MatVecOp.plus(dv, dv);
+ Assert.assertTrue(plusResult1 instanceof DenseVector);
+ Assert.assertTrue(plusResult2 instanceof DenseVector);
+ Assert.assertTrue(plusResult3 instanceof SparseVector);
+ Assert.assertTrue(plusResult4 instanceof DenseVector);
+ Assert.assertArrayEquals(((DenseVector) plusResult1).getData(), new double[]{2, 2, 4, 4}, TOL);
+ Assert.assertArrayEquals(((DenseVector) plusResult2).getData(), new double[]{2, 2, 4, 4}, TOL);
+ Assert.assertArrayEquals(((SparseVector) plusResult3).getIndices(), new int[]{0, 2});
+ Assert.assertArrayEquals(((SparseVector) plusResult3).getValues(), new double[]{2., 2.}, TOL);
+ Assert.assertArrayEquals(((DenseVector) plusResult4).getData(), new double[]{2, 4, 6, 8}, TOL);
+ }
+
+ @Test
+ public void testMinus() throws Exception {
+ Vector minusResult1 = MatVecOp.minus(dv, sv);
+ Vector minusResult2 = MatVecOp.minus(sv, dv);
+ Vector minusResult3 = MatVecOp.minus(sv, sv);
+ Vector minusResult4 = MatVecOp.minus(dv, dv);
+ Assert.assertTrue(minusResult1 instanceof DenseVector);
+ Assert.assertTrue(minusResult2 instanceof DenseVector);
+ Assert.assertTrue(minusResult3 instanceof SparseVector);
+ Assert.assertTrue(minusResult4 instanceof DenseVector);
+ Assert.assertArrayEquals(((DenseVector) minusResult1).getData(), new double[]{0, 2, 2, 4}, TOL);
+ Assert.assertArrayEquals(((DenseVector) minusResult2).getData(), new double[]{0, -2, -2, -4}, TOL);
+ Assert.assertArrayEquals(((SparseVector) minusResult3).getIndices(), new int[]{0, 2});
+ Assert.assertArrayEquals(((SparseVector) minusResult3).getValues(), new double[]{0., 0.}, TOL);
+ Assert.assertArrayEquals(((DenseVector) minusResult4).getData(), new double[]{0, 0, 0, 0}, TOL);
+ }
+
+ @Test
+ public void testDot() throws Exception {
+ Assert.assertEquals(MatVecOp.dot(dv, sv), 4.0, TOL);
+ Assert.assertEquals(MatVecOp.dot(sv, dv), 4.0, TOL);
+ Assert.assertEquals(MatVecOp.dot(sv, sv), 2.0, TOL);
+ Assert.assertEquals(MatVecOp.dot(dv, dv), 30.0, TOL);
+ }
+
+ @Test
+ public void testSumAbsDiff() throws Exception {
+ Assert.assertEquals(MatVecOp.sumAbsDiff(dv, sv), 8.0, TOL);
+ Assert.assertEquals(MatVecOp.sumAbsDiff(sv, dv), 8.0, TOL);
+ Assert.assertEquals(MatVecOp.sumAbsDiff(sv, sv), 0.0, TOL);
+ Assert.assertEquals(MatVecOp.sumAbsDiff(dv, dv), 0.0, TOL);
+ }
+
+ @Test
+ public void testSumSquaredDiff() throws Exception {
+ Assert.assertEquals(MatVecOp.sumSquaredDiff(dv, sv), 24.0, TOL);
+ Assert.assertEquals(MatVecOp.sumSquaredDiff(sv, dv), 24.0, TOL);
+ Assert.assertEquals(MatVecOp.sumSquaredDiff(sv, sv), 0.0, TOL);
+ Assert.assertEquals(MatVecOp.sumSquaredDiff(dv, dv), 0.0, TOL);
+ }
+}
diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/SparseVectorTest.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/SparseVectorTest.java
new file mode 100644
index 0000000000000..7c58681b20221
--- /dev/null
+++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/SparseVectorTest.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.common.linalg;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Map;
+import java.util.TreeMap;
+
+/**
+ * Test cases for SparseVector.
+ */
+public class SparseVectorTest {
+ private static final double TOL = 1.0e-6;
+ private SparseVector v1 = new SparseVector(8, new int[]{1, 3, 5, 7}, new double[]{2.0, 2.0, 2.0, 2.0});
+ private SparseVector v2 = new SparseVector(8, new int[]{3, 4, 5}, new double[]{1.0, 1.0, 1.0});
+
+ @Test
+ public void testConstructor() throws Exception {
+ int[] indices = new int[]{3, 7, 2, 1};
+ double[] values = new double[]{3.0, 7.0, 2.0, 1.0};
+ Map
+ * Vector vector = ...;
+ * VectorIterator iterator = vector.iterator();
+ *
+ * while(iterator.hasNext()) {
+ * int index = iterator.getIndex();
+ * double value = iterator.getValue();
+ * iterator.next();
+ * }
+ *
+ */
+public interface VectorIterator extends Serializable {
+
+ /**
+ * Returns {@code true} if the iteration has more elements.
+ * Otherwise, {@code false} will be returned.
+ *
+ * @return {@code true} if the iteration has more elements
+ */
+ boolean hasNext();
+
+ /**
+ * Trigger the cursor points to the next element of the vector.
+ *
+ *