Skip to content

Commit

Permalink
Raise error if torch-scatter is not installed or wrong version is ins…
Browse files Browse the repository at this point in the history
…talled (#2486)

* automatically download correct torch-scatter version

* raise error if torch-scatter is not installed

* Update Documentation & Code Style

* catch all import errors and fix linter

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
MichelBartels and github-actions[bot] committed May 5, 2022
1 parent 1418f0c commit 5d98810
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
5 changes: 4 additions & 1 deletion docs/_src/tutorials/tutorials/15.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ Make sure you enable the GPU runtime to experience decent speed in this tutorial
!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab]

# The TaPAs-based TableReader requires the torch-scatter library
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
import torch

version = torch.__version__
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{version}.html

# Install pygraphviz for visualization of Pipelines
!apt install libgraphviz-dev
Expand Down
18 changes: 18 additions & 0 deletions haystack/nodes/reader/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
from haystack.nodes.reader.base import BaseReader
from haystack.modeling.utils import initialize_device_settings

torch_scatter_installed = True
torch_scatter_wrong_version = False
try:
import torch_scatter # pylint: disable=unused-import
except ImportError:
torch_scatter_installed = False
except OSError:
torch_scatter_wrong_version = True


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -95,6 +104,15 @@ def __init__(
query + table exceed max_seq_len, the table will be truncated by removing rows until the
input size fits the model.
"""
if not torch_scatter_installed:
raise ImportError(
"Please install torch_scatter to use TableReader. You can follow the instructions here: https://github.com/rusty1s/pytorch_scatter"
)
if torch_scatter_wrong_version:
raise ImportError(
"torch_scatter could not be loaded. This could be caused by a mismatch between your cuda version and the one used by torch_scatter."
"Please try to reinstall torch-scatter. You can follow the instructions here: https://github.com/rusty1s/pytorch_scatter"
)
super().__init__()

self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
Expand Down
5 changes: 4 additions & 1 deletion tutorials/Tutorial15_TableQA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@
"!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab]\n",
"\n",
"# The TaPAs-based TableReader requires the torch-scatter library\n",
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
"import torch\n",
"\n",
"version = torch.__version__\n",
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-{version}.html\n",
"\n",
"# Install pygraphviz for visualization of Pipelines\n",
"!apt install libgraphviz-dev\n",
Expand Down

0 comments on commit 5d98810

Please sign in to comment.