Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug of tensors not on the same device when running on CUDA device #59

Merged
merged 10 commits into from
Apr 19, 2023
Merged
Prev Previous commit
Next Next commit
feat: add arg verbose to enhance execute_command() in pypots-cli an…
…d delete pytest cache files finally in dev;
  • Loading branch information
WenjieDu committed Apr 19, 2023
commit 75e4ee0db9c666d7ebec2c8a36af1a6e7d2c2b32
46 changes: 32 additions & 14 deletions pypots/utils/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""

Expose the base class for PyPOTS CLI (command line interface).
"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3


import subprocess
import sys
from abc import ABC, abstractmethod
from argparse import ArgumentParser

from pypots.utils.logging import logger


class BaseCommand(ABC):
@staticmethod
Expand All @@ -18,19 +21,34 @@ def register_subcommand(parser: ArgumentParser):
raise NotImplementedError()

@staticmethod
def execute_command(command):
exec_result = subprocess.run(
command,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True,
)
if exec_result.returncode != 0:
if len(exec_result.stderr) > 0:
raise RuntimeError(exec_result.stderr)
if len(exec_result.stdout) > 0:
raise RuntimeError(exec_result.stdout)
def execute_command(command: str, verbose: bool = True):
if verbose:
exec_result = subprocess.Popen(
command,
stdin=subprocess.PIPE,
close_fds=True,
stderr=sys.stderr,
stdout=sys.stdout,
universal_newlines=True,
shell=True,
bufsize=1,
)
exec_result.communicate()
else:
exec_result = subprocess.run(
command,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True,
)
if exec_result.returncode != 0:
if len(exec_result.stderr) > 0:
logger.error(exec_result.stderr)
if len(exec_result.stdout) > 0:
logger.error(exec_result.stdout)
raise RuntimeError()
return exec_result.returncode

@abstractmethod
def run(self):
Expand Down
10 changes: 4 additions & 6 deletions pypots/utils/commands/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,21 +161,19 @@ def run(self):
else pytest_command
)
logger.info(f"Executing '{command_to_run_test}'...")

self.execute_command(command_to_run_test)
if self._show_coverage:
self.execute_command("coverage report -m")
self.execute_command("rm -rf .coverage")

self.execute_command("rm -rf .pytest_cache")

elif self._lint_code:
logger.info("Reformatting with Black...")
self.execute_command("black .")
logger.info("Linting with Flake8...")
self.execute_command("flake8 .")

except ImportError:
raise ImportError(IMPORT_ERROR_MESSAGE)
except Exception as e:
raise RuntimeError(e)
finally:
shutil.rmtree(".pytest_cache", ignore_errors=True)
if os.path.exists(".coverage"):
os.remove(".coverage")