Skip to content

Commit

Permalink
Add parent pid to persistent connection socket path hash (ansible#33518)
Browse files Browse the repository at this point in the history
* Add parent pid to persistent connection socket path hash

Fixes ansible#33192

*  Add parent pid in persistent connection socket path hash
   to avoid using same socket path for multiple simultaneous
   connection to same remote host.

* Ensure unique persistent socket path for each ansible-playbook run

* Fix CI failures
  • Loading branch information
ganeshrn committed Dec 15, 2017
1 parent 08a2338 commit 2f932d8
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 10 deletions.
12 changes: 7 additions & 5 deletions bin/ansible-connection
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ConnectionProcess(object):
The connection process wraps around a Connection object that manages
the connection to a remote device that persists over the playbook
'''
def __init__(self, fd, play_context, socket_path, original_path):
def __init__(self, fd, play_context, socket_path, original_path, ansible_playbook_pid=None):
self.play_context = play_context
self.socket_path = socket_path
self.original_path = original_path
Expand All @@ -52,6 +52,7 @@ class ConnectionProcess(object):
self.sock = None

self.connection = None
self._ansible_playbook_pid = ansible_playbook_pid

def start(self):
try:
Expand All @@ -65,8 +66,8 @@ class ConnectionProcess(object):
# find it now that our cwd is /
if self.play_context.private_key_file and self.play_context.private_key_file[0] not in '~/':
self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file)

self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null')
self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null',
ansible_playbook_pid=self._ansible_playbook_pid)
self.connection.set_options()
self.connection._connect()
self.connection._socket_path = self.socket_path
Expand Down Expand Up @@ -244,7 +245,8 @@ def main():

if rc == 0:
ssh = connection_loader.get('ssh', class_only=True)
cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user, play_context.connection)
ansible_playbook_pid = sys.argv[1]
cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user, play_context.connection, ansible_playbook_pid)

# create the persistent connection dir if need be and create the paths
# which we will be using later
Expand All @@ -268,7 +270,7 @@ def main():
try:
os.close(r)
wfd = os.fdopen(w, 'w')
process = ConnectionProcess(wfd, play_context, socket_path, original_path)
process = ConnectionProcess(wfd, play_context, socket_path, original_path, ansible_playbook_pid)
process.start()
except Exception:
messages.append(traceback.format_exc())
Expand Down
4 changes: 2 additions & 2 deletions lib/ansible/executor/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def _get_connection(self, variables, templar):

conn_type = self._play_context.connection

connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin)
connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin, ansible_playbook_pid=to_text(os.getppid()))
if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type)

Expand Down Expand Up @@ -800,7 +800,7 @@ def _start_connection(self):
Starts the persistent connection
'''
master, slave = pty.openpty()
p = subprocess.Popen(["ansible-connection"], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
p = subprocess.Popen(["ansible-connection", to_text(os.getppid())], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdin = os.fdopen(master, 'wb', 0)
os.close(slave)

Expand Down
5 changes: 4 additions & 1 deletion lib/ansible/plugins/connection/network_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __init__(self, play_context, new_stdin, *args, **kwargs):
self._terminal = None
self._cliconf = None

self._ansible_playbook_pid = kwargs.get('ansible_playbook_pid')

if self._play_context.verbosity > 3:
logging.getLogger('paramiko').setLevel(logging.DEBUG)

Expand Down Expand Up @@ -220,7 +222,8 @@ def _update_connection_state(self):
value to None and the _connected value to False
'''
ssh = connection_loader.get('ssh', class_only=True)
cp = ssh._create_control_path(self._play_context.remote_addr, self._play_context.port, self._play_context.remote_user, self._play_context.connection)
cp = ssh._create_control_path(self._play_context.remote_addr, self._play_context.port, self._play_context.remote_user, self._play_context.connection,
self._ansible_playbook_pid)

tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
socket_path = unfrackpath(cp % dict(directory=tmp_path))
Expand Down
2 changes: 1 addition & 1 deletion lib/ansible/plugins/connection/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _start_connection(self):
Starts the persistent connection
'''
master, slave = pty.openpty()
p = subprocess.Popen(["ansible-connection"], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
p = subprocess.Popen(["ansible-connection", to_text(os.getppid())], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdin = os.fdopen(master, 'wb', 0)
os.close(slave)

Expand Down
4 changes: 3 additions & 1 deletion lib/ansible/plugins/connection/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,13 @@ def _connect(self):
return self

@staticmethod
def _create_control_path(host, port, user, connection=None):
def _create_control_path(host, port, user, connection=None, pid=None):
'''Make a hash for the controlpath based on con attributes'''
pstring = '%s-%s-%s' % (host, port, user)
if connection:
pstring += '-%s' % connection
if pid:
pstring += '-%s' % to_text(pid)
m = hashlib.sha1()
m.update(to_bytes(pstring))
digest = m.hexdigest()
Expand Down

0 comments on commit 2f932d8

Please sign in to comment.