diff --git a/flink-ml-parent/flink-ml-lib/pom.xml b/flink-ml-parent/flink-ml-lib/pom.xml index bc5ca37b38c89..391b2acd86b12 100644 --- a/flink-ml-parent/flink-ml-lib/pom.xml +++ b/flink-ml-parent/flink-ml-lib/pom.xml @@ -34,5 +34,10 @@ under the License. flink-ml-api ${project.version} + + com.github.fommil.netlib + core + 1.1.2 + diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java new file mode 100644 index 0000000000000..f5d9e8497e8e6 --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +/** + * A utility class that provides BLAS routines over matrices and vectors. + */ +public class BLAS { + private static final com.github.fommil.netlib.BLAS NATIVE_BLAS = com.github.fommil.netlib.BLAS.getInstance(); + private static final com.github.fommil.netlib.BLAS F2J_BLAS = com.github.fommil.netlib.F2jBLAS.getInstance(); + + /** + * y += a * x . + */ + public static void axpy(double a, double[] x, double[] y) { + assert x.length == y.length : "Array dimension mismatched."; + F2J_BLAS.daxpy(x.length, a, x, 1, y, 1); + } + + /** + * y += a * x . + */ + public static void axpy(double a, DenseVector x, DenseVector y) { + assert x.data.length == y.data.length : "Vector dimension mismatched."; + F2J_BLAS.daxpy(x.data.length, a, x.data, 1, y.data, 1); + } + + /** + * y += a * x . + */ + public static void axpy(double a, SparseVector x, DenseVector y) { + for (int i = 0; i < x.indices.length; i++) { + y.data[x.indices[i]] += a * x.values[i]; + } + } + + /** + * y += a * x . + */ + public static void axpy(double a, DenseMatrix x, DenseMatrix y) { + assert x.m == y.m && x.n == y.n : "Matrix dimension mismatched."; + F2J_BLAS.daxpy(x.data.length, a, x.data, 1, y.data, 1); + } + + /** + * x \cdot y . + */ + public static double dot(double[] x, double[] y) { + assert x.length == y.length : "Array dimension mismatched."; + return F2J_BLAS.ddot(x.length, x, 1, y, 1); + } + + /** + * x \cdot y . + */ + public static double dot(DenseVector x, DenseVector y) { + assert x.data.length == y.data.length : "Vector dimension mismatched."; + return F2J_BLAS.ddot(x.data.length, x.data, 1, y.data, 1); + } + + /** + * x = x * a . + */ + public static void scal(double a, double[] x) { + F2J_BLAS.dscal(x.length, a, x, 1); + } + + /** + * x = x * a . + */ + public static void scal(double a, DenseVector x) { + F2J_BLAS.dscal(x.data.length, a, x.data, 1); + } + + /** + * x = x * a . + */ + public static void scal(double a, SparseVector x) { + F2J_BLAS.dscal(x.values.length, a, x.values, 1); + } + + /** + * x = x * a . + */ + public static void scal(double a, DenseMatrix x) { + F2J_BLAS.dscal(x.data.length, a, x.data, 1); + } + + /** + * C := alpha * A * B + beta * C . + */ + public static void gemm(double alpha, DenseMatrix matA, boolean transA, DenseMatrix matB, boolean transB, + double beta, DenseMatrix matC) { + if (transA) { + assert matA.numCols() == matC.numRows() : "The columns of A does not match the rows of C"; + } else { + assert matA.numRows() == matC.numRows() : "The rows of A does not match the rows of C"; + } + if (transB) { + assert matB.numRows() == matC.numCols() : "The rows of B does not match the columns of C"; + } else { + assert matB.numCols() == matC.numCols() : "The columns of B does not match the columns of C"; + } + + final int m = matC.numRows(); + final int n = matC.numCols(); + final int k = transA ? matA.numRows() : matA.numCols(); + final int lda = matA.numRows(); + final int ldb = matB.numRows(); + final int ldc = matC.numRows(); + final String ta = transA ? "T" : "N"; + final String tb = transB ? "T" : "N"; + NATIVE_BLAS.dgemm(ta, tb, m, n, k, alpha, matA.getData(), lda, matB.getData(), ldb, beta, matC.getData(), ldc); + } + + /** + * y := alpha * A * x + beta * y . + */ + public static void gemv(double alpha, DenseMatrix matA, boolean transA, + DenseVector x, double beta, DenseVector y) { + if (transA) { + assert (matA.numCols() == y.size() && matA.numRows() == x.size()) : "Matrix and vector size mismatched."; + } else { + assert (matA.numRows() == y.size() && matA.numCols() == x.size()) : "Matrix and vector size mismatched."; + } + final int m = matA.numRows(); + final int n = matA.numCols(); + final int lda = matA.numRows(); + final String ta = transA ? "T" : "N"; + NATIVE_BLAS.dgemv(ta, m, n, alpha, matA.getData(), lda, x.getData(), 1, beta, y.getData(), 1); + } +} diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseMatrix.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseMatrix.java new file mode 100644 index 0000000000000..2b25aa1b4a727 --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseMatrix.java @@ -0,0 +1,606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * DenseMatrix stores dense matrix data and provides some methods to operate on + * the matrix it represents. + */ +public class DenseMatrix implements Serializable { + + /** + * Row dimension. + * + *

Package private to allow access from {@link MatVecOp} and {@link BLAS}. + */ + int m; + + /** + * Column dimension. + * + *

Package private to allow access from {@link MatVecOp} and {@link BLAS}. + */ + int n; + + /** + * Array for internal storage of elements. + * + *

Package private to allow access from {@link MatVecOp} and {@link BLAS}. + * + *

The matrix data is stored in column major format internally. + */ + double[] data; + + /** + * Construct an m-by-n matrix of zeros. + * + * @param m Number of rows. + * @param n Number of colums. + */ + public DenseMatrix(int m, int n) { + this(m, n, new double[m * n], false); + } + + /** + * Construct a matrix from a 1-D array. The data in the array should organize + * in column major. + * + * @param m Number of rows. + * @param n Number of cols. + * @param data One-dimensional array of doubles. + */ + public DenseMatrix(int m, int n, double[] data) { + this(m, n, data, false); + } + + /** + * Construct a matrix from a 1-D array. The data in the array is organized + * in column major or in row major, which is specified by parameter 'inRowMajor' + * + * @param m Number of rows. + * @param n Number of cols. + * @param data One-dimensional array of doubles. + * @param inRowMajor Whether the matrix in 'data' is in row major format. + */ + public DenseMatrix(int m, int n, double[] data, boolean inRowMajor) { + assert (data.length == m * n); + this.m = m; + this.n = n; + if (inRowMajor) { + toColumnMajor(m, n, data); + } + this.data = data; + } + + /** + * Construct a matrix from a 2-D array. + * + * @param data Two-dimensional array of doubles. + * @throws IllegalArgumentException All rows must have the same size + */ + public DenseMatrix(double[][] data) { + this.m = data.length; + if (this.m == 0) { + this.n = 0; + this.data = new double[0]; + return; + } + this.n = data[0].length; + for (int i = 0; i < m; i++) { + if (data[i].length != n) { + throw new IllegalArgumentException("All rows must have the same size."); + } + } + this.data = new double[m * n]; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + this.set(i, j, data[i][j]); + } + } + } + + /** + * Create an identity matrix. + * + * @param n the dimension of the eye matrix. + * @return an identity matrix. + */ + public static DenseMatrix eye(int n) { + return eye(n, n); + } + + /** + * Create a m * n identity matrix. + * + * @param m the row dimension. + * @param n the column dimension.e + * @return the m * n identity matrix. + */ + public static DenseMatrix eye(int m, int n) { + DenseMatrix mat = new DenseMatrix(m, n); + int k = Math.min(m, n); + for (int i = 0; i < k; i++) { + mat.data[i * m + i] = 1.0; + } + return mat; + } + + /** + * Create a zero matrix. + * + * @param m the row dimension. + * @param n the column dimension. + * @return a m * n zero matrix. + */ + public static DenseMatrix zeros(int m, int n) { + return new DenseMatrix(m, n); + } + + /** + * Create a matrix with all elements set to 1. + * + * @param m the row dimension + * @param n the column dimension + * @return the m * n matrix with all elements set to 1. + */ + public static DenseMatrix ones(int m, int n) { + DenseMatrix mat = new DenseMatrix(m, n); + Arrays.fill(mat.data, 1.); + return mat; + } + + /** + * Create a random matrix. + * + * @param m the row dimension + * @param n the column dimension. + * @return a m * n random matrix. + */ + public static DenseMatrix rand(int m, int n) { + DenseMatrix mat = new DenseMatrix(m, n); + for (int i = 0; i < mat.data.length; i++) { + mat.data[i] = Math.random(); + } + return mat; + } + + /** + * Create a random symmetric matrix. + * + * @param n the dimension of the symmetric matrix. + * @return a n * n random symmetric matrix. + */ + public static DenseMatrix randSymmetric(int n) { + DenseMatrix mat = new DenseMatrix(n, n); + for (int i = 0; i < n; i++) { + for (int j = i; j < n; j++) { + double r = Math.random(); + mat.set(i, j, r); + if (i != j) { + mat.set(j, i, r); + } + } + } + return mat; + } + + /** + * Get a single element. + * + * @param i Row index. + * @param j Column index. + * @return matA(i, j) + * @throws ArrayIndexOutOfBoundsException + */ + public double get(int i, int j) { + return data[j * m + i]; + } + + /** + * Get the data array of this matrix. + * + * @return the data array of this matrix. + */ + public double[] getData() { + return this.data; + } + + /** + * Get all the matrix data, returned as a 2-D array. + * + * @return all matrix data, returned as a 2-D array. + */ + public double[][] getArrayCopy2D() { + double[][] arrayData = new double[m][n]; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + arrayData[i][j] = this.get(i, j); + } + } + return arrayData; + } + + /** + * Get all matrix data, returned as a 1-D array. + * + * @param inRowMajor Whether to return data in row major. + * @return all matrix data, returned as a 1-D array. + */ + public double[] getArrayCopy1D(boolean inRowMajor) { + if (inRowMajor) { + double[] arrayData = new double[m * n]; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + arrayData[i * n + j] = this.get(i, j); + } + } + return arrayData; + } else { + return this.data.clone(); + } + } + + /** + * Get one row. + * + * @param row the row index. + * @return the row with the given index. + */ + public double[] getRow(int row) { + assert (row >= 0 && row < m) : "Invalid row index."; + double[] r = new double[n]; + for (int i = 0; i < n; i++) { + r[i] = this.get(row, i); + } + return r; + } + + /** + * Get one column. + * + * @param col the column index. + * @return the column with the given index. + */ + public double[] getColumn(int col) { + assert (col >= 0 && col < n) : "Invalid column index."; + double[] columnData = new double[m]; + System.arraycopy(this.data, col * m, columnData, 0, m); + return columnData; + } + + /** + * Clone the Matrix object. + */ + @Override + public DenseMatrix clone() { + return new DenseMatrix(this.m, this.n, this.data.clone(), false); + } + + /** + * Create a new matrix by selecting some of the rows. + * + * @param rows the array of row indexes to select. + * @return a new matrix by selecting some of the rows. + */ + public DenseMatrix selectRows(int[] rows) { + DenseMatrix sub = new DenseMatrix(rows.length, this.n); + for (int i = 0; i < rows.length; i++) { + for (int j = 0; j < this.n; j++) { + sub.set(i, j, this.get(rows[i], j)); + } + } + return sub; + } + + /** + * Get sub matrix. + * + * @param m0 the starting row index (inclusive) + * @param m1 the ending row index (exclusive) + * @param n0 the starting column index (inclusive) + * @param n1 the ending column index (exclusive) + * @return the specified sub matrix. + */ + public DenseMatrix getSubMatrix(int m0, int m1, int n0, int n1) { + assert (m0 >= 0 && m1 <= m) && (n0 >= 0 && n1 <= n) : "Invalid index range."; + DenseMatrix sub = new DenseMatrix(m1 - m0, n1 - n0); + for (int i = 0; i < sub.m; i++) { + for (int j = 0; j < sub.n; j++) { + sub.set(i, j, this.get(m0 + i, n0 + j)); + } + } + return sub; + } + + /** + * Set part of the matrix values from the values of another matrix. + * + * @param sub the matrix whose element values will be assigned to the sub matrix of this matrix. + * @param m0 the starting row index (inclusive) + * @param m1 the ending row index (exclusive) + * @param n0 the starting column index (inclusive) + * @param n1 the ending column index (exclusive) + */ + public void setSubMatrix(DenseMatrix sub, int m0, int m1, int n0, int n1) { + assert (m0 >= 0 && m1 <= m) && (n0 >= 0 && n1 <= n) : "Invalid index range."; + for (int i = 0; i < sub.m; i++) { + for (int j = 0; j < sub.n; j++) { + this.set(m0 + i, n0 + j, sub.get(i, j)); + } + } + } + + /** + * Set a single element. + * + * @param i Row index. + * @param j Column index. + * @param s A(i,j). + * @throws ArrayIndexOutOfBoundsException + */ + public void set(int i, int j, double s) { + data[j * m + i] = s; + } + + /** + * Add the given value to a single element. + * + * @param i Row index. + * @param j Column index. + * @param s A(i,j). + * @throws ArrayIndexOutOfBoundsException + */ + public void add(int i, int j, double s) { + data[j * m + i] += s; + } + + /** + * Check whether the matrix is square matrix. + * + * @return true if this matrix is a square matrix, false otherwise. + */ + public boolean isSquare() { + return m == n; + } + + /** + * Check whether the matrix is symmetric matrix. + * + * @return true if this matrix is a symmetric matrix, false otherwise. + */ + public boolean isSymmetric() { + if (m != n) { + return false; + } + for (int i = 0; i < n; i++) { + for (int j = i + 1; j < n; j++) { + if (this.get(i, j) != this.get(j, i)) { + return false; + } + } + } + return true; + } + + /** + * Get the number of rows. + * + * @return the number of rows. + */ + public int numRows() { + return m; + } + + /** + * Get the number of columns. + * + * @return the number of columns. + */ + public int numCols() { + return n; + } + + /** + * Sum of all elements of the matrix. + */ + public double sum() { + double s = 0.; + for (int i = 0; i < this.data.length; i++) { + s += this.data[i]; + } + return s; + } + + /** + * Scale the vector by value "v" and create a new matrix to store the result. + */ + public DenseMatrix scale(double v) { + DenseMatrix r = this.clone(); + BLAS.scal(v, r); + return r; + } + + /** + * Scale the matrix by value "v". + */ + public void scaleEqual(double v) { + BLAS.scal(v, this); + } + + /** + * Create a new matrix by plussing another matrix. + */ + public DenseMatrix plus(DenseMatrix mat) { + DenseMatrix r = this.clone(); + BLAS.axpy(1.0, mat, r); + return r; + } + + /** + * Create a new matrix by plussing a constant. + */ + public DenseMatrix plus(double alpha) { + DenseMatrix r = this.clone(); + for (int i = 0; i < r.data.length; i++) { + r.data[i] += alpha; + } + return r; + } + + /** + * Plus with another matrix. + */ + public void plusEquals(DenseMatrix mat) { + BLAS.axpy(1.0, mat, this); + } + + /** + * Plus with a constant. + */ + public void plusEquals(double alpha) { + for (int i = 0; i < this.data.length; i++) { + this.data[i] += alpha; + } + } + + /** + * Create a new matrix by subtracting another matrix. + */ + public DenseMatrix minus(DenseMatrix mat) { + DenseMatrix r = this.clone(); + BLAS.axpy(-1.0, mat, r); + return r; + } + + /** + * Minus with another vector. + */ + public void minusEquals(DenseMatrix mat) { + BLAS.axpy(-1.0, mat, this); + } + + /** + * Multiply with another matrix. + */ + public DenseMatrix multiplies(DenseMatrix mat) { + DenseMatrix r = new DenseMatrix(this.m, mat.n); + BLAS.gemm(1.0, this, false, mat, false, 0., r); + return r; + } + + /** + * Multiply with a dense vector. + */ + public DenseVector multiplies(DenseVector x) { + DenseVector y = new DenseVector(this.numRows()); + BLAS.gemv(1.0, this, false, x, 0.0, y); + return y; + } + + /** + * Multiply with a sparse vector. + */ + public DenseVector multiplies(SparseVector x) { + DenseVector y = new DenseVector(this.numRows()); + for (int i = 0; i < this.numRows(); i++) { + double s = 0.; + int[] indices = x.getIndices(); + double[] values = x.getValues(); + for (int j = 0; j < indices.length; j++) { + int index = indices[j]; + if (index >= this.numCols()) { + throw new RuntimeException("Vector index out of bound:" + index); + } + s += this.get(i, index) * values[j]; + } + y.set(i, s); + } + return y; + } + + /** + * Create a new matrix by transposing current matrix. + * + *

Use cache-oblivious matrix transpose algorithm. + */ + public DenseMatrix transpose() { + DenseMatrix mat = new DenseMatrix(n, m); + int m0 = m; + int n0 = n; + int barrierSize = 16384; + while (m0 * n0 > barrierSize) { + if (m0 >= n0) { + m0 /= 2; + } else { + n0 /= 2; + } + } + for (int i0 = 0; i0 < m; i0 += m0) { + for (int j0 = 0; j0 < n; j0 += n0) { + for (int i = i0; i < i0 + m0 && i < m; i++) { + for (int j = j0; j < j0 + n0 && j < n; j++) { + mat.set(j, i, this.get(i, j)); + } + } + } + } + return mat; + } + + /** + * Converts the data layout in "data" from row major to column major. + */ + private static void toColumnMajor(int m, int n, double[] data) { + if (m == n) { + for (int i = 0; i < m; i++) { + for (int j = i + 1; j < m; j++) { + int pos0 = j * m + i; + int pos1 = i * m + j; + double t = data[pos0]; + data[pos0] = data[pos1]; + data[pos1] = t; + } + } + } else { + DenseMatrix temp = new DenseMatrix(n, m, data, false); + System.arraycopy(temp.transpose().data, 0, data, 0, data.length); + } + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(); + sbd.append(String.format("mat[%d,%d]:\n", m, n)); + for (int i = 0; i < m; i++) { + sbd.append(" "); + for (int j = 0; j < n; j++) { + if (j > 0) { + sbd.append(","); + } + sbd.append(this.get(i, j)); + } + sbd.append("\n"); + } + return sbd.toString(); + } +} diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseVector.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseVector.java new file mode 100644 index 0000000000000..6c3337b75346d --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseVector.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import java.util.Arrays; +import java.util.Random; + +/** + * A dense vector represented by a values array. + */ +public class DenseVector extends Vector { + /** + * The array holding the vector data. + *

+ * Package private to allow access from {@link MatVecOp} and {@link BLAS}. + */ + double[] data; + + /** + * Create a zero size vector. + */ + public DenseVector() { + this(0); + } + + /** + * Create a size n vector with all elements zero. + * + * @param n Size of the vector. + */ + public DenseVector(int n) { + this.data = new double[n]; + } + + /** + * Create a dense vector with the user provided data. + * + * @param data The vector data. + */ + public DenseVector(double[] data) { + this.data = data; + } + + /** + * Get the data array. + */ + public double[] getData() { + return this.data; + } + + /** + * Set the data array. + */ + public void setData(double[] data) { + this.data = data; + } + + /** + * Create a dense vector with all elements one. + * + * @param n Size of the vector. + * @return The newly created dense vector. + */ + public static DenseVector ones(int n) { + DenseVector r = new DenseVector(n); + Arrays.fill(r.data, 1.0); + return r; + } + + /** + * Create a dense vector with all elements zero. + * + * @param n Size of the vector. + * @return The newly created dense vector. + */ + public static DenseVector zeros(int n) { + DenseVector r = new DenseVector(n); + Arrays.fill(r.data, 0.0); + return r; + } + + /** + * Create a dense vector with random values uniformly distributed in the range of [0.0, 1.0]. + * + * @param n Size of the vector. + * @return The newly created dense vector. + */ + public static DenseVector rand(int n) { + Random random = new Random(); + DenseVector v = new DenseVector(n); + for (int i = 0; i < n; i++) { + v.data[i] = random.nextDouble(); + } + return v; + } + + @Override + public DenseVector clone() { + return new DenseVector(this.data.clone()); + } + + @Override + public String toString() { + return VectorUtil.toString(this); + } + + @Override + public int size() { + return data.length; + } + + @Override + public double get(int i) { + return data[i]; + } + + @Override + public void set(int i, double d) { + data[i] = d; + } + + @Override + public void add(int i, double d) { + data[i] += d; + } + + @Override + public double normL1() { + double d = 0; + for (double t : data) { + d += Math.abs(t); + } + return d; + } + + @Override + public double normL2() { + double d = 0; + for (double t : data) { + d += t * t; + } + return Math.sqrt(d); + } + + @Override + public double normL2Square() { + double d = 0; + for (double t : data) { + d += t * t; + } + return d; + } + + @Override + public double normInf() { + double d = 0; + for (double t : data) { + d = Math.max(Math.abs(t), d); + } + return d; + } + + @Override + public DenseVector slice(int[] indices) { + double[] values = new double[indices.length]; + for (int i = 0; i < indices.length; ++i) { + if (indices[i] >= data.length) { + throw new RuntimeException("Index is larger than vector size."); + } + values[i] = data[indices[i]]; + } + return new DenseVector(values); + } + + @Override + public DenseVector prefix(double d) { + double[] data = new double[this.size() + 1]; + data[0] = d; + System.arraycopy(this.data, 0, data, 1, this.data.length); + return new DenseVector(data); + } + + @Override + public DenseVector append(double d) { + double[] data = new double[this.size() + 1]; + System.arraycopy(this.data, 0, data, 0, this.data.length); + data[this.size()] = d; + return new DenseVector(data); + } + + @Override + public void scaleEqual(double d) { + BLAS.scal(d, this); + } + + @Override + public DenseVector plus(Vector other) { + DenseVector r = this.clone(); + if (other instanceof DenseVector) { + BLAS.axpy(1.0, (DenseVector) other, r); + } else { + BLAS.axpy(1.0, (SparseVector) other, r); + } + return r; + } + + @Override + public DenseVector minus(Vector other) { + DenseVector r = this.clone(); + if (other instanceof DenseVector) { + BLAS.axpy(-1.0, (DenseVector) other, r); + } else { + BLAS.axpy(-1.0, (SparseVector) other, r); + } + return r; + } + + @Override + public DenseVector scale(double d) { + DenseVector r = this.clone(); + BLAS.scal(d, r); + return r; + } + + @Override + public double dot(Vector vec) { + if (vec instanceof DenseVector) { + return BLAS.dot(this, (DenseVector) vec); + } else { + return vec.dot(this); + } + } + + @Override + public void standardizeEqual(double mean, double stdvar) { + int size = data.length; + for (int i = 0; i < size; i++) { + data[i] -= mean; + data[i] *= (1.0 / stdvar); + } + } + + @Override + public void normalizeEqual(double p) { + double norm = 0.0; + if (Double.isInfinite(p)) { + norm = normInf(); + } else if (p == 1.0) { + norm = normL1(); + } else if (p == 2.0) { + norm = normL2(); + } else { + for (int i = 0; i < data.length; i++) { + norm += Math.pow(Math.abs(data[i]), p); + } + norm = Math.pow(norm, 1 / p); + } + for (int i = 0; i < data.length; i++) { + data[i] /= norm; + } + } + + /** + * Set the data of the vector the same as those of another vector. + */ + public void setEqual(DenseVector other) { + assert this.size() == other.size() : "Size of the two vectors mismatched."; + System.arraycopy(other.data, 0, this.data, 0, this.size()); + } + + /** + * Plus with another vector. + */ + public void plusEqual(Vector other) { + if (other instanceof DenseVector) { + BLAS.axpy(1.0, (DenseVector) other, this); + } else { + BLAS.axpy(1.0, (SparseVector) other, this); + } + } + + /** + * Minus with another vector. + */ + public void minusEqual(Vector other) { + if (other instanceof DenseVector) { + BLAS.axpy(-1.0, (DenseVector) other, this); + } else { + BLAS.axpy(-1.0, (SparseVector) other, this); + } + } + + /** + * Plus with another vector scaled by "alpha". + */ + public void plusScaleEqual(Vector other, double alpha) { + if (other instanceof DenseVector) { + BLAS.axpy(alpha, (DenseVector) other, this); + } else { + BLAS.axpy(alpha, (SparseVector) other, this); + } + } + + @Override + public DenseMatrix outer() { + return this.outer(this); + } + + /** + * Compute the outer product with another vector. + * + * @return The outer product matrix. + */ + public DenseMatrix outer(DenseVector other) { + int nrows = this.size(); + int ncols = other.size(); + double[] data = new double[nrows * ncols]; + int pos = 0; + for (int j = 0; j < ncols; j++) { + for (int i = 0; i < nrows; i++) { + data[pos++] = this.data[i] * other.data[j]; + } + } + return new DenseMatrix(nrows, ncols, data, false); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DenseVector that = (DenseVector) o; + return Arrays.equals(data, that.data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } + + @Override + public VectorIterator iterator() { + return new DenseVectorIterator(); + } + + private class DenseVectorIterator implements VectorIterator { + private int cursor = 0; + + @Override + public boolean hasNext() { + return cursor < data.length; + } + + @Override + public void next() { + cursor++; + } + + @Override + public int getIndex() { + if (cursor >= data.length) { + throw new RuntimeException("Iterator out of bound."); + } + return cursor; + } + + @Override + public double getValue() { + if (cursor >= data.length) { + throw new RuntimeException("Iterator out of bound."); + } + return data[cursor]; + } + } +} diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/MatVecOp.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/MatVecOp.java new file mode 100644 index 0000000000000..9bbfbbf714052 --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/MatVecOp.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import java.util.function.BiFunction; +import java.util.function.Function; + +/** + * A utility class that provides operations over {@link DenseVector}, {@link SparseVector} + * and {@link DenseMatrix}. + */ +public class MatVecOp { + /** + * compute vec1 + vec2 . + */ + public static Vector plus(Vector vec1, Vector vec2) { + return vec1.plus(vec2); + } + + /** + * compute vec1 - vec2 . + */ + public static Vector minus(Vector vec1, Vector vec2) { + return vec1.minus(vec2); + } + + /** + * Compute vec1 \cdot vec2 . + */ + public static double dot(Vector vec1, Vector vec2) { + return vec1.dot(vec2); + } + + /** + * Compute || vec1 - vec2 ||_1 . + */ + public static double sumAbsDiff(Vector vec1, Vector vec2) { + if (vec1 instanceof DenseVector) { + if (vec2 instanceof DenseVector) { + return MatVecOp.applySum((DenseVector) vec1, (DenseVector) vec2, (a, b) -> Math.abs(a - b)); + } else { + return MatVecOp.applySum((DenseVector) vec1, (SparseVector) vec2, (a, b) -> Math.abs(a - b)); + } + } else { + if (vec2 instanceof DenseVector) { + return MatVecOp.applySum((SparseVector) vec1, (DenseVector) vec2, (a, b) -> Math.abs(a - b)); + } else { + return MatVecOp.applySum((SparseVector) vec1, (SparseVector) vec2, (a, b) -> Math.abs(a - b)); + } + } + } + + /** + * Compute || vec1 - vec2 ||_2^2 . + */ + public static double sumSquaredDiff(Vector vec1, Vector vec2) { + if (vec1 instanceof DenseVector) { + if (vec2 instanceof DenseVector) { + return MatVecOp.applySum((DenseVector) vec1, (DenseVector) vec2, (a, b) -> (a - b) * (a - b)); + } else { + return MatVecOp.applySum((DenseVector) vec1, (SparseVector) vec2, (a, b) -> (a - b) * (a - b)); + } + } else { + if (vec2 instanceof DenseVector) { + return MatVecOp.applySum((SparseVector) vec1, (DenseVector) vec2, (a, b) -> (a - b) * (a - b)); + } else { + return MatVecOp.applySum((SparseVector) vec1, (SparseVector) vec2, (a, b) -> (a - b) * (a - b)); + } + } + } + + /** + * y = func(x). + */ + public static void apply(DenseMatrix x, DenseMatrix y, Function func) { + assert (x.m == y.m && x.n == y.n) : "x and y size mismatched."; + double[] xdata = x.data; + double[] ydata = y.data; + for (int i = 0; i < xdata.length; i++) { + ydata[i] = func.apply(xdata[i]); + } + } + + /** + * y = func(x1, x2). + */ + public static void apply( + DenseMatrix x1, + DenseMatrix x2, + DenseMatrix y, + BiFunction func) { + + assert (x1.m == y.m && x1.n == y.n) : "x1 and y size mismatched."; + assert (x2.m == y.m && x2.n == y.n) : "x2 and y size mismatched."; + double[] x1data = x1.data; + double[] x2data = x2.data; + double[] ydata = y.data; + for (int i = 0; i < ydata.length; i++) { + ydata[i] = func.apply(x1data[i], x2data[i]); + } + } + + /** + * y = func(x). + */ + public static void apply(DenseVector x, DenseVector y, Function func) { + assert (x.data.length == y.data.length) : "x and y size mismatched."; + for (int i = 0; i < x.data.length; i++) { + y.data[i] = func.apply(x.data[i]); + } + } + + /** + * y = func(x1, x2). + */ + public static void apply( + DenseVector x1, + DenseVector x2, + DenseVector y, + BiFunction func) { + + assert (x1.data.length == y.data.length) : "x1 and y size mismatched."; + assert (x2.data.length == y.data.length) : "x1 and y size mismatched."; + for (int i = 0; i < y.data.length; i++) { + y.data[i] = func.apply(x1.data[i], x2.data[i]); + } + } + + /** + * Create a new {@link SparseVector} by element wise operation between two {@link SparseVector}s. + * y = func(x1, x2). + */ + public static SparseVector apply(SparseVector x1, SparseVector x2, BiFunction func) { + assert (x1.size() == x2.size()) : "x1 and x2 size mismatched."; + + int totNnz = x1.values.length + x2.values.length; + int p0 = 0; + int p1 = 0; + while (p0 < x1.values.length && p1 < x2.values.length) { + if (x1.indices[p0] == x2.indices[p1]) { + totNnz--; + p0++; + p1++; + } else if (x1.indices[p0] < x2.indices[p1]) { + p0++; + } else { + p1++; + } + } + + SparseVector r = new SparseVector(x1.size()); + r.indices = new int[totNnz]; + r.values = new double[totNnz]; + p0 = p1 = 0; + int pos = 0; + while (pos < totNnz) { + if (p0 < x1.values.length && p1 < x2.values.length) { + if (x1.indices[p0] == x2.indices[p1]) { + r.indices[pos] = x1.indices[p0]; + r.values[pos] = func.apply(x1.values[p0], x2.values[p1]); + p0++; + p1++; + } else if (x1.indices[p0] < x2.indices[p1]) { + r.indices[pos] = x1.indices[p0]; + r.values[pos] = func.apply(x1.values[p0], 0.0); + p0++; + } else { + r.indices[pos] = x2.indices[p1]; + r.values[pos] = func.apply(0.0, x2.values[p1]); + p1++; + } + pos++; + } else { + if (p0 < x1.values.length) { + r.indices[pos] = x1.indices[p0]; + r.values[pos] = func.apply(x1.values[p0], 0.0); + p0++; + pos++; + continue; + } + if (p1 < x2.values.length) { + r.indices[pos] = x2.indices[p1]; + r.values[pos] = func.apply(0.0, x2.values[p1]); + p1++; + pos++; + continue; + } + } + } + + return r; + } + + /** + * \sum_i func(x1_i, x2_i) . + */ + public static double applySum(DenseVector x1, DenseVector x2, BiFunction func) { + assert x1.size() == x2.size() : "x1 and x2 size mismatched."; + double[] x1data = x1.getData(); + double[] x2data = x2.getData(); + double s = 0.; + for (int i = 0; i < x1data.length; i++) { + s += func.apply(x1data[i], x2data[i]); + } + return s; + } + + /** + * \sum_i func(x1_i, x2_i) . + */ + public static double applySum(SparseVector x1, SparseVector x2, BiFunction func) { + double s = 0.; + int p1 = 0; + int p2 = 0; + int[] x1Indices = x1.getIndices(); + double[] x1Values = x1.getValues(); + int[] x2Indices = x2.getIndices(); + double[] x2Values = x2.getValues(); + int nnz1 = x1Indices.length; + int nnz2 = x2Indices.length; + while (p1 < nnz1 || p2 < nnz2) { + if (p1 < nnz1 && p2 < nnz2) { + if (x1Indices[p1] == x2Indices[p2]) { + s += func.apply(x1Values[p1], x2Values[p2]); + p1++; + p2++; + } else if (x1Indices[p1] < x2Indices[p2]) { + s += func.apply(x1Values[p1], 0.); + p1++; + } else { + s += func.apply(0., x2Values[p2]); + p2++; + } + } else { + if (p1 < nnz1) { + s += func.apply(x1Values[p1], 0.); + p1++; + } else { // p2 < nnz2 + s += func.apply(0., x2Values[p2]); + p2++; + } + } + } + return s; + } + + /** + * \sum_i func(x1_i, x2_i) . + */ + public static double applySum(DenseVector x1, SparseVector x2, BiFunction func) { + assert x1.size() == x2.size() : "x1 and x2 size mismatched."; + double s = 0.; + int p2 = 0; + int[] x2Indices = x2.getIndices(); + double[] x2Values = x2.getValues(); + int nnz2 = x2Indices.length; + double[] x1data = x1.getData(); + for (int i = 0; i < x1data.length; i++) { + if (p2 < nnz2 && x2Indices[p2] == i) { + s += func.apply(x1data[i], x2Values[p2]); + p2++; + } else { + s += func.apply(x1data[i], 0.); + } + } + return s; + } + + /** + * \sum_i func(x1_i, x2_i) . + */ + public static double applySum(SparseVector x1, DenseVector x2, BiFunction func) { + assert x1.size() == x2.size() : "x1 and x2 size mismatched."; + double s = 0.; + int p1 = 0; + int[] x1Indices = x1.getIndices(); + double[] x1Values = x1.getValues(); + int nnz1 = x1Indices.length; + double[] x2data = x2.getData(); + for (int i = 0; i < x2data.length; i++) { + if (p1 < nnz1 && x1Indices[p1] == i) { + s += func.apply(x1Values[p1], x2data[i]); + p1++; + } else { + s += func.apply(0., x2data[i]); + } + } + return s; + } + + /** + * \sum_ij func(x1_ij, x2_ij) . + */ + public static double applySum(DenseMatrix x1, DenseMatrix x2, BiFunction func) { + assert (x1.m == x2.m && x1.n == x2.n) : "x1 and x2 size mismatched."; + double[] x1data = x1.data; + double[] x2data = x2.data; + double s = 0.; + for (int i = 0; i < x1data.length; i++) { + s += func.apply(x1data[i], x2data[i]); + } + return s; + } +} diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/SparseVector.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/SparseVector.java new file mode 100644 index 0000000000000..572226ee692ba --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/SparseVector.java @@ -0,0 +1,595 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; + +/** + * A sparse vector represented by an indices array and a values array. + */ +public class SparseVector extends Vector { + + /** + * Size of the vector. n = -1 indicates that the vector size is undetermined. + * + *

Package private to allow access from {@link MatVecOp} and {@link BLAS}. + */ + int n; + + /** + * Column indices. + *

+ * Package private to allow access from {@link MatVecOp} and {@link BLAS}. + */ + int[] indices; + + /** + * Column values. + *

+ * Package private to allow access from {@link MatVecOp} and {@link BLAS}. + */ + double[] values; + + /** + * Construct an empty sparse vector with undetermined size. + */ + public SparseVector() { + this(-1); + } + + /** + * Construct an empty sparse vector with determined size. + */ + public SparseVector(int n) { + this.n = n; + this.indices = new int[0]; + this.values = new double[0]; + } + + /** + * Construct a sparse vector with the given indices and values. + * + * @throws IllegalArgumentException If size of indices array and values array differ. + * @throws IllegalArgumentException If n >= 0 and the indices are out of bound. + */ + public SparseVector(int n, int[] indices, double[] values) { + this.n = n; + this.indices = indices; + this.values = values; + checkSizeAndIndicesRange(); + sortIndices(); + } + + /** + * Construct a sparse vector with given indices to values map. + * + * @throws IllegalArgumentException If n >= 0 and the indices are out of bound. + */ + public SparseVector(int n, Map kv) { + this.n = n; + int nnz = kv.size(); + int[] indices = new int[nnz]; + double[] values = new double[nnz]; + + int pos = 0; + for (Map.Entry entry : kv.entrySet()) { + indices[pos] = entry.getKey(); + values[pos] = entry.getValue(); + pos++; + } + + this.indices = indices; + this.values = values; + checkSizeAndIndicesRange(); + + if (!(kv instanceof TreeMap)) { + sortIndices(); + } + } + + /** + * Check whether the indices array and values array are of the same size, + * and whether vector indices are in valid range. + */ + private void checkSizeAndIndicesRange() { + if (indices.length != values.length) { + throw new IllegalArgumentException("Indices size and values size should be the same."); + } + for (int i = 0; i < indices.length; i++) { + if (indices[i] < 0 || (n >= 0 && indices[i] >= n)) { + throw new IllegalArgumentException("Index out of bound."); + } + } + } + + /** + * Sort the indices and values using quick sort. + */ + private static void sortImpl(int[] indices, double[] values, int low, int high) { + int pivot = indices[high]; + int pos = low - 1; + for (int i = low; i <= high; i++) { + if (indices[i] <= pivot) { + pos++; + int tempI = indices[pos]; + double tempD = values[pos]; + indices[pos] = indices[i]; + values[pos] = values[i]; + indices[i] = tempI; + values[i] = tempD; + } + } + if (pos - 1 > low) { + sortImpl(indices, values, low, pos - 1); + } + if (high > pos + 1) { + sortImpl(indices, values, pos + 1, high); + } + } + + /** + * Sort the indices and values if the indices are out of order. + */ + private void sortIndices() { + boolean outOfOrder = false; + for (int i = 0; i < this.indices.length - 1; i++) { + if (this.indices[i] > this.indices[i + 1]) { + outOfOrder = true; + break; + } + } + if (outOfOrder) { + sortImpl(this.indices, this.values, 0, this.indices.length - 1); + } + } + + @Override + public SparseVector clone() { + SparseVector vec = new SparseVector(this.n); + vec.indices = this.indices.clone(); + vec.values = this.values.clone(); + return vec; + } + + @Override + public SparseVector prefix(double d) { + int[] indices = new int[this.indices.length + 1]; + double[] values = new double[this.values.length + 1]; + int n = (this.n >= 0) ? this.n + 1 : this.n; + + indices[0] = 0; + values[0] = d; + + for (int i = 0; i < this.indices.length; i++) { + indices[i + 1] = this.indices[i] + 1; + values[i + 1] = this.values[i]; + } + + return new SparseVector(n, indices, values); + } + + @Override + public SparseVector append(double d) { + int[] indices = new int[this.indices.length + 1]; + double[] values = new double[this.values.length + 1]; + int n = (this.n >= 0) ? this.n + 1 : this.n; + + System.arraycopy(this.indices, 0, indices, 0, this.indices.length); + System.arraycopy(this.values, 0, values, 0, this.values.length); + + indices[this.indices.length] = n - 1; + values[this.values.length] = d; + + return new SparseVector(n, indices, values); + } + + /** + * Get the indices array. + */ + public int[] getIndices() { + return indices; + } + + /** + * Get the values array. + */ + public double[] getValues() { + return values; + } + + @Override + public int size() { + return n; + } + + @Override + public double get(int i) { + int pos = Arrays.binarySearch(indices, i); + if (pos >= 0) { + return values[pos]; + } + return 0.; + } + + /** + * Set the size of the vector. + */ + public void setSize(int n) { + this.n = n; + } + + /** + * Get number of values in this vector. + */ + public int numberOfValues() { + return this.values.length; + } + + @Override + public void set(int i, double val) { + int pos = Arrays.binarySearch(indices, i); + if (pos >= 0) { + this.values[pos] = val; + } else { + pos = -(pos + 1); + insert(pos, i, val); + } + } + + @Override + public void add(int i, double val) { + int pos = Arrays.binarySearch(indices, i); + if (pos >= 0) { + this.values[pos] += val; + } else { + pos = -(pos + 1); + insert(pos, i, val); + } + } + + /** + * Insert value "val" in the position "pos" with index "index". + */ + private void insert(int pos, int index, double val) { + double[] newValues = new double[this.values.length + 1]; + int[] newIndices = new int[this.values.length + 1]; + System.arraycopy(this.values, 0, newValues, 0, pos); + System.arraycopy(this.indices, 0, newIndices, 0, pos); + newValues[pos] = val; + newIndices[pos] = index; + System.arraycopy(this.values, pos, newValues, pos + 1, this.values.length - pos); + System.arraycopy(this.indices, pos, newIndices, pos + 1, this.values.length - pos); + this.values = newValues; + this.indices = newIndices; + } + + @Override + public String toString() { + return VectorUtil.toString(this); + } + + @Override + public double normL2() { + double d = 0; + for (double t : values) { + d += t * t; + } + return Math.sqrt(d); + } + + @Override + public double normL1() { + double d = 0; + for (double t : values) { + d += Math.abs(t); + } + return d; + } + + @Override + public double normInf() { + double d = 0; + for (double t : values) { + d = Math.max(Math.abs(t), d); + } + return d; + } + + @Override + public double normL2Square() { + double d = 0; + for (double t : values) { + d += t * t; + } + return d; + } + + @Override + public SparseVector slice(int[] indices) { + SparseVector sliced = new SparseVector(indices.length); + int nnz = 0; + sliced.indices = new int[indices.length]; + sliced.values = new double[indices.length]; + + for (int i = 0; i < indices.length; i++) { + int pos = Arrays.binarySearch(this.indices, indices[i]); + if (pos >= 0) { + sliced.indices[nnz] = i; + sliced.values[nnz] = this.values[pos]; + nnz++; + } + } + + if (nnz < sliced.indices.length) { + sliced.indices = Arrays.copyOf(sliced.indices, nnz); + sliced.values = Arrays.copyOf(sliced.values, nnz); + } + + return sliced; + } + + @Override + public Vector plus(Vector vec) { + if (this.size() != vec.size()) { + throw new IllegalArgumentException("The size of the two vectors are different."); + } + + if (vec instanceof DenseVector) { + DenseVector r = ((DenseVector) vec).clone(); + for (int i = 0; i < this.indices.length; i++) { + r.add(this.indices[i], this.values[i]); + } + return r; + } else { + return MatVecOp.apply(this, (SparseVector) vec, ((a, b) -> a + b)); + } + } + + @Override + public Vector minus(Vector vec) { + if (this.size() != vec.size()) { + throw new IllegalArgumentException("The size of the two vectors are different."); + } + + if (vec instanceof DenseVector) { + DenseVector r = ((DenseVector) vec).scale(-1.0); + for (int i = 0; i < this.indices.length; i++) { + r.add(this.indices[i], this.values[i]); + } + return r; + } else { + return MatVecOp.apply(this, (SparseVector) vec, ((a, b) -> a - b)); + } + } + + @Override + public SparseVector scale(double d) { + SparseVector r = this.clone(); + BLAS.scal(d, r); + return r; + } + + @Override + public void scaleEqual(double d) { + BLAS.scal(d, this); + } + + /** + * Remove all zero values away from this vector. + */ + public void removeZeroValues() { + if (this.values.length != 0) { + List idxs = new ArrayList<>(); + for (int i = 0; i < values.length; i++) { + if (0 != values[i]) { + idxs.add(i); + } + } + int[] newIndices = new int[idxs.size()]; + double[] newValues = new double[newIndices.length]; + for (int i = 0; i < newIndices.length; i++) { + newIndices[i] = indices[idxs.get(i)]; + newValues[i] = values[idxs.get(i)]; + } + this.indices = newIndices; + this.values = newValues; + } + } + + private double dot(SparseVector other) { + if (this.size() != other.size()) { + throw new RuntimeException("the size of the two vectors are different"); + } + + double d = 0; + int p0 = 0; + int p1 = 0; + while (p0 < this.values.length && p1 < other.values.length) { + if (this.indices[p0] == other.indices[p1]) { + d += this.values[p0] * other.values[p1]; + p0++; + p1++; + } else if (this.indices[p0] < other.indices[p1]) { + p0++; + } else { + p1++; + } + } + return d; + } + + private double dot(DenseVector other) { + if (this.size() != other.size()) { + throw new RuntimeException( + "The size of the two vectors are different: " + this.size() + " vs " + other.size()); + } + double s = 0.; + for (int i = 0; i < this.indices.length; i++) { + s += this.values[i] * other.get(this.indices[i]); + } + return s; + } + + @Override + public double dot(Vector other) { + if (other instanceof DenseVector) { + return dot((DenseVector) other); + } else { + return dot((SparseVector) other); + } + } + + @Override + public DenseMatrix outer() { + return this.outer(this); + } + + /** + * Compute the outer product with another vector. + * + * @return The outer product matrix. + */ + public DenseMatrix outer(SparseVector other) { + int nrows = this.size(); + int ncols = other.size(); + double[] data = new double[ncols * nrows]; + for (int i = 0; i < this.values.length; i++) { + for (int j = 0; j < other.values.length; j++) { + data[this.indices[i] + other.indices[j] * nrows] = this.values[i] * other.values[j]; + } + } + return new DenseMatrix(nrows, ncols, data); + } + + /** + * Convert to a dense vector. + */ + public DenseVector toDenseVector() { + if (n >= 0) { + DenseVector r = new DenseVector(n); + for (int i = 0; i < this.indices.length; i++) { + r.set(this.indices[i], this.values[i]); + } + return r; + } else { + if (this.indices.length == 0) { + return new DenseVector(); + } else { + int n = this.indices[this.indices.length - 1] + 1; + DenseVector r = new DenseVector(n); + for (int i = 0; i < this.indices.length; i++) { + r.set(this.indices[i], this.values[i]); + } + return r; + } + } + } + + @Override + public void standardizeEqual(double mean, double stdvar) { + for (int i = 0; i < indices.length; i++) { + values[i] -= mean; + values[i] *= (1.0 / stdvar); + } + } + + @Override + public void normalizeEqual(double p) { + double norm = 0.0; + if (Double.isInfinite(p)) { + norm = normInf(); + } else if (p == 1.0) { + norm = normL1(); + } else if (p == 2.0) { + norm = normL2(); + } else { + for (int i = 0; i < indices.length; i++) { + norm += Math.pow(values[i], p); + } + norm = Math.pow(norm, 1 / p); + } + + for (int i = 0; i < indices.length; i++) { + values[i] /= norm; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SparseVector that = (SparseVector) o; + return n == that.n && + Arrays.equals(indices, that.indices) && + Arrays.equals(values, that.values); + } + + @Override + public int hashCode() { + int result = Objects.hash(n); + result = 31 * result + Arrays.hashCode(indices); + result = 31 * result + Arrays.hashCode(values); + return result; + } + + @Override + public VectorIterator iterator() { + return new SparseVectorVectorIterator(); + } + + private class SparseVectorVectorIterator implements VectorIterator { + private int cursor = 0; + + @Override + public boolean hasNext() { + return cursor < values.length; + } + + @Override + public void next() { + cursor++; + } + + @Override + public int getIndex() { + if (cursor >= values.length) { + throw new RuntimeException("Iterator out of bound."); + } + return indices[cursor]; + } + + @Override + public double getValue() { + if (cursor >= values.length) { + throw new RuntimeException("Iterator out of bound."); + } + return values[cursor]; + } + } +} diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/Vector.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/Vector.java new file mode 100644 index 0000000000000..887da6a5c64a6 --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/Vector.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import java.io.Serializable; + +/** + * The Vector class defines some common methods for both DenseVector and + * SparseVector. + */ +public abstract class Vector implements Serializable { + /** + * Get the size of the vector. + */ + public abstract int size(); + + /** + * Get the i-th element of the vector. + */ + public abstract double get(int i); + + /** + * Set the i-th element of the vector to value "val". + */ + public abstract void set(int i, double val); + + /** + * Add the i-th element of the vector by value "val". + */ + public abstract void add(int i, double val); + + /** + * Return the L1 norm of the vector. + */ + public abstract double normL1(); + + /** + * Return the Inf norm of the vector. + */ + public abstract double normInf(); + + /** + * Return the L2 norm of the vector. + */ + public abstract double normL2(); + + /** + * Return the square of L2 norm of the vector. + */ + public abstract double normL2Square(); + + /** + * Scale the vector by value "v" and create a new vector to store the result. + */ + public abstract Vector scale(double v); + + /** + * Scale the vector by value "v". + */ + public abstract void scaleEqual(double v); + + /** + * Normalize the vector. + */ + public abstract void normalizeEqual(double p); + + /** + * Standardize the vector. + */ + public abstract void standardizeEqual(double mean, double stdvar); + + /** + * Create a new vector by adding an element to the head of the vector. + */ + public abstract Vector prefix(double v); + + /** + * Create a new vector by adding an element to the end of the vector. + */ + public abstract Vector append(double v); + + /** + * Create a new vector by plussing another vector. + */ + public abstract Vector plus(Vector vec); + + /** + * Create a new vector by subtracting another vector. + */ + public abstract Vector minus(Vector vec); + + /** + * Compute the dot product with another vector. + */ + public abstract double dot(Vector vec); + + /** + * Get the iterator of the vector. + */ + public abstract VectorIterator iterator(); + + /** + * Slice the vector. + */ + public abstract Vector slice(int[] indexes); + + /** + * Compute the outer product with itself. + * + * @return The outer product matrix. + */ + public abstract DenseMatrix outer(); +} diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorIterator.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorIterator.java new file mode 100644 index 0000000000000..8828aca1882fc --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorIterator.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * An iterator over the elements of a vector. + * + *

Usage: + * + * + * Vector vector = ...; + * VectorIterator iterator = vector.iterator(); + * + * while(iterator.hasNext()) { + * int index = iterator.getIndex(); + * double value = iterator.getValue(); + * iterator.next(); + * } + * + */ +public interface VectorIterator extends Serializable { + + /** + * Returns {@code true} if the iteration has more elements. + * Otherwise, {@code false} will be returned. + * + * @return {@code true} if the iteration has more elements + */ + boolean hasNext(); + + /** + * Trigger the cursor points to the next element of the vector. + * + *

The {@link #getIndex()} while returns the index of the + * element which the cursor points. + * The {@link #getValue()} ()} while returns the value of + * the element which the cursor points. + * + *

The difference to the {@link Iterator#next()} is that this + * can avoid the return of boxed type. + */ + void next(); + + /** + * Returns the index of the element which the cursor points. + * + * @returnthe the index of the element which the cursor points. + */ + int getIndex(); + + /** + * Returns the value of the element which the cursor points. + * + * @returnthe the value of the element which the cursor points. + */ + double getValue(); +} diff --git a/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorUtil.java b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorUtil.java new file mode 100644 index 0000000000000..b605f635e61d7 --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorUtil.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import org.apache.commons.lang3.StringUtils; + +/** + * Utility class for the operations on {@link Vector} and its subclasses. + */ +public class VectorUtil { + /** + * Delimiter between elements. + */ + private static final char ELEMENT_DELIMITER = ' '; + /** + * Delimiter between vector size and vector data. + */ + private static final char HEADER_DELIMITER = '$'; + /** + * Delimiter between index and value. + */ + private static final char INDEX_VALUE_DELIMITER = ':'; + + /** + * Parse either a {@link SparseVector} or a {@link DenseVector} from a formatted string. + * + *

The format of a dense vector is space separated values such as "1 2 3 4". + * The format of a sparse vector is space separated index-value pairs, such as "0:1 2:3 3:4". + * If the sparse vector has determined vector size, the size is prepended to the head. For example, + * the string "$4$0:1 2:3 3:4" represents a sparse vector with size 4. + * + * @param str A formatted string representing a vector. + * @return The parsed vector. + */ + public static Vector parse(String str) { + boolean isSparse = org.apache.flink.util.StringUtils.isNullOrWhitespaceOnly(str) + || StringUtils.indexOf(str, INDEX_VALUE_DELIMITER) != -1 + || StringUtils.indexOf(str, HEADER_DELIMITER) != -1; + if (isSparse) { + return parseSparse(str); + } else { + return parseDense(str); + } + } + + /** + * Parse the dense vector from a formatted string. + * + *

The format of a dense vector is space separated values such as "1 2 3 4". + * + * @param str A string of space separated values. + * @return The parsed vector. + */ + public static DenseVector parseDense(String str) { + if (org.apache.flink.util.StringUtils.isNullOrWhitespaceOnly(str)) { + return new DenseVector(); + } + + int len = str.length(); + + int inDataBuffPos = 0; + boolean isInBuff = false; + + for (int i = 0; i < len; ++i) { + char c = str.charAt(i); + + if (c == ELEMENT_DELIMITER + // to be compatible with previous delimiter + || c == ',') { + if (isInBuff) { + inDataBuffPos++; + } + + isInBuff = false; + } else { + isInBuff = true; + } + } + + if (isInBuff) { + inDataBuffPos++; + } + + double[] data = new double[inDataBuffPos]; + int lastestInCharBuffPos = 0; + + inDataBuffPos = 0; + isInBuff = false; + + for (int i = 0; i < len; ++i) { + char c = str.charAt(i); + + if (c == ELEMENT_DELIMITER) { + if (isInBuff) { + data[inDataBuffPos++] = Double.parseDouble( + StringUtils.substring(str, lastestInCharBuffPos, i).trim() + ); + + lastestInCharBuffPos = i + 1; + } + + isInBuff = false; + } else { + isInBuff = true; + } + } + + if (isInBuff) { + data[inDataBuffPos] = Double.valueOf( + StringUtils.substring(str, lastestInCharBuffPos).trim() + ); + } + + return new DenseVector(data); + } + + /** + * Parse the sparse vector from a formatted string. + * + *

The format of a sparse vector is space separated index-value pairs, such as "0:1 2:3 3:4". + * If the sparse vector has determined vector size, the size is prepended to the head. For example, + * the string "$4$0:1 2:3 3:4" represents a sparse vector with size 4. + * + * @param str A formatted string representing a sparse vector. + * @throws IllegalArgumentException If the string is of invalid format. + */ + public static SparseVector parseSparse(String str) { + try { + if (org.apache.flink.util.StringUtils.isNullOrWhitespaceOnly(str)) { + return new SparseVector(); + } + + int n = -1; + int firstDollarPos = str.indexOf(HEADER_DELIMITER); + int lastDollarPos = -1; + if (firstDollarPos >= 0) { + lastDollarPos = StringUtils.lastIndexOf(str, HEADER_DELIMITER); + String sizeStr = StringUtils.substring(str, firstDollarPos + 1, lastDollarPos); + n = Integer.valueOf(sizeStr); + if (lastDollarPos == str.length() - 1) { + return new SparseVector(n); + } + } + + int numValues = StringUtils.countMatches(str, String.valueOf(INDEX_VALUE_DELIMITER)); + double[] data = new double[numValues]; + int[] indices = new int[numValues]; + int startPos = lastDollarPos + 1; + int endPos; + for (int i = 0; i < numValues; i++) { + int colonPos = StringUtils.indexOf(str, INDEX_VALUE_DELIMITER, startPos); + if (colonPos < 0) { + throw new IllegalArgumentException("Format error."); + } + endPos = StringUtils.indexOf(str, ELEMENT_DELIMITER, colonPos); + + if (endPos == -1) { + endPos = str.length(); + } + indices[i] = Integer.valueOf(str.substring(startPos, colonPos).trim()); + data[i] = Double.valueOf(str.substring(colonPos + 1, endPos).trim()); + startPos = endPos + 1; + } + return new SparseVector(n, indices, data); + } catch (Exception e) { + throw new IllegalArgumentException( + String.format("Fail to getVector sparse vector from string: \"%s\".", str), e); + } + } + + /** + * Serialize the vector to a string. + * + * @param vector The vector to serialize. + * @see #toString(DenseVector) + * @see #toString(SparseVector) + */ + public static String toString(Vector vector) { + if (vector instanceof SparseVector) { + return toString((SparseVector) vector); + } + return toString((DenseVector) vector); + } + + /** + * Serialize the SparseVector to string. + * + *

The format of the returned is described at {@link #parseSparse(String)} + * + * @param sparseVector The sparse vector to serialize. + */ + public static String toString(SparseVector sparseVector) { + StringBuilder sbd = new StringBuilder(); + if (sparseVector.n > 0) { + sbd.append(HEADER_DELIMITER); + sbd.append(sparseVector.n); + sbd.append(HEADER_DELIMITER); + } + if (null != sparseVector.indices) { + for (int i = 0; i < sparseVector.indices.length; i++) { + sbd.append(sparseVector.indices[i]); + sbd.append(INDEX_VALUE_DELIMITER); + sbd.append(sparseVector.values[i]); + if (i < sparseVector.indices.length - 1) { + sbd.append(ELEMENT_DELIMITER); + } + } + } + + return sbd.toString(); + } + + /** + * Serialize the DenseVector to String. + * + *

The format of the returned is described at {@link #parseDense(String)} + * + * @param denseVector The DenseVector to serialize. + */ + public static String toString(DenseVector denseVector) { + StringBuilder sbd = new StringBuilder(); + + for (int i = 0; i < denseVector.data.length; i++) { + sbd.append(denseVector.data[i]); + if (i < denseVector.data.length - 1) { + sbd.append(ELEMENT_DELIMITER); + } + } + return sbd.toString(); + } + +} diff --git a/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/NOTICE b/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/NOTICE new file mode 100644 index 0000000000000..db087abe66c49 --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/NOTICE @@ -0,0 +1,10 @@ +flink-ml-lib +Copyright 2014-2019 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +This project bundles the following dependencies under the BSD license. +See bundled license files for details. + +- com.github.fommil.netlib:core:1.1.2 diff --git a/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/licenses/LICENSE.core b/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/licenses/LICENSE.core new file mode 100644 index 0000000000000..b7d28491e4b6b --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/main/resources/META-INF/licenses/LICENSE.core @@ -0,0 +1,49 @@ +Copyright (c) 2013 Samuel Halliday +Copyright (c) 1992-2011 The University of Tennessee and The University + of Tennessee Research Foundation. All rights + reserved. +Copyright (c) 2000-2011 The University of California Berkeley. All + rights reserved. +Copyright (c) 2006-2011 The University of Colorado Denver. All rights + reserved. + +$COPYRIGHT$ + +Additional copyrights may follow + +$HEADER$ + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +- Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer listed + in this license in the documentation and/or other materials + provided with the distribution. + +- Neither the name of the copyright holders nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +The copyright holders provide no reassurances that the source code +provided does not infringe any patent, copyright, or any other +intellectual property rights of third parties. The copyright holders +disclaim any liability to any recipient for claims brought against +recipient by any third party for infringement of that parties +intellectual property rights. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseMatrixTest.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseMatrixTest.java new file mode 100644 index 0000000000000..bb35294ea0033 --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseMatrixTest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Test cases for DenseMatrix. + */ +public class DenseMatrixTest { + + private static final double TOL = 1.0e-6; + + private static void assertEqual2D(double[][] matA, double[][] matB) { + assert (matA.length == matB.length); + assert (matA[0].length == matB[0].length); + int m = matA.length; + int n = matA[0].length; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + Assert.assertEquals(matA[i][j], matB[i][j], TOL); + } + } + } + + private static double[][] simpleMM(double[][] matA, double[][] matB) { + int m = matA.length; + int n = matB[0].length; + int k = matA[0].length; + double[][] matC = new double[m][n]; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + matC[i][j] = 0.; + for (int l = 0; l < k; l++) { + matC[i][j] += matA[i][l] * matB[l][j]; + } + } + } + return matC; + } + + private static double[] simpleMV(double[][] matA, double[] x) { + int m = matA.length; + int n = matA[0].length; + assert (n == x.length); + double[] y = new double[m]; + for (int i = 0; i < m; i++) { + y[i] = 0.; + for (int j = 0; j < n; j++) { + y[i] += matA[i][j] * x[j]; + } + } + return y; + } + + @Test + public void testPlusEquals() throws Exception { + DenseMatrix matA = new DenseMatrix(new double[][]{ + new double[]{1, 3, 5}, + new double[]{2, 4, 6}, + }); + DenseMatrix matB = DenseMatrix.ones(2, 3); + matA.plusEquals(matB); + Assert.assertArrayEquals(matA.getData(), new double[]{2, 3, 4, 5, 6, 7}, TOL); + matA.plusEquals(1.0); + Assert.assertArrayEquals(matA.getData(), new double[]{3, 4, 5, 6, 7, 8}, TOL); + } + + @Test + public void testMinusEquals() throws Exception { + DenseMatrix matA = new DenseMatrix(new double[][]{ + new double[]{1, 3, 5}, + new double[]{2, 4, 6}, + }); + DenseMatrix matB = DenseMatrix.ones(2, 3); + matA.minusEquals(matB); + Assert.assertArrayEquals(matA.getData(), new double[]{0, 1, 2, 3, 4, 5}, TOL); + } + + @Test + public void testPlus() throws Exception { + DenseMatrix matA = new DenseMatrix(new double[][]{ + new double[]{1, 3, 5}, + new double[]{2, 4, 6}, + }); + DenseMatrix matB = DenseMatrix.ones(2, 3); + DenseMatrix matC = matA.plus(matB); + Assert.assertArrayEquals(matC.getData(), new double[]{2, 3, 4, 5, 6, 7}, TOL); + DenseMatrix matD = matA.plus(1.0); + Assert.assertArrayEquals(matD.getData(), new double[]{2, 3, 4, 5, 6, 7}, TOL); + } + + @Test + public void testMinus() throws Exception { + DenseMatrix matA = new DenseMatrix(new double[][]{ + new double[]{1, 3, 5}, + new double[]{2, 4, 6}, + }); + DenseMatrix matB = DenseMatrix.ones(2, 3); + DenseMatrix matC = matA.minus(matB); + Assert.assertArrayEquals(matC.getData(), new double[]{0, 1, 2, 3, 4, 5}, TOL); + } + + @Test + public void testMM() throws Exception { + DenseMatrix matA = DenseMatrix.rand(4, 3); + DenseMatrix matB = DenseMatrix.rand(3, 5); + DenseMatrix matC = matA.multiplies(matB); + assertEqual2D(matC.getArrayCopy2D(), simpleMM(matA.getArrayCopy2D(), matB.getArrayCopy2D())); + + DenseMatrix matD = new DenseMatrix(5, 4); + BLAS.gemm(1., matB, true, matA, true, 0., matD); + Assert.assertArrayEquals(matD.transpose().getData(), matC.data, TOL); + } + + @Test + public void testMV() throws Exception { + DenseMatrix matA = DenseMatrix.rand(4, 3); + DenseVector x = DenseVector.ones(3); + DenseVector y = matA.multiplies(x); + Assert.assertArrayEquals(y.getData(), simpleMV(matA.getArrayCopy2D(), x.getData()), TOL); + + SparseVector x2 = new SparseVector(3, new int[]{0, 1, 2}, new double[]{1, 1, 1}); + DenseVector y2 = matA.multiplies(x2); + Assert.assertArrayEquals(y2.getData(), y.getData(), TOL); + } + + @Test + public void testDataSelection() throws Exception { + DenseMatrix mat = new DenseMatrix(new double[][]{ + new double[]{1, 2, 3}, + new double[]{4, 5, 6}, + new double[]{7, 8, 9}, + }); + DenseMatrix sub1 = mat.selectRows(new int[]{1}); + DenseMatrix sub2 = mat.getSubMatrix(1, 2, 1, 2); + Assert.assertEquals(sub1.numRows(), 1); + Assert.assertEquals(sub1.numCols(), 3); + Assert.assertEquals(sub2.numRows(), 1); + Assert.assertEquals(sub2.numCols(), 1); + Assert.assertArrayEquals(sub1.getData(), new double[]{4, 5, 6}, TOL); + Assert.assertArrayEquals(sub2.getData(), new double[]{5}, TOL); + + double[] row = mat.getRow(1); + double[] col = mat.getColumn(1); + Assert.assertArrayEquals(row, new double[]{4, 5, 6}, 0.); + Assert.assertArrayEquals(col, new double[]{2, 5, 8}, 0.); + } + + @Test + public void testSum() throws Exception { + DenseMatrix matA = DenseMatrix.ones(3, 2); + Assert.assertEquals(matA.sum(), 6.0, TOL); + } + + @Test + public void testRowMajorFormat() throws Exception { + double[] data = new double[]{1, 2, 3, 4, 5, 6}; + DenseMatrix matA = new DenseMatrix(2, 3, data, true); + Assert.assertArrayEquals(data, new double[]{1, 4, 2, 5, 3, 6}, 0.); + Assert.assertArrayEquals(matA.getData(), new double[]{1, 4, 2, 5, 3, 6}, 0.); + + data = new double[]{1, 2, 3, 4}; + matA = new DenseMatrix(2, 2, data, true); + Assert.assertArrayEquals(data, new double[]{1, 3, 2, 4}, 0.); + Assert.assertArrayEquals(matA.getData(), new double[]{1, 3, 2, 4}, 0.); + } +} diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseVectorTest.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseVectorTest.java new file mode 100644 index 0000000000000..3e972040f8047 --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseVectorTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Test cases for DenseVector. + */ +public class DenseVectorTest { + private static final double TOL = 1.0e-6; + + @Test + public void testSize() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + Assert.assertEquals(vec.size(), 3); + } + + @Test + public void testNormL1() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + Assert.assertEquals(vec.normL1(), 6, 0); + } + + @Test + public void testNormMax() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + Assert.assertEquals(vec.normInf(), 3, 0); + } + + @Test + public void testNormL2() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + Assert.assertEquals(vec.normL2(), Math.sqrt(1 + 4 + 9), TOL); + } + + @Test + public void testNormL2Square() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + Assert.assertEquals(vec.normL2Square(), 1 + 4 + 9, TOL); + } + + @Test + public void testSlice() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + DenseVector sliced = vec.slice(new int[]{0, 2}); + Assert.assertArrayEquals(new double[]{1, -3}, sliced.getData(), 0); + } + + @Test + public void testMinus() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + DenseVector d = new DenseVector(new double[]{1, 2, 1}); + DenseVector vec2 = vec.minus(d); + Assert.assertArrayEquals(vec.getData(), new double[]{1, 2, -3}, 0); + Assert.assertArrayEquals(vec2.getData(), new double[]{0, 0, -4}, TOL); + vec.minusEqual(d); + Assert.assertArrayEquals(vec.getData(), new double[]{0, 0, -4}, TOL); + } + + @Test + public void testPlus() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + DenseVector d = new DenseVector(new double[]{1, 2, 1}); + DenseVector vec2 = vec.plus(d); + Assert.assertArrayEquals(vec.getData(), new double[]{1, 2, -3}, 0); + Assert.assertArrayEquals(vec2.getData(), new double[]{2, 4, -2}, TOL); + vec.plusEqual(d); + Assert.assertArrayEquals(vec.getData(), new double[]{2, 4, -2}, TOL); + } + + @Test + public void testPlusScaleEqual() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + DenseVector vec2 = new DenseVector(new double[]{1, 0, 2}); + vec.plusScaleEqual(vec2, 2.); + Assert.assertArrayEquals(vec.getData(), new double[]{3, 2, 1}, TOL); + } + + @Test + public void testDot() throws Exception { + DenseVector vec1 = new DenseVector(new double[]{1, 2, -3}); + DenseVector vec2 = new DenseVector(new double[]{3, 2, 1}); + Assert.assertEquals(vec1.dot(vec2), 3 + 4 - 3, TOL); + } + + @Test + public void testPrefix() throws Exception { + DenseVector vec1 = new DenseVector(new double[]{1, 2, -3}); + DenseVector vec2 = vec1.prefix(0); + Assert.assertArrayEquals(vec2.getData(), new double[]{0, 1, 2, -3}, 0); + } + + @Test + public void testAppend() throws Exception { + DenseVector vec1 = new DenseVector(new double[]{1, 2, -3}); + DenseVector vec2 = vec1.append(0); + Assert.assertArrayEquals(vec2.getData(), new double[]{1, 2, -3, 0}, 0); + } + + @Test + public void testOuter() throws Exception { + DenseVector vec1 = new DenseVector(new double[]{1, 2, -3}); + DenseVector vec2 = new DenseVector(new double[]{3, 2, 1}); + DenseMatrix outer = vec1.outer(vec2); + Assert.assertArrayEquals(outer.getArrayCopy1D(true), + new double[]{3, 2, 1, 6, 4, 2, -9, -6, -3}, TOL); + } + + @Test + public void testNormalize() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + vec.normalizeEqual(1.0); + Assert.assertArrayEquals(vec.getData(), new double[]{1. / 6, 2. / 6, -3. / 6}, TOL); + } + + @Test + public void testStandardize() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + vec.standardizeEqual(1.0, 1.0); + Assert.assertArrayEquals(vec.getData(), new double[]{0, 1, -4}, TOL); + } + + @Test + public void testIterator() throws Exception { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + VectorIterator iterator = vec.iterator(); + Assert.assertTrue(iterator.hasNext()); + Assert.assertEquals(iterator.getIndex(), 0); + Assert.assertEquals(iterator.getValue(), 1, 0); + iterator.next(); + Assert.assertTrue(iterator.hasNext()); + Assert.assertEquals(iterator.getIndex(), 1); + Assert.assertEquals(iterator.getValue(), 2, 0); + iterator.next(); + Assert.assertTrue(iterator.hasNext()); + Assert.assertEquals(iterator.getIndex(), 2); + Assert.assertEquals(iterator.getValue(), -3, 0); + iterator.next(); + Assert.assertFalse(iterator.hasNext()); + } + +} diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/MatVecOpTest.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/MatVecOpTest.java new file mode 100644 index 0000000000000..2415dcb0c1bcd --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/MatVecOpTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link MatVecOp}. + */ +public class MatVecOpTest { + private static final double TOL = 1.0e-6; + private DenseVector dv; + private SparseVector sv; + + @Before + public void setUp() throws Exception { + dv = new DenseVector(new double[]{1, 2, 3, 4}); + sv = new SparseVector(4, new int[]{0, 2}, new double[]{1., 1.}); + } + + @Test + public void testPlus() throws Exception { + Vector plusResult1 = MatVecOp.plus(dv, sv); + Vector plusResult2 = MatVecOp.plus(sv, dv); + Vector plusResult3 = MatVecOp.plus(sv, sv); + Vector plusResult4 = MatVecOp.plus(dv, dv); + Assert.assertTrue(plusResult1 instanceof DenseVector); + Assert.assertTrue(plusResult2 instanceof DenseVector); + Assert.assertTrue(plusResult3 instanceof SparseVector); + Assert.assertTrue(plusResult4 instanceof DenseVector); + Assert.assertArrayEquals(((DenseVector) plusResult1).getData(), new double[]{2, 2, 4, 4}, TOL); + Assert.assertArrayEquals(((DenseVector) plusResult2).getData(), new double[]{2, 2, 4, 4}, TOL); + Assert.assertArrayEquals(((SparseVector) plusResult3).getIndices(), new int[]{0, 2}); + Assert.assertArrayEquals(((SparseVector) plusResult3).getValues(), new double[]{2., 2.}, TOL); + Assert.assertArrayEquals(((DenseVector) plusResult4).getData(), new double[]{2, 4, 6, 8}, TOL); + } + + @Test + public void testMinus() throws Exception { + Vector minusResult1 = MatVecOp.minus(dv, sv); + Vector minusResult2 = MatVecOp.minus(sv, dv); + Vector minusResult3 = MatVecOp.minus(sv, sv); + Vector minusResult4 = MatVecOp.minus(dv, dv); + Assert.assertTrue(minusResult1 instanceof DenseVector); + Assert.assertTrue(minusResult2 instanceof DenseVector); + Assert.assertTrue(minusResult3 instanceof SparseVector); + Assert.assertTrue(minusResult4 instanceof DenseVector); + Assert.assertArrayEquals(((DenseVector) minusResult1).getData(), new double[]{0, 2, 2, 4}, TOL); + Assert.assertArrayEquals(((DenseVector) minusResult2).getData(), new double[]{0, -2, -2, -4}, TOL); + Assert.assertArrayEquals(((SparseVector) minusResult3).getIndices(), new int[]{0, 2}); + Assert.assertArrayEquals(((SparseVector) minusResult3).getValues(), new double[]{0., 0.}, TOL); + Assert.assertArrayEquals(((DenseVector) minusResult4).getData(), new double[]{0, 0, 0, 0}, TOL); + } + + @Test + public void testDot() throws Exception { + Assert.assertEquals(MatVecOp.dot(dv, sv), 4.0, TOL); + Assert.assertEquals(MatVecOp.dot(sv, dv), 4.0, TOL); + Assert.assertEquals(MatVecOp.dot(sv, sv), 2.0, TOL); + Assert.assertEquals(MatVecOp.dot(dv, dv), 30.0, TOL); + } + + @Test + public void testSumAbsDiff() throws Exception { + Assert.assertEquals(MatVecOp.sumAbsDiff(dv, sv), 8.0, TOL); + Assert.assertEquals(MatVecOp.sumAbsDiff(sv, dv), 8.0, TOL); + Assert.assertEquals(MatVecOp.sumAbsDiff(sv, sv), 0.0, TOL); + Assert.assertEquals(MatVecOp.sumAbsDiff(dv, dv), 0.0, TOL); + } + + @Test + public void testSumSquaredDiff() throws Exception { + Assert.assertEquals(MatVecOp.sumSquaredDiff(dv, sv), 24.0, TOL); + Assert.assertEquals(MatVecOp.sumSquaredDiff(sv, dv), 24.0, TOL); + Assert.assertEquals(MatVecOp.sumSquaredDiff(sv, sv), 0.0, TOL); + Assert.assertEquals(MatVecOp.sumSquaredDiff(dv, dv), 0.0, TOL); + } +} diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/SparseVectorTest.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/SparseVectorTest.java new file mode 100644 index 0000000000000..7c58681b20221 --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/SparseVectorTest.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Map; +import java.util.TreeMap; + +/** + * Test cases for SparseVector. + */ +public class SparseVectorTest { + private static final double TOL = 1.0e-6; + private SparseVector v1 = new SparseVector(8, new int[]{1, 3, 5, 7}, new double[]{2.0, 2.0, 2.0, 2.0}); + private SparseVector v2 = new SparseVector(8, new int[]{3, 4, 5}, new double[]{1.0, 1.0, 1.0}); + + @Test + public void testConstructor() throws Exception { + int[] indices = new int[]{3, 7, 2, 1}; + double[] values = new double[]{3.0, 7.0, 2.0, 1.0}; + Map map = new TreeMap<>(); + for (int i = 0; i < indices.length; i++) { + map.put(indices[i], values[i]); + } + SparseVector v = new SparseVector(8, map); + Assert.assertArrayEquals(v.getIndices(), new int[]{1, 2, 3, 7}); + Assert.assertArrayEquals(v.getValues(), new double[]{1, 2, 3, 7}, TOL); + } + + @Test + public void testSize() throws Exception { + Assert.assertEquals(v1.size(), 8); + } + + @Test + public void testSet() throws Exception { + SparseVector v = v1.clone(); + v.set(2, 2.0); + v.set(3, 3.0); + Assert.assertEquals(v.get(2), 2.0, TOL); + Assert.assertEquals(v.get(3), 3.0, TOL); + } + + @Test + public void testAdd() throws Exception { + SparseVector v = v1.clone(); + v.add(2, 2.0); + v.add(3, 3.0); + Assert.assertEquals(v.get(2), 2.0, TOL); + Assert.assertEquals(v.get(3), 5.0, TOL); + } + + @Test + public void testPrefix() throws Exception { + SparseVector prefixed = v1.prefix(0.2); + Assert.assertArrayEquals(prefixed.getIndices(), new int[]{0, 2, 4, 6, 8}); + Assert.assertArrayEquals(prefixed.getValues(), new double[]{0.2, 2, 2, 2, 2}, 0); + } + + @Test + public void testAppend() throws Exception { + SparseVector prefixed = v1.append(0.2); + Assert.assertArrayEquals(prefixed.getIndices(), new int[]{1, 3, 5, 7, 8}); + Assert.assertArrayEquals(prefixed.getValues(), new double[]{2, 2, 2, 2, 0.2}, 0); + } + + @Test + public void testSortIndices() throws Exception { + int n = 8; + int[] indices = new int[]{7, 5, 3, 1}; + double[] values = new double[]{7, 5, 3, 1}; + v1 = new SparseVector(n, indices, values); + Assert.assertArrayEquals(values, new double[]{1, 3, 5, 7}, 0.); + Assert.assertArrayEquals(v1.getValues(), new double[]{1, 3, 5, 7}, 0.); + Assert.assertArrayEquals(indices, new int[]{1, 3, 5, 7}); + Assert.assertArrayEquals(v1.getIndices(), new int[]{1, 3, 5, 7}); + } + + @Test + public void testNormL2Square() throws Exception { + Assert.assertEquals(v2.normL2Square(), 3.0, TOL); + } + + @Test + public void testMinus() throws Exception { + Vector d = v2.minus(v1); + Assert.assertEquals(d.get(0), 0.0, TOL); + Assert.assertEquals(d.get(1), -2.0, TOL); + Assert.assertEquals(d.get(2), 0.0, TOL); + Assert.assertEquals(d.get(3), -1.0, TOL); + Assert.assertEquals(d.get(4), 1.0, TOL); + } + + @Test + public void testPlus() throws Exception { + Vector d = v1.plus(v2); + Assert.assertEquals(d.get(0), 0.0, TOL); + Assert.assertEquals(d.get(1), 2.0, TOL); + Assert.assertEquals(d.get(2), 0.0, TOL); + Assert.assertEquals(d.get(3), 3.0, TOL); + + DenseVector dv = DenseVector.ones(8); + dv = dv.plus(v2); + Assert.assertArrayEquals(dv.getData(), new double[]{1, 1, 1, 2, 2, 2, 1, 1}, TOL); + } + + @Test + public void testDot() throws Exception { + Assert.assertEquals(v1.dot(v2), 4.0, TOL); + } + + @Test + public void testGet() throws Exception { + Assert.assertEquals(v1.get(5), 2.0, TOL); + Assert.assertEquals(v1.get(6), 0.0, TOL); + } + + @Test + public void testSlice() throws Exception { + int n = 8; + int[] indices = new int[]{1, 3, 5, 7}; + double[] values = new double[]{2.0, 3.0, 4.0, 5.0}; + SparseVector v = new SparseVector(n, indices, values); + + int[] indices1 = new int[]{5, 4, 3}; + SparseVector vec1 = v.slice(indices1); + Assert.assertEquals(vec1.size(), 3); + Assert.assertArrayEquals(vec1.getIndices(), new int[]{0, 2}); + Assert.assertArrayEquals(vec1.getValues(), new double[]{4.0, 3.0}, 0.0); + + int[] indices2 = new int[]{3, 5}; + SparseVector vec2 = v.slice(indices2); + Assert.assertArrayEquals(vec2.getIndices(), new int[]{0, 1}); + Assert.assertArrayEquals(vec2.getValues(), new double[]{3.0, 4.0}, 0.0); + + int[] indices3 = new int[]{2, 4}; + SparseVector vec3 = v.slice(indices3); + Assert.assertEquals(vec3.size(), 2); + Assert.assertArrayEquals(vec3.getIndices(), new int[]{}); + Assert.assertArrayEquals(vec3.getValues(), new double[]{}, 0.0); + + int[] indices4 = new int[]{2, 2, 4, 4}; + SparseVector vec4 = v.slice(indices4); + Assert.assertEquals(vec4.size(), 4); + Assert.assertArrayEquals(vec4.getIndices(), new int[]{}); + Assert.assertArrayEquals(vec4.getValues(), new double[]{}, 0.0); + } + + @Test + public void testToDenseVector() throws Exception { + int[] indices = new int[]{1, 3, 5}; + double[] values = new double[]{1.0, 3.0, 5.0}; + SparseVector v = new SparseVector(-1, indices, values); + DenseVector dv = v.toDenseVector(); + Assert.assertEquals(dv.size(), 6); + Assert.assertArrayEquals(dv.getData(), new double[]{0, 1, 0, 3, 0, 5}, TOL); + } + + @Test + public void testRemoveZeroValues() throws Exception { + int[] indices = new int[]{1, 3, 5}; + double[] values = new double[]{0.0, 3.0, 0.0}; + SparseVector v = new SparseVector(6, indices, values); + v.removeZeroValues(); + Assert.assertArrayEquals(v.getIndices(), new int[]{3}); + Assert.assertArrayEquals(v.getValues(), new double[]{3}, TOL); + } + + @Test + public void testOuter() throws Exception { + DenseMatrix outerProduct = v1.outer(v2); + Assert.assertEquals(outerProduct.numRows(), 8); + Assert.assertEquals(outerProduct.numCols(), 8); + Assert.assertArrayEquals(outerProduct.getRow(0), new double[]{0, 0, 0, 0, 0, 0, 0, 0}, TOL); + Assert.assertArrayEquals(outerProduct.getRow(1), new double[]{0, 0, 0, 2, 2, 2, 0, 0}, TOL); + Assert.assertArrayEquals(outerProduct.getRow(2), new double[]{0, 0, 0, 0, 0, 0, 0, 0}, TOL); + Assert.assertArrayEquals(outerProduct.getRow(3), new double[]{0, 0, 0, 2, 2, 2, 0, 0}, TOL); + Assert.assertArrayEquals(outerProduct.getRow(4), new double[]{0, 0, 0, 0, 0, 0, 0, 0}, TOL); + Assert.assertArrayEquals(outerProduct.getRow(5), new double[]{0, 0, 0, 2, 2, 2, 0, 0}, TOL); + Assert.assertArrayEquals(outerProduct.getRow(6), new double[]{0, 0, 0, 0, 0, 0, 0, 0}, TOL); + Assert.assertArrayEquals(outerProduct.getRow(7), new double[]{0, 0, 0, 2, 2, 2, 0, 0}, TOL); + } + + @Test + public void testIterator() throws Exception { + VectorIterator iterator = v1.iterator(); + Assert.assertTrue(iterator.hasNext()); + Assert.assertEquals(iterator.getIndex(), 1); + Assert.assertEquals(iterator.getValue(), 2, 0); + iterator.next(); + Assert.assertTrue(iterator.hasNext()); + Assert.assertEquals(iterator.getIndex(), 3); + Assert.assertEquals(iterator.getValue(), 2, 0); + iterator.next(); + Assert.assertTrue(iterator.hasNext()); + Assert.assertEquals(iterator.getIndex(), 5); + Assert.assertEquals(iterator.getValue(), 2, 0); + iterator.next(); + Assert.assertTrue(iterator.hasNext()); + Assert.assertEquals(iterator.getIndex(), 7); + Assert.assertEquals(iterator.getValue(), 2, 0); + iterator.next(); + Assert.assertFalse(iterator.hasNext()); + } +} diff --git a/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/VectorUtilTest.java b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/VectorUtilTest.java new file mode 100644 index 0000000000000..7ab3bd4c1c2c5 --- /dev/null +++ b/flink-ml-parent/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/VectorUtilTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Test cases for VectorUtil. + */ +public class VectorUtilTest { + @Test + public void testParseDenseAndToString() { + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + String str = VectorUtil.toString(vec); + Assert.assertEquals(str, "1.0 2.0 -3.0"); + Assert.assertArrayEquals(vec.getData(), VectorUtil.parseDense(str).getData(), 0); + } + + @Test + public void testParseDenseWithSpace() { + DenseVector vec1 = VectorUtil.parseDense("1 2 -3"); + DenseVector vec2 = VectorUtil.parseDense(" 1 2 -3 "); + DenseVector vec = new DenseVector(new double[]{1, 2, -3}); + Assert.assertArrayEquals(vec1.getData(), vec.getData(), 0); + Assert.assertArrayEquals(vec2.getData(), vec.getData(), 0); + } + + @Test + public void testSparseToString() { + SparseVector v1 = new SparseVector(8, new int[]{1, 3, 5, 7}, new double[]{2.0, 2.0, 2.0, 2.0}); + Assert.assertEquals(VectorUtil.toString(v1), "$8$1:2.0 3:2.0 5:2.0 7:2.0"); + } + + @Test + public void testParseSparse() { + SparseVector vec1 = VectorUtil.parseSparse("0:1 2:-3"); + SparseVector vec3 = VectorUtil.parseSparse("$4$0:1 2:-3"); + SparseVector vec4 = VectorUtil.parseSparse("$4$"); + SparseVector vec5 = VectorUtil.parseSparse(""); + Assert.assertEquals(vec1.get(0), 1., 0.); + Assert.assertEquals(vec1.get(2), -3., 0.); + Assert.assertArrayEquals(vec3.toDenseVector().getData(), new double[]{1, 0, -3, 0}, 0); + Assert.assertEquals(vec3.size(), 4); + Assert.assertArrayEquals(vec4.toDenseVector().getData(), new double[]{0, 0, 0, 0}, 0); + Assert.assertEquals(vec4.size(), 4); + Assert.assertEquals(vec5.size(), -1); + } + + @Test + public void testParseAndToStringOfVector() { + Vector sparse = VectorUtil.parseSparse("0:1 2:-3"); + Vector dense = VectorUtil.parseDense("1 0 -3"); + + Assert.assertEquals(VectorUtil.toString(sparse), "0:1.0 2:-3.0"); + Assert.assertEquals(VectorUtil.toString(dense), "1.0 0.0 -3.0"); + Assert.assertTrue(VectorUtil.parse("$4$0:1 2:-3") instanceof SparseVector); + Assert.assertTrue(VectorUtil.parse("1 0 -3") instanceof DenseVector); + } +}