Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure OpenMPI Environment #14118

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

jessecambon
Copy link

What does this PR do?

  • Adds a new AzureOpenMPIEnvironment cluster environment ClusterEnvironment for Azure/MPI #14014
  • The added local tests for AzureOpenMPIEnvironment pass, but it has not yet been tested on Azure to make sure the environment is properly detected.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

@awaelchli

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Aug 9, 2022
@jessecambon jessecambon changed the title DRAFT: Azure OpenMPI Environment Azure OpenMPI Environment Aug 9, 2022
Copy link
Member

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for sending this PR!!

but it has not yet been tested on Azure to make sure the environment is properly detected.

Do you have the resources to test this on azure?

Comment on lines +55 to +35
@property
def main_address(self) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! If we could follow the style of the other classes and move the properties to the top, followed by all methods, that would be nice!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed now

See Azure documentation here: https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-distributed-gpu#mpi
"""

def __init__(self, devices: int = 1) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer not pass devices in here. Do you think we could define the node_rank in a different way?
If I look at the docs here, I can see env variables for node rank. Is this also getting set on azure?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, the OMPI_COMM_WORLD_NODE_RANK variable doesn't appear to correspond to node rank. You can see the note about this in the Azure documentation:

Despite the name, environment variable OMPI_COMM_WORLD_NODE_RANK does not corresponds to the NODE_RANK. To use per-node-launcher, set process_count_per_node=1 and use OMPI_COMM_WORLD_RANK as the NODE_RANK.

The Open MPI docs also define this variable as:

OMPI_COMM_WORLD_NODE_RANK - the relative rank of this process on this node looking across ALL jobs.

I just ran a job on a 2 node setup with 2 V100 GPUs per node using Microsoft's docker image mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04 and it is setting OMPI_COMM_WORLD_NODE_RANK and OMPI_COMM_WORLD_LOCAL_RANK to the same values:

Driver 0

OMPI_COMM_WORLD_RANK = 0
OMPI_COMM_WORLD_LOCAL_RANK = 0
OMPI_COMM_WORLD_SIZE = 4
OMPI_COMM_WORLD_LOCAL_SIZE = 2
OMPI_COMM_WORLD_NODE_RANK = 0
OMPI_UNIVERSE_SIZE = 4

Driver 1

OMPI_COMM_WORLD_RANK = 1
OMPI_COMM_WORLD_LOCAL_RANK = 1
OMPI_COMM_WORLD_SIZE = 4
OMPI_COMM_WORLD_LOCAL_SIZE = 2
OMPI_COMM_WORLD_NODE_RANK = 1
OMPI_UNIVERSE_SIZE = 4

Driver 2

OMPI_COMM_WORLD_RANK = 2
OMPI_COMM_WORLD_LOCAL_RANK = 0
OMPI_COMM_WORLD_SIZE = 4
OMPI_COMM_WORLD_LOCAL_SIZE = 2
OMPI_COMM_WORLD_NODE_RANK = 0
OMPI_UNIVERSE_SIZE = 4

Driver 3

OMPI_COMM_WORLD_RANK = 3
OMPI_COMM_WORLD_LOCAL_RANK = 1
OMPI_COMM_WORLD_SIZE = 4
OMPI_COMM_WORLD_LOCAL_SIZE = 2
OMPI_COMM_WORLD_NODE_RANK = 1
OMPI_UNIVERSE_SIZE = 4

@awaelchli awaelchli self-assigned this Aug 11, 2022
@awaelchli awaelchli added feature Is an improvement or enhancement environment labels Aug 11, 2022
@jessecambon jessecambon force-pushed the azure-mpi-environment branch 2 times, most recently from afdb689 to d8952f0 Compare August 16, 2022 15:14
@github-actions github-actions bot removed the pl Generic label for PyTorch Lightning package label Aug 16, 2022
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Aug 16, 2022
@jessecambon jessecambon force-pushed the azure-mpi-environment branch 2 times, most recently from 792f162 to b90da1e Compare August 17, 2022 21:14
@jessecambon
Copy link
Author

jessecambon commented Aug 18, 2022

@awaelchli I've been testing this on Azure and the new AzureOpenMPIEnvironment environment is detected and behaves as expected with the exception of single node multi-gpu setups. For some reason, when testing with 1 node and 2 gpus I get errors related to master port regardless of what value I set it to. I tried statically setting master port either in the code or by setting the environmental variable MASTER_PORT. I also tried using the find_free_network_port() function from the lightning environment.

Azure does not set the AZ_BATCH_MASTER_NODE variable in a single node setting so you have to find another way to set the main port. This is what the error messages look like:

Driver 0

[W socket.cpp:401] [c10d] The server socket has failed to bind to [::]:6105 (errno: 98 - Address already in use).
[W socket.cpp:401] [c10d] The server socket has failed to bind to ?UNKNOWN? (errno: 98 - Address already in use).
[E socket.cpp:435] [c10d] The server socket has failed to listen on any local network address.

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 129, in _wrapping_function
    self._strategy._worker_setup(process_idx)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp_spawn.py", line 181, in _worker_setup
    init_dist_connection(
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 374, in init_dist_connection
    torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 595, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/torch/distributed/rendezvous.py", line 232, in _env_rendezvous_handler
    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/torch/distributed/rendezvous.py", line 160, in _create_c10d_store
    return TCPStore(
RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:6105 (errno: 98 - Address already in use). The server socket has failed to bind to ?UNKNOWN? (errno: 98 - Address already in use).

Driver 1:

Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 130, in _wrapping_function
    results = function(*args, **kwargs)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1118, in _run
    self.__setup_profiler()
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1745, in __setup_profiler
    self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 2218, in log_dir
    dirpath = self.strategy.broadcast(dirpath)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp_spawn.py", line 245, in broadcast
    torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1869, in broadcast_object_list
    broadcast(object_sizes_tensor, src=src, group=group)
  File "/opt/miniconda/envs/ptldev/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1187, in broadcast
    work = default_pg.broadcast([tensor], opts)
RuntimeError: [1] is setting up NCCL communicator and retreiving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: Broken pipe

This may be a broader issue as I get these issues whether I use an MPIConfiguration or a PytorchConfiguration and whether I use this new Azure environment or not. Here is my code for reference:

import torch, os, sys
from torch.utils.data import DataLoader, Dataset
#from deepspeed.ops.adam import FusedAdam
#from azureml.core import Run, Workspace
from pytorch_lightning import LightningModule, Trainer, LightningDataModule, seed_everything
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.plugins.environments import AzureOpenMPIEnvironment
#from pytorch_lightning.strategies import DeepSpeedStrategy
from argparse import ArgumentParser

device = 'gpu' if torch.cuda.is_available() else 'cpu'

print(f"device = {device}")
divider_str="-"*40

def get_env_display_text(var_name):
    var_value = os.environ.get(var_name, "")
    return f"{var_name} = {var_value}"

def display_environment(header='Environmental variables'):
    """
    Print a few environment variables of note
    """
    variable_names = [
        "PL_GLOBAL_SEED",
        "PL_SEED_WORKERS",
        "AZ_BATCH_MASTER_NODE",
        "AZ_BATCHAI_MPI_MASTER_NODE",
        "MASTER_ADDR",
        "MASTER_ADDRESS",
        "MASTER_PORT",
        "RANK",
        "NODE_RANK",
        "LOCAL_RANK",
        "GLOBAL_RANK",
        "WORLD_SIZE",
        "NCCL_SOCKET_IFNAME",
        "OMPI_COMM_WORLD_RANK",
        "OMPI_COMM_WORLD_LOCAL_RANK",
        "OMPI_COMM_WORLD_SIZE",
        "OMPI_COMM_WORLD_LOCAL_SIZE",
        "OMPI_COMM_WORLD_NODE_RANK",
        "OMPI_UNIVERSE_SIZE"
    ]

    var_text = "\n".join([get_env_display_text(var) for var in variable_names])
    print(f"\n{header}:\n{divider_str}\n{var_text}\n{divider_str}\n")

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters())
        #return FusedAdam(self.model.parameters())
    
    def setup(self, stage=None) -> None:
        # prevents hanging
        if stage != "fit":
            return

class DataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.num_workers = os.cpu_count()
        print(f"num_workers set to {self.num_workers}")

    def setup(self, stage=None) -> None:
        self._dataloader = DataLoader(
            RandomDataset(32, 64),
             num_workers=self.num_workers,
             batch_size=1,
             pin_memory=True
             )

    def train_dataloader(self):
        return self._dataloader
    
    def test_dataloader(self):
        return self._dataloader

    def val_dataloader(self):
        return self._dataloader

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--num_nodes", default = 1, type=int)
    parser.add_argument("--devices", default = 1, type=int, help='Number of devices per node')
    args = parser.parse_args()

    seed_everything(102938, workers = True)

    model = BoringModel()
    dm = DataModule()

    #os.environ['MASTER_PORT'] = "6105"
    display_environment("__main__")

    trainer = Trainer(
        num_nodes=args.num_nodes,
        accelerator=device,
        devices=args.devices,
        precision=16,
        limit_train_batches=2,
        limit_val_batches=2,
        log_every_n_steps=1,
        logger=False,
        enable_checkpointing=False,
        num_sanity_val_steps=0,
        max_epochs=2,
        enable_model_summary=False,
        #strategy = "deepspeed_stage_3",
        #plugins=[AzureOpenMPIEnvironment(devices=args.devices)],
        # strategy=DDPStrategy(
        #     cluster_environment = AzureOpenMPIEnvironment(devices=args.devices)
        # ),
        # strategy=DeepSpeedStrategy(
        #     stage = 3,
        #     cluster_environment = AzureOpenMPIEnvironment(devices=args.devices)
        # ),
    )

    # extract cluster environment 
    trainer_cluster_environment = trainer._accelerator_connector.cluster_environment
    print(f"trainer cluster environment: {trainer_cluster_environment}")
    print(f"Was Azure OpenMPI environment used? {type(trainer_cluster_environment) == AzureOpenMPIEnvironment}")

    trainer.fit(model, datamodule=dm)

    print(f"""trainer.local_rank: {trainer.local_rank}
trainer.global_rank : {trainer.global_rank}
trainer.world_size : {trainer.world_size}
""")

Environment:

* CUDA:
	- GPU:
		- Tesla K80
		- Tesla K80
		- Tesla K80
		- Tesla K80
	- available:         True
	- version:           11.3
* Lightning:
	- lightning:         2022.8.18
	- lightning-cloud:   0.5.3
	- torch:             1.11.0
	- torchaudio:        0.11.0
	- torchmetrics:      0.9.3
	- torchvision:       0.12.0
* Packages:
	- absl-py:           1.2.0
	- accelerate:        0.12.0
	- adal:              1.2.7
	- aiobotocore:       2.3.4
	- aiohttp:           3.8.1
	- aioitertools:      0.10.0
	- aiosignal:         1.2.0
	- anyio:             3.6.1
	- argcomplete:       2.0.0
	- asgiref:           3.5.2
	- async-timeout:     4.0.2
	- attrs:             22.1.0
	- azure-common:      1.1.28
	- azure-core:        1.25.0
	- azure-graphrbac:   0.61.1
	- azure-identity:    1.10.0
	- azure-mgmt-authorization: 2.0.0
	- azure-mgmt-containerregistry: 10.0.0
	- azure-mgmt-core:   1.3.0
	- azure-mgmt-keyvault: 10.1.0
	- azure-mgmt-resource: 21.1.0
	- azure-mgmt-storage: 20.0.0
	- azure-storage-blob: 12.9.0
	- azureml-core:      1.44.0
	- azureml-dataprep:  4.2.2
	- azureml-dataprep-native: 38.0.0
	- azureml-dataprep-rslex: 2.8.1
	- azureml-dataset-runtime: 1.44.0
	- azureml-defaults:  1.44.0
	- azureml-inference-server-http: 0.7.4
	- azureml-mlflow:    1.44.0
	- backports.tempfile: 1.0
	- backports.weakref: 1.0.post1
	- bcrypt:            3.2.2
	- botocore:          1.24.21
	- brotlipy:          0.7.0
	- cachetools:        5.2.0
	- certifi:           2022.6.15
	- cffi:              1.15.0
	- charset-normalizer: 2.0.4
	- click:             8.1.3
	- cloudpickle:       2.1.0
	- commonmark:        0.9.1
	- configparser:      3.7.4
	- contextlib2:       21.6.0
	- croniter:          1.3.5
	- cryptography:      37.0.1
	- databricks-cli:    0.17.1
	- datasets:          2.4.0
	- deepdiff:          5.8.1
	- deepspeed:         0.7.0
	- dill:              0.3.5.1
	- distro:            1.7.0
	- dnspython:         2.2.1
	- docker:            5.0.3
	- dotnetcore2:       3.1.23
	- email-validator:   1.2.1
	- entrypoints:       0.4
	- fastapi:           0.79.0
	- filelock:          3.8.0
	- flask:             2.1.3
	- flask-cors:        3.0.10
	- frozenlist:        1.3.1
	- fsspec:            2022.7.1
	- fusepy:            3.0.1
	- gitdb:             4.0.9
	- gitpython:         3.1.27
	- google-api-core:   2.8.2
	- google-auth:       2.10.0
	- google-auth-oauthlib: 0.4.6
	- googleapis-common-protos: 1.56.4
	- grpcio:            1.47.0
	- gunicorn:          20.1.0
	- h11:               0.13.0
	- hjson:             3.1.0
	- httptools:         0.4.0
	- huggingface-hub:   0.8.1
	- humanfriendly:     10.0
	- idna:              3.3
	- importlib-metadata: 4.12.0
	- importlib-resources: 5.9.0
	- inference-schema:  1.4.2
	- isodate:           0.6.1
	- itsdangerous:      2.1.2
	- jeepney:           0.8.0
	- jinja2:            3.1.2
	- jmespath:          1.0.0
	- joblib:            1.1.0
	- json-logging-py:   0.2
	- jsonpickle:        2.2.0
	- jsonschema:        4.12.1
	- knack:             0.9.0
	- lightning:         2022.8.18
	- lightning-cloud:   0.5.3
	- markdown:          3.4.1
	- markupsafe:        2.1.1
	- mkl-fft:           1.3.1
	- mkl-random:        1.2.2
	- mkl-service:       2.4.0
	- mlflow-skinny:     1.28.0
	- msal:              1.18.0
	- msal-extensions:   1.0.0
	- msrest:            0.7.1
	- msrestazure:       0.6.4
	- multidict:         6.0.2
	- multiprocess:      0.70.13
	- ndg-httpsclient:   0.5.1
	- ninja:             1.10.2.3
	- numpy:             1.22.3
	- oauthlib:          3.2.0
	- opencensus:        0.11.0
	- opencensus-context: 0.1.3
	- opencensus-ext-azure: 1.1.6
	- ordered-set:       4.1.0
	- orjson:            3.7.12
	- packaging:         21.3
	- pandas:            1.4.3
	- paramiko:          2.11.0
	- pathspec:          0.9.0
	- pillow:            9.0.1
	- pip:               20.0.2
	- pkginfo:           1.8.3
	- pkgutil-resolve-name: 1.3.10
	- portalocker:       2.5.1
	- protobuf:          3.19.4
	- psutil:            5.9.1
	- py-cpuinfo:        8.0.0
	- pyarrow:           9.0.0
	- pyasn1:            0.4.8
	- pyasn1-modules:    0.2.8
	- pycparser:         2.21
	- pydantic:          1.9.2
	- pydeprecate:       0.3.2
	- pygments:          2.13.0
	- pyjwt:             2.4.0
	- pynacl:            1.5.0
	- pyopenssl:         22.0.0
	- pyparsing:         3.0.9
	- pyrsistent:        0.18.1
	- pysocks:           1.7.1
	- python-dateutil:   2.8.2
	- python-dotenv:     0.20.0
	- python-multipart:  0.0.5
	- pytz:              2022.2.1
	- pyyaml:            6.0
	- requests:          2.27.1
	- requests-oauthlib: 1.3.1
	- responses:         0.18.0
	- rich:              12.5.1
	- rsa:               4.9
	- s3fs:              2022.7.1
	- scikit-learn:      1.1.2
	- scipy:             1.9.0
	- secretstorage:     3.3.3
	- sentencepiece:     0.1.97
	- setuptools:        61.2.0
	- six:               1.16.0
	- sklearn:           0.0
	- smmap:             5.0.0
	- sniffio:           1.2.0
	- sqlparse:          0.4.2
	- starlette:         0.20.4
	- starsessions:      1.3.0
	- tabulate:          0.8.10
	- tensorboard:       2.10.0
	- tensorboard-data-server: 0.6.1
	- tensorboard-plugin-wit: 1.8.1
	- threadpoolctl:     3.1.0
	- torch:             1.11.0
	- torchaudio:        0.11.0
	- torchmetrics:      0.9.3
	- torchvision:       0.12.0
	- tqdm:              4.64.0
	- traitlets:         5.3.0
	- typing-extensions: 4.1.1
	- ujson:             5.4.0
	- urllib3:           1.26.9
	- uvicorn:           0.17.6
	- uvloop:            0.16.0
	- watchgod:          0.8.2
	- websocket-client:  1.3.3
	- websockets:        10.3
	- werkzeug:          2.2.2
	- wheel:             0.37.1
	- wrapt:             1.14.1
	- xxhash:            3.0.0
	- yarl:              1.8.1
	- zipp:              3.8.1
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.8.13
	- version:           #38-Ubuntu SMP Sun Mar 22 21:27:21 UTC 2020

Building from the base docker image : mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04

@awaelchli
Copy link
Member

@jessecambon Thanks for investigating this. Unfortunate that the edge case is single-node behavior. Does your Azure OpenMPI environment get selected correctly when running on single node too?

if "AZ_BATCH_MASTER_NODE" in os.environ:
return int(os.environ.get("AZ_BATCH_MASTER_NODE").split(":")[1])
else:
return int(os.environ.get("MASTER_PORT", find_free_network_port()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to be careful here: If each process calls find_free_network_port(), all processes will end up with a different "generated" port. Maybe you took this code from our LightningEnvironment, but it is important to note that we do it there only because we know we call it on the main process first, and then assign

os.environ["MASTER_PORT"] = cluster_env.main_port

Here, this won't be the case because the assumption is that processes get created externally.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could explain why you got
RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:6105 (errno: 98 - Address already in use). The server socket has failed to bind to ?UNKNOWN? (errno: 98 - Address already in use).

Copy link
Member

@awaelchli awaelchli Aug 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessecambon What if we hardcode a port that we know is available. Could also be made configurable as a parameter input to the __init__. If you hardcode it now, at least you can verify that everything else works normally in the single node configuration.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli thanks, that makes sense. I removed find_free_network_port() and changed the code back to first look in the MASTER_PORT environmental variable and then default to a hardcoded port value. For some reason I still get these errors whether I set MASTER_PORT or let it use the hardcoded port value:

RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:57345 (errno: 98 - Address already in use). The server socket has failed to bind to ?UNKNOWN? (errno: 98 - Address already in use).

@justusschock
Copy link
Member

@awaelchli could we port this to lite?

@awaelchli
Copy link
Member

We added an environment to handle MPI here: #16570. The Azure version could rely on this as well.

Copy link

gitguardian bot commented Jan 16, 2024

⚠️ GitGuardian has uncovered 2 secrets following the scan of your pull request.

Please consider investigating the findings and remediating the incidents. Failure to do so may lead to compromising the associated services or software components.

🔎 Detected hardcoded secrets in your pull request
GitGuardian id Secret Commit Filename
- Generic High Entropy Secret 78fa3af tests/tests_app/utilities/test_login.py View secret
- Base64 Basic Authentication 78fa3af tests/tests_app/utilities/test_login.py View secret
🛠 Guidelines to remediate hardcoded secrets
  1. Understand the implications of revoking this secret by investigating where it is used in your code.
  2. Replace and store your secret safely. Learn here the best practices.
  3. Revoke and rotate this secret.
  4. If possible, rewrite git history. Rewriting git history is not a trivial act. You might completely break other contributing developers' workflow and you risk accidentally deleting legitimate data.

To avoid such incidents in the future consider


🦉 GitGuardian detects secrets in your source code to help developers and security teams secure the modern development process. You are seeing this because you or someone else with access to this repository has authorized GitGuardian to scan your pull request.

Our GitHub checks need improvements? Share your feedbacks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
environment feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants