Skip to content

Commit

Permalink
Install torch from pip (#1058)
Browse files Browse the repository at this point in the history
* Install torch packages from pip.

Conda takes too long to converge.

* Remove basemap

Basemap has been deprecated in favor of cartopy: https://github.com/matplotlib/basemap#basemap

* remove conda command for torch in gpu file

* specify cudatoolkit version

* Use * vars in torch version
  • Loading branch information
rosbo committed Jul 30, 2021
1 parent b7d1851 commit 47a089a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
9 changes: 4 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,15 @@ ENV PROJ_LIB=/opt/conda/share/proj
# Using the same global consistent ordered list of channels
RUN conda config --add channels conda-forge && \
conda config --add channels nvidia && \
conda config --add channels pytorch && \
conda config --add channels rapidsai && \
# ^ rapidsai is the highest priority channel, default lowest, conda-forge 2nd lowest.
# b/182405233 pyproj 3.x is not compatible with basemap 1.2.1
# b/161473620#comment7 pin required to prevent resolver from picking pysal 1.x., pysal 2.2.x is also downloading data on import.
conda install basemap cartopy imagemagick pyproj "pysal==2.1.0" && \
conda install "pytorch=1.7" "torchvision=0.8" "torchaudio=0.7" "torchtext=0.8" cpuonly && \
conda install cartopy=0.19 imagemagick=7.0 pyproj==3.1.0 pysal==2.1.0 && \
/tmp/clean-layer.sh

RUN pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 torchtext==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html && \
/tmp/clean-layer.sh

# The anaconda base image includes outdated versions of these packages. Update them to include the latest version.
RUN pip install seaborn python-dateutil dask python-igraph && \
pip install pyyaml joblib husl geopy ml_metrics mne pyshp && \
pip install pandas && \
Expand Down
9 changes: 6 additions & 3 deletions gpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ RUN apt-get install -y ocl-icd-libopencl1 clinfo libboost-all-dev && \
# the remaining pip commands: https://www.anaconda.com/using-pip-in-a-conda-environment/
# However, because this image is based on the CPU image, this isn't possible but better
# to put them at the top of this file to minize conflicts.
RUN conda remove --force -y pytorch torchvision torchaudio torchtext cpuonly && \
conda install "pytorch=1.7" "torchvision=0.8" "torchaudio=0.7" "torchtext=0.8" cudatoolkit=$CUDA_VERSION && \
conda install "cudf=21.06" "cuml=21.06" && \
RUN conda install cudf=21.06 cuml=21.06 cudatoolkit=$CUDA_VERSION && \
/tmp/clean-layer.sh

# Install Pytorch and torchvision with GPU support.
# Note: torchtext and torchaudio do not require a separate GPU package.
RUN pip install torch==1.7.1+cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION torchvision==0.8.2+cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://download.pytorch.org/whl/torch_stable.html && \
/tmp/clean-layer.sh

# Install LightGBM with GPU
Expand Down
6 changes: 0 additions & 6 deletions tests/test_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,9 @@
import matplotlib.pyplot as plt
import numpy as np

from mpl_toolkits.basemap import Basemap

class TestMatplotlib(unittest.TestCase):
def test_plot(self):
plt.plot(np.linspace(0,1,50), np.random.rand(50))
plt.savefig("plot1.png")

self.assertTrue(os.path.isfile("plot1.png"))

def test_basemap(self):
m = Basemap(width=100,height=100,projection='aeqd', lat_0=40,lon_0=-105)
self.assertEqual(0, m.xmin)

0 comments on commit 47a089a

Please sign in to comment.