JVM/Scala wrappers for LibTorch.
This project is mature enough to be used regularly in production code. The API exposed is fairly clean
and tries to follow PyTorch syntax as much as possible. The API is a mix of hand-written wrappings and a wrapper
around most of Declarations.yaml
.
That said, some internal documentation is not quite ready for public consumption yet, though there is enough
documentation that people who are already familiar with Scala and LibTorch can probably figure out what's going on.
Code generation is accomplished through a combination of Swig and a quick-and-dirty
Python script that reads in Declarations.yaml
, which provides a language-independent
API for a large part of LibTorch. This file is deprecated and in the
future, we can hopefully replace bindgen.py
using the forthcoming torchgen
tool provided by PyTorch.
One major annoyance with Scala in particular is that you cannot define multiple overloads of a method that take default
arguments. Currently, bindgen.py
uses any defaults present in only the first overload found in Declarations.yaml
.
In some cases, clever use of Scala's implicit conversions can hide these headaches, but currently, you occasionaly have to write
out the defaults where you would not have to in Python. One potential future option is to give overloads
different names, but we elected not to do that (yet).
We have not yet published JARs for this project. These are coming soon.
Scala-torch exposes an API that tries to mirror PyTorch as much as Scala syntax allows. For example, taking some snippets from this tutorial:
PyTorch:
import torch
data = [[1, 2],[3, 4]]
x_data = torch.tensor(data)
Scala-Torch:
import com.microsoft.scalatorch.torch
import com.microsoft.scalatorch.torch.syntax._
torch.ReferenceManager.forBlock { implicit rm =>
val data = $($(1, 2), $(3, 4))
val x_data = torch.tensor(data)
}
PyTorch:
tensor = torch.ones(4, 4)
print(f"First row: {tensor[0]}")
print(f"First column: {tensor[:, 0]}")
print(f"Last column: {tensor[..., -1]}")
tensor[:,1] = 0
print(tensor)
Scala-Torch:
val tensor = torch.ones($(4, 4))
println(s"First row: ${tensor(0)}")
println(s"First column: ${tensor(::, 0)}")
println(s"Last column: ${tensor(---, -1)}")
tensor(::, 1) = 0
println(tensor)
See this file for a complete translation of the PyTorch tutorial into Scala-Torch.
One big difference between Scala-Torch and PyTorch is in memory management. Because Python and LibTorch both use
reference counting, memory management is fairly transparent to users. However, since the JVM uses garbage collection
and finalizers are not guaranteed to run,
it is not easy to make memory management transparent to the user. Scala-Torch elects to make memory management something
the user must control by providing ReferenceManagers
that define the lifetime of any LibTorch-allocated object
that is added to it. All Scala-Torch methods that allocate objects from LibTorch take an implicit
ReferenceManager
,
so it is the responsibility of the caller to make sure there is a ReferenceManager
in implicit
scope (or passed
explicitly) and that that ReferenceManager
will be close()
ed when appropriate. See documentation and uses
of ReferenceManager
for more examples.
PyTorch provides pre-built binaries for the native code backing it here. We make use of the pre-built dynamic libraries by packaging them up in a jar, much like TensorFlow Scala. Downstream projects have two options for handling the native dependencies: they can either
- Declare a dependency on the packaged native dependencies wrapped up with a jar using
val osClassifier = System.getProperty("os.name").toLowerCase match {
case os if os.contains("mac") || os.contains("darwin") => "darwin"
case os if os.contains("linux") => "linux"
case os if os.contains("windows") => "windows"
case os => throw new sbt.MessageOnlyException(s"The OS $os is not a supported platform.")
}
libraryDependencies += ("com.microsoft.scalatorch" % "libtorch-jar" % "1.10.0").classifier(osClassifier + "_cpu")
- Ensure that the libtorch dependencies are installed in the OS-dependent way, for example, in
/usr/lib
or inLD_LIBRARY_PATH
on Linux, or inPATH
on windows. Note that on recent version of MacOS, System Integrity Protected resetsLD_LIBRARY_PATH
andDYLD_LIBRARY_PATH
when working processes, so it is very hard to use that approach on MacOS.
The native binaries for the JNI bindings for all three supported OSes are published in scala-torch-swig.jar
, so there
is no need for OS-specific treatment of those libraries.
Approach 1 is convenient because sbt will handle the libtorch native dependency for you and users won't need install libtorch or set any environment variables. This is the ideal approach for local development.
There are several downsides of approach 1:
- it may unnecessarily duplicate installation of libtorch if, for example, pytorch is already installed
- jars for GPU builds of libtorch are not provided, so approach 2 is the only option if GPU support is required
- care must be taken when publishing any library that depends on Scala-Torch to not publish the dependency
on the
libtorch-jar
, since that would force the consumer of that library to depend on whatever OS-specific version of the jar was used at building time. See the use ofpomPostProcess
in build.sbt for how we handle that. Note that another option is for downstream libraries to exclude thelibtorch-jar
using something like
libraryDependencies += ("com.microsoft" % "scala-torch" % "0.1.0").exclude("com.microsoft.scalatorch", "libtorch-jar")
Approach 2 is the better option for CI, remote jobs, production, etc.
You will need to have SWIG installed, which you can
install using brew install swig
.
git submodule update --init --recursive
cd pytorch
python3 -m tools.codegen.gen -s aten/src/ATen -d torch/share/ATen
cd ..
curl https://download.pytorch.org/libtorch/cpu/libtorch-macos-$(pytorchVersion).zip -o libtorch.zip
unzip libtorch.zip
rm -f libtorch.zip
conda env create --name scala-torch --file environment.yml
conda activate scala-torch
export TORCH_DIR=$PWD/libtorch
# This links to the JNI shared library to the absolute paths in the libtorch dir instead of
# using an rpath.
export LINK_TO_BUILD_LIB=true
sbt test
A similar setup should work for Linux and Windows.
If you are using Clang 11.0.3 you may run into an error
when compiling the SobolEngineOps
file. This is most
likely due to an issue with the compiler and it has already
been reported here.
A temporary workaround is to install another version of
Clang (e.g., by executing brew install llvm
). Another option
is to downgrade XCode to a version < 11.4.
To upgrade the underlying version of LibTorch:
cd pytorch; git checkout <commit>
with the<commit>
of the desired release version, best found here.- Rerun the steps under Local Development.
- Change
TORCH_VERSION
in run_tests.yml. - Address compilation errors when running
sbt compile
. Changes to bindgen.py may be necessary.
Thanks to the following contributors to this project:
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.