Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[k8s] Remove SSH jump pod for port-forward mode #3657

Merged
merged 16 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,29 +439,38 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
f'Key {secret_name} does not exist in the cluster, creating it...')
kubernetes.core_api().create_namespaced_secret(namespace, secret)

ssh_jump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME
private_key_path, _ = get_or_generate_keys()
if network_mode == nodeport_mode:
ssh_jump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME
service_type = kubernetes_enums.KubernetesServiceType.NODEPORT
# Setup service for SSH jump pod. We create the SSH jump service here
# because we need to know the service IP address and port to set the
# ssh_proxy_command in the autoscaler config.
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace,
service_type)
ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
ssh_jump_name,
nodeport_mode,
private_key_path=private_key_path,
namespace=namespace)
elif network_mode == port_forward_mode:
# Using `kubectl port-forward` creates a direct tunnel to the pod and
# does not require a ssh jump pod.
kubernetes_utils.check_port_forward_mode_dependencies()
# Using `kubectl port-forward` creates a direct tunnel to jump pod and
# does not require opening any ports on Kubernetes nodes. As a result,
# the service can be a simple ClusterIP service which we access with
# `kubectl port-forward`.
service_type = kubernetes_enums.KubernetesServiceType.CLUSTERIP
# TODO(romilb): This can be further optimized. Instead of using the
# head node as a jump pod for worker nodes, we can also directly
# set the ssh_target to the worker node. However, that requires
# changes in the downstream code to return a mapping of node IPs to
# pod names (to be used as ssh_target) and updating the upstream
# SSHConfigHelper to use a different ProxyCommand for each pod.
# This optimization can reduce SSH time from ~0.35s to ~0.25s, tested
# on GKE.
ssh_target = config['cluster_name'] + '-head'
ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
ssh_target, port_forward_mode, private_key_path=private_key_path)
else:
# This should never happen because we check for this in from_str above.
raise ValueError(f'Unsupported networking mode: {network_mode_str}')
# Setup service for SSH jump pod. We create the SSH jump service here
# because we need to know the service IP address and port to set the
# ssh_proxy_command in the autoscaler config.
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, service_type)

ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
PRIVATE_SSH_KEY_PATH, ssh_jump_name, network_mode, namespace,
clouds.Kubernetes.PORT_FORWARD_PROXY_CMD_PATH,
clouds.Kubernetes.PORT_FORWARD_PROXY_CMD_TEMPLATE)

config['auth']['ssh_proxy_command'] = ssh_proxy_cmd

