Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MXNet Java bug fixes and experience improvement #14213

Merged
merged 3 commits into from
Feb 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
this(NDArray.array(arr, shape, ctx))
}

override def toString: String = nd.toString

def serialize(): Array[Byte] = nd.serialize()

/**
Expand Down
18 changes: 12 additions & 6 deletions scala-package/mxnet-demo/java-demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
<!--- under the License. -->

# MXNet Java Sample Project
This is an project created to use Maven-published Scala/Java package with two Java examples.
This is a project demonstrating how to use the Maven published Scala/Java MXNet package.
The examples provided include:
* NDArray creation
* NDArray operation
* Object Detection using the Inference API
* Image Classification using the Predictor API

## Setup
You are required to use Maven to build the package with the following commands:
You are required to use Maven to build the package with the following commands under `java-demo`:
```
mvn package
```
Expand All @@ -42,16 +48,16 @@ The `SCALA_PKG_PROFILE` should be chosen from `osx-x86_64-cpu`, `linux-x86_64-cp


## Run
lanking520 marked this conversation as resolved.
Show resolved Hide resolved
### Hello World
The Scala file is being executed using Java. You can execute the helloWorld example as follows:
### NDArrayCreation
The Scala file is being executed using Java. You can execute the `NDArrayCreation` example as follows:
```Bash
bash bin/java_sample.sh
```
You can also run the following command manually:
```Bash
java -cp $CLASSPATH sample.HelloWorld
java -cp $CLASSPATH sample.NDArrayCreation
```
However, you have to define the Classpath before you run the demo code. More information can be found in the `java_sample.sh`.
However, you have to define the Classpath before you run the demo code. More information can be found in `bin/java_sample.sh`.
The `CLASSPATH` should point to the jar file you have downloaded.

It will load the library automatically and run the example
Expand Down
2 changes: 1 addition & 1 deletion scala-package/mxnet-demo/java-demo/bin/java_sample.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
#!/bin/bash
CURR_DIR=$(cd $(dirname $0)/../; pwd)
CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/dependency/*
java -Xmx8G -cp $CLASSPATH mxnet.HelloWorld
java -Xmx8G -cp $CLASSPATH mxnet.NDArrayCreation
2 changes: 1 addition & 1 deletion scala-package/mxnet-demo/java-demo/bin/run_od.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
#!/bin/bash
CURR_DIR=$(cd $(dirname $0)/../; pwd)
CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/dependency/*
java -Xmx8G -cp $CLASSPATH mxnet.ObjectDetection
java -Xmx8G -cp $CLASSPATH mxnet.ObjectDetection
6 changes: 6 additions & 0 deletions scala-package/mxnet-demo/java-demo/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@
<artifactId>mxnet-full_${mxnet.scalaprofile}-${mxnet.profile}</artifactId>
<version>${mxnet.version}</version>
</dependency>
<dependency>
<groupId>org.apache.mxnet</groupId>
<artifactId>mxnet-full_${mxnet.scalaprofile}-${mxnet.profile}</artifactId>
<version>${mxnet.version}</version>
<classifier>sources</classifier>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attaching the sources jar

</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mxnet;

import org.apache.commons.io.FileUtils;
import org.apache.mxnet.infer.javaapi.Predictor;
import org.apache.mxnet.javaapi.*;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;

public class ImageClassification {
private static String modelPath;
private static String imagePath;

private static void downloadUrl(String url, String filePath) {
File tmpFile = new File(filePath);
if (!tmpFile.exists()) {
try {
FileUtils.copyURLToFile(new URL(url), tmpFile);
} catch (Exception exception) {
System.err.println(exception);
}
}
}

public static void downloadModelImage() {
String tempDirPath = System.getProperty("java.io.tmpdir");
String baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models";
downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json",
tempDirPath + "/resnet18/resnet-18-symbol.json");
downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params",
tempDirPath + "/resnet18/resnet-18-0000.params");
downloadUrl(baseUrl + "/resnet-18/synset.txt",
tempDirPath + "/resnet18/synset.txt");
downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg");
modelPath = tempDirPath + File.separator + "resnet18/resnet-18";
imagePath = tempDirPath + File.separator +
"inputImages/resnet18/Pug-Cookie.jpg";
}

/**
* Helper class to print the maximum prediction result
* @param probabilities The float array of probability
* @param modelPathPrefix model Path needs to load the synset.txt
*/
private static String printMaximumClass(float[] probabilities,
String modelPathPrefix) throws IOException {
String synsetFilePath = modelPathPrefix.substring(0,
1 + modelPathPrefix.lastIndexOf(File.separator)) + "/synset.txt";
BufferedReader reader = new BufferedReader(new FileReader(synsetFilePath));
ArrayList<String> list = new ArrayList<>();
String line = reader.readLine();

while (line != null){
list.add(line);
line = reader.readLine();
}
reader.close();

int maxIdx = 0;
for (int i = 1;i<probabilities.length;i++) {
if (probabilities[i] > probabilities[maxIdx]) {
maxIdx = i;
}
}

return "Probability : " + probabilities[maxIdx] + " Class : " + list.get(maxIdx) ;
}

public static void main(String[] args) {
// Download the model and Image
downloadModelImage();

// Prepare the model
List<Context> context = new ArrayList<Context>();
context.add(Context.cpu());
List<DataDesc> inputDesc = new ArrayList<>();
Shape inputShape = new Shape(new int[]{1, 3, 224, 224});
inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
Predictor predictor = new Predictor(modelPath, inputDesc, context,0);

// Prepare data
NDArray nd = Image.imRead(imagePath, 1, true);
nd = Image.imResize(nd, 224, 224, null);
nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0]; // HWC to CHW
nd = NDArray.expand_dims(nd, 0, null)[0]; // Add N -> NCHW
nd = nd.asType(DType.Float32()); // Inference with Float32

// Predict directly
float[][] result = predictor.predict(new float[][]{nd.toArray()});
try {
System.out.println("Predict with Float input");
System.out.println(printMaximumClass(result[0], modelPath));
} catch (IOException e) {
System.err.println(e);
}

// predict with NDArray
List<NDArray> ndList = new ArrayList<>();
ndList.add(nd);
List<NDArray> ndResult = predictor.predictWithNDArray(ndList);
try {
System.out.println("Predict with NDArray");
System.out.println(printMaximumClass(ndResult.get(0).toArray(), modelPath));
} catch (IOException e) {
System.err.println(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mxnet;

import org.apache.mxnet.javaapi.*;

public class NDArrayCreation {
static NDArray$ NDArray = NDArray$.MODULE$;
public static void main(String[] args) {

// Create new NDArray
NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu());
System.out.println(nd);

// create new Double NDArray
NDArray ndDouble = new NDArray(new double[]{2.0d, 3.0d}, new Shape(new int[]{2, 1}), Context.cpu());
System.out.println(ndDouble);

// create ones
NDArray ones = NDArray.ones(Context.cpu(), new int[] {1, 2, 3});
System.out.println(ones);

// random
NDArray random = NDArray.random_uniform(
NDArray.new random_uniformParam()
.setLow(0.0f)
.setHigh(2.0f)
.setShape(new Shape(new int[]{10, 10}))
)[0];
System.out.println(random);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,31 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mxnet;

import org.apache.mxnet.javaapi.*;
import java.util.Arrays;

public class HelloWorld {
public class NDArrayOperation {
static NDArray$ NDArray = NDArray$.MODULE$;

public static void main(String[] args) {
System.out.println("Hello World!");
NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu());
System.out.println(nd.shape());
NDArray nd2 = NDArray.dot(NDArray.new dotParam(nd, nd.T()))[0];
System.out.println(Arrays.toString(nd2.toArray()));

// Transpose
NDArray ndT = nd.T();
System.out.println(nd);
System.out.println(ndT);

// change Data Type
NDArray ndInt = nd.asType(DType.Int32());
System.out.println(ndInt);

// element add
NDArray eleAdd = NDArray.elemwise_add(nd, nd, null)[0];
System.out.println(eleAdd);

// norm (L2 Norm)
NDArray normed = NDArray.norm(NDArray.new normParam(nd))[0];
System.out.println(normed);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,18 @@ public static void downloadModelImage() {

public static void main(String[] args) {
List<Context> context = new ArrayList<Context>();
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
context.add(Context.gpu());
} else {
context.add(Context.cpu());
}
context.add(Context.cpu());
downloadModelImage();

List<List<ObjectDetectorOutput>> output
= runObjectDetectionSingle(modelPath, imagePath, context);

Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
Shape outputShape = new Shape(new int[] {1, 6132, 6});
int width = inputShape.get(2);
int height = inputShape.get(3);
List<List<ObjectDetectorOutput>> output
= runObjectDetectionSingle(modelPath, imagePath, context);
String outputStr = "\n";

for (List<ObjectDetectorOutput> ele : output) {
for (ObjectDetectorOutput i : ele) {
outputStr += "Class: " + i.getClassName() + "\n";
Expand All @@ -98,4 +96,4 @@ public static void main(String[] args) {
}
System.out.println(outputStr);
}
}
}