return config
Expand Down
6 changes: 6 additions & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,12 @@ def ssh_credential_from_yaml(
ssh_private_key = auth_section.get('ssh_private_key')
ssh_control_name = config.get('cluster_name', '__default__')
ssh_proxy_command = auth_section.get('ssh_proxy_command')

# Update the ssh_user placeholder in proxy command, if required
if (ssh_proxy_command is not None and
constants.SKY_SSH_USER_PLACEHOLDER in ssh_proxy_command):
ssh_proxy_command = ssh_proxy_command.replace(
constants.SKY_SSH_USER_PLACEHOLDER, ssh_user)
credentials = {
'ssh_user': ssh_user,
'ssh_private_key': ssh_private_key,
Expand Down
5 changes: 4 additions & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3065,7 +3065,10 @@ def _update_after_cluster_provisioned(
)
usage_lib.messages.usage.update_final_cluster_status(
status_lib.ClusterStatus.UP)
auth_config = common_utils.read_yaml(handle.cluster_yaml)['auth']
auth_config = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml,
ssh_user=handle.ssh_user,
docker_user=handle.docker_user)
Comment on lines +3068 to +3071
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we test this for other clouds with image_id specified with docker:xxx, just to make sure changing this will not affect ssh for those?

Or, if we have passed in the ssh_user and docker_user here, should we remove the argument of handle.docker_user and handle.ssh_user in the add_cluster function below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - tested with pytest tests/test_smoke.py::test_job_queue_with_docker --gcp.

backend_utils.SSHConfigHelper.add_cluster(handle.cluster_name,
ip_list, auth_config,
ssh_port_list,
Expand Down
4 changes: 1 addition & 3 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class Kubernetes(clouds.Cloud):

SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-keys'
SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod'
PORT_FORWARD_PROXY_CMD_TEMPLATE = \
'kubernetes-port-forward-proxy-command.sh.j2'
PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/port-forward-proxy-cmd.sh'
# Timeout for resource provisioning. This timeout determines how long to
# wait for pod to be in pending status before giving up.
# Larger timeout may be required for autoscaling clusters, since autoscaler
Expand Down Expand Up @@ -323,6 +320,7 @@ def make_deploy_resources_variables(
'k8s_namespace':
kubernetes_utils.get_current_kube_config_context_namespace(),
'k8s_port_mode': port_mode.value,
'k8s_networking_mode': network_utils.get_networking_mode().value,
'k8s_ssh_key_secret_name': self.SKY_SSH_KEY_SECRET_NAME,
'k8s_acc_label_key': k8s_acc_label_key,
'k8s_acc_label_value': k8s_acc_label_value,
Expand Down
7 changes: 6 additions & 1 deletion sky/provision/kubernetes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from sky.adaptors import kubernetes
from sky.provision import common
from sky.provision.kubernetes import network_utils
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.utils import kubernetes_enums

logger = logging.getLogger(__name__)

Expand All @@ -25,7 +27,10 @@ def bootstrap_instances(

_configure_services(namespace, config.provider_config)

config = _configure_ssh_jump(namespace, config)
networking_mode = network_utils.get_networking_mode(
config.provider_config.get('networking_mode'))
if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
config = _configure_ssh_jump(namespace, config)

requested_service_account = config.node_config['spec']['serviceAccountName']
if (requested_service_account ==
Expand Down
17 changes: 11 additions & 6 deletions sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sky.provision import common
from sky.provision import docker_utils
from sky.provision.kubernetes import config as config_lib
from sky.provision.kubernetes import network_utils
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.utils import command_runner
from sky.utils import common_utils
Expand Down Expand Up @@ -493,14 +494,18 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
if head_pod_name is None:
head_pod_name = pod.metadata.name

# Adding the jump pod to the new_nodes list as well so it can be
# checked if it's scheduled and running along with other pods.
ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump']
jump_pod = kubernetes.core_api().read_namespaced_pod(
ssh_jump_pod_name, namespace)
wait_pods_dict = _filter_pods(namespace, tags, ['Pending'])
wait_pods = list(wait_pods_dict.values())
wait_pods.append(jump_pod)

networking_mode = network_utils.get_networking_mode(
config.provider_config.get('networking_mode'))
if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
# Adding the jump pod to the new_nodes list as well so it can be
# checked if it's scheduled and running along with other pods.
ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump']
jump_pod = kubernetes.core_api().read_namespaced_pod(
ssh_jump_pod_name, namespace)
wait_pods.append(jump_pod)
provision_timeout = provider_config['timeout']

wait_str = ('indefinitely'
Expand Down
17 changes: 17 additions & 0 deletions sky/provision/kubernetes/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ def get_port_mode(
return port_mode


def get_networking_mode(
mode_str: Optional[str] = None
) -> kubernetes_enums.KubernetesNetworkingMode:
"""Get the networking mode from the provider config."""
mode_str = mode_str or skypilot_config.get_nested(
('kubernetes', 'networking_mode'),
kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD.value)
try:
networking_mode = kubernetes_enums.KubernetesNetworkingMode.from_str(
mode_str)
except ValueError as e:
with ux_utils.print_exception_no_traceback():
raise ValueError(str(e) +
' Please check: ~/.sky/config.yaml.') from None
return networking_mode


def fill_loadbalancer_template(namespace: str, service_name: str,
ports: List[int], selector_key: str,
selector_value: str) -> Dict:
Expand Down
108 changes: 71 additions & 37 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import os
import re
import shutil
import subprocess
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from urllib.parse import urlparse
Expand All @@ -16,6 +17,7 @@
from sky import skypilot_config
from sky.adaptors import kubernetes
from sky.provision.kubernetes import network_utils
from sky.skylet import constants
from sky.utils import common_utils
from sky.utils import env_options
from sky.utils import kubernetes_enums
Expand Down Expand Up @@ -53,6 +55,10 @@

KIND_CONTEXT_NAME = 'kind-skypilot' # Context name used by sky local up

# Port-forward proxy command constants
PORT_FORWARD_PROXY_CMD_TEMPLATE = 'kubernetes-port-forward-proxy-command.sh'
PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/kubernetes-port-forward-proxy-command.sh'

logger = sky_logging.init_logger(__name__)


Expand Down Expand Up @@ -911,30 +917,38 @@ def __str__(self):
return self.name


def construct_ssh_jump_command(private_key_path: str,
ssh_jump_ip: str,
ssh_jump_port: Optional[int] = None,
proxy_cmd_path: Optional[str] = None) -> str:
def construct_ssh_jump_command(
private_key_path: str,
ssh_jump_ip: str,
ssh_jump_port: Optional[int] = None,
ssh_jump_user: str = 'sky',
proxy_cmd_path: Optional[str] = None,
proxy_cmd_target_pod: Optional[str] = None) -> str:
ssh_jump_proxy_command = (f'ssh -tt -i {private_key_path} '
'-o StrictHostKeyChecking=no '
'-o UserKnownHostsFile=/dev/null '
f'-o IdentitiesOnly=yes '
f'-W %h:%p sky@{ssh_jump_ip}')
f'-W %h:%p {ssh_jump_user}@{ssh_jump_ip}')
if ssh_jump_port is not None:
ssh_jump_proxy_command += f' -p {ssh_jump_port} '
if proxy_cmd_path is not None:
proxy_cmd_path = os.path.expanduser(proxy_cmd_path)
# adding execution permission to the proxy command script
os.chmod(proxy_cmd_path, os.stat(proxy_cmd_path).st_mode | 0o111)
ssh_jump_proxy_command += f' -o ProxyCommand=\'{proxy_cmd_path}\' '
ssh_jump_proxy_command += (f' -o ProxyCommand=\'{proxy_cmd_path} '
f'{proxy_cmd_target_pod}\' ')
return ssh_jump_proxy_command


def get_ssh_proxy_command(
private_key_path: str, ssh_jump_name: str,
network_mode: kubernetes_enums.KubernetesNetworkingMode, namespace: str,
port_fwd_proxy_cmd_path: str, port_fwd_proxy_cmd_template: str) -> str:
"""Generates the SSH proxy command to connect through the SSH jump pod.
k8s_ssh_target: str,
network_mode: kubernetes_enums.KubernetesNetworkingMode,
private_key_path: Optional[str] = None,
namespace: Optional[str] = None) -> str:
"""Generates the SSH proxy command to connect to the pod.

Uses a jump pod if the network mode is NODEPORT, and direct port-forwarding
if the network mode is PORTFORWARD.

By default, establishing an SSH connection creates a communication
channel to a remote node by setting up a TCP connection. When a
Expand All @@ -950,57 +964,77 @@ def get_ssh_proxy_command(

With the NodePort networking mode, a NodePort service is launched. This
service opens an external port on the node which redirects to the desired
port within the pod. When establishing an SSH session in this mode, the
port to a SSH jump pod. When establishing an SSH session in this mode, the
ProxyCommand makes use of this external port to create a communication
channel directly to port 22, which is the default port ssh server listens
on, of the jump pod.

With Port-forward mode, instead of directly exposing an external port,
'kubectl port-forward' sets up a tunnel between a local port
(127.0.0.1:23100) and port 22 of the jump pod. Then we establish a TCP
(127.0.0.1:23100) and port 22 of the provisioned pod. Then we establish TCP
connection to the local end of this tunnel, 127.0.0.1:23100, using 'socat'.
This is setup in the inner ProxyCommand of the nested ProxyCommand, and the
rest is the same as NodePort approach, which the outer ProxyCommand
establishes a communication channel between 127.0.0.1:23100 and port 22 on
the jump pod. Consequently, any stdin provided on the local machine is
forwarded through this tunnel to the application (SSH server) listening in
the pod. Similarly, any output from the application in the pod is tunneled
back and displayed in the terminal on the local machine.
All of this is done in a ProxyCommand script. Any stdin provided on the
local machine is forwarded through this tunnel to the application
(SSH server) listening in the pod. Similarly, any output from the
application in the pod is tunneled back and displayed in the terminal on
the local machine.

Args:
private_key_path: str; Path to the private key to use for SSH.
This key must be authorized to access the SSH jump pod.
ssh_jump_name: str; Name of the SSH jump service to use
k8s_ssh_target: str; The Kubernetes object that will be used as the
target for SSH. If network_mode is NODEPORT, this is the name of the
service. If network_mode is PORTFORWARD, this is the pod name.
network_mode: KubernetesNetworkingMode; networking mode for ssh
session. It is either 'NODEPORT' or 'PORTFORWARD'
namespace: Kubernetes namespace to use
port_fwd_proxy_cmd_path: str; path to the script used as Proxycommand
with 'kubectl port-forward'
port_fwd_proxy_cmd_template: str; template used to create
'kubectl port-forward' Proxycommand
private_key_path: str; Path to the private key to use for SSH.
This key must be authorized to access the SSH jump pod.
Required for NODEPORT networking mode.
namespace: Kubernetes namespace to use.
Required for NODEPORT networking mode.
"""
# Fetch IP to connect to for the jump svc
ssh_jump_ip = get_external_ip(network_mode)
assert private_key_path is not None, 'Private key path must be provided'
if network_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
ssh_jump_port = get_port(ssh_jump_name, namespace)
assert namespace is not None, 'Namespace must be provided for NodePort'
ssh_jump_port = get_port(k8s_ssh_target, namespace)
ssh_jump_proxy_command = construct_ssh_jump_command(
private_key_path, ssh_jump_ip, ssh_jump_port=ssh_jump_port)
# Setting kubectl port-forward/socat to establish ssh session using
# ClusterIP service to disallow any ports opened
else:
vars_to_fill = {
'ssh_jump_name': ssh_jump_name,
}
common_utils.fill_template(port_fwd_proxy_cmd_template,
vars_to_fill,
output_path=port_fwd_proxy_cmd_path)
ssh_jump_proxy_command_path = create_proxy_command_script()
ssh_jump_proxy_command = construct_ssh_jump_command(
private_key_path,
ssh_jump_ip,
proxy_cmd_path=port_fwd_proxy_cmd_path)
ssh_jump_user=constants.SKY_SSH_USER_PLACEHOLDER,
proxy_cmd_path=ssh_jump_proxy_command_path,
proxy_cmd_target_pod=k8s_ssh_target)
return ssh_jump_proxy_command


def create_proxy_command_script() -> str:
"""Creates a ProxyCommand script that uses kubectl port-forward to setup
a tunnel between a local port and the SSH server in the pod.

Returns:
str: Path to the ProxyCommand script.
"""
port_fwd_proxy_cmd_path = os.path.expanduser(PORT_FORWARD_PROXY_CMD_PATH)
os.makedirs(os.path.dirname(port_fwd_proxy_cmd_path),
exist_ok=True,
mode=0o700)

root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
template_path = os.path.join(root_dir, 'templates',
PORT_FORWARD_PROXY_CMD_TEMPLATE)
# Copy the template to the proxy command path. We create a copy to allow
# different users sharing the same SkyPilot installation to have their own
# proxy command scripts.
shutil.copy(template_path, port_fwd_proxy_cmd_path)
# Set the permissions to 700 to ensure only the owner can read, write,
# and execute the file.
os.chmod(port_fwd_proxy_cmd_path, 0o700)
return port_fwd_proxy_cmd_path


def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str,
service_type: kubernetes_enums.KubernetesServiceType):
"""Sets up Kubernetes service resource to access for SSH jump pod.
Expand Down
4 changes: 4 additions & 0 deletions sky/skylet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,7 @@
SKYPILOT_NODE_IPS = 'SKYPILOT_NODE_IPS'
SKYPILOT_NUM_GPUS_PER_NODE = 'SKYPILOT_NUM_GPUS_PER_NODE'
SKYPILOT_NODE_RANK = 'SKYPILOT_NODE_RANK'

# Placeholder for the SSH user in proxy command, replaced when the ssh_user is
# known after provisioning.
SKY_SSH_USER_PLACEHOLDER = 'skypilot:ssh_user'
Loading
Loading