Fix leading slashes being stripped from mount src (#24013)

* Tidy mount module for testing

Fix spelling mistakes in comments. I *think* the example for omitting parents
root has the wrong parent ID.

Make mountinfo file a parameter for testing.

* Don't strip leading slash from mounts

The current code does not follow the example, it produces src=tmp/aaa instead
of src=/tmp/aaa. This causes problems with bind mounts under /rootfs.

* Use dictionary to store mounts by ID

Instead of looping over each one to check if the ID matches. This does not
preserve the order of the output on < Python3.6, but that is not necessary.

* Make linux_mounts a dict

Always accessed by 'dst', so avoid looping by just making it a key.

* Add test case for get_linux_mounts
This commit is contained in:
Kai 2017-08-29 14:16:53 +01:00 committed by Martin Krizek
parent a914a39975
commit 3251aecd95
2 changed files with 60 additions and 49 deletions

View file

@ -446,19 +446,12 @@ def is_bind_mounted(module, linux_mounts, dest, src=None, fstype=None):
if get_platform() == 'Linux' and linux_mounts is not None:
if src is None:
# That's for unmounted/absent
for m in linux_mounts:
if m['dst'] == dest:
is_mounted = True
else:
mounted_src = None
for m in linux_mounts:
if m['dst'] == dest:
mounted_src = m['src']
# That's for mounted
if mounted_src is not None and mounted_src == src:
if dest in linux_mounts:
is_mounted = True
else:
if dest in linux_mounts:
is_mounted = linux_mounts[dest]['src'] == src
else:
bin_path = module.get_bin_path('mount', required=True)
cmd = '%s -l' % bin_path
@ -483,11 +476,9 @@ def is_bind_mounted(module, linux_mounts, dest, src=None, fstype=None):
return is_mounted
def get_linux_mounts(module):
def get_linux_mounts(module, mntinfo_file="/proc/self/mountinfo"):
"""Gather mount information"""
mntinfo_file = "/proc/self/mountinfo"
try:
f = open(mntinfo_file)
except IOError:
@ -500,7 +491,7 @@ def get_linux_mounts(module):
except IOError:
module.fail_json(msg="Cannot close file %s" % mntinfo_file)
mntinfo = []
mntinfo = {}
for line in lines:
fields = line.split()
@ -515,40 +506,35 @@ def get_linux_mounts(module):
'src': fields[-2]
}
mntinfo.append(record)
mntinfo[record['id']] = record
mounts = []
mounts = {}
for mnt in mntinfo:
src = mnt['src']
for mnt in mntinfo.values():
if mnt['parent_id'] != 1 and mnt['parent_id'] in mntinfo:
m = mntinfo[mnt['parent_id']]
if (
len(m['root']) > 1 and
mnt['root'].startswith("%s/" % m['root'])):
# Ommit the parent's root in the child's root
# == Example:
# 140 136 253:2 /rootfs / rw - ext4 /dev/sdb2 rw
# 141 140 253:2 /rootfs/tmp/aaa /tmp/bbb rw - ext4 /dev/sdb2 rw
# == Expected result:
# src=/tmp/aaa
mnt['root'] = mnt['root'][len(m['root']):]
if mnt['parent_id'] != 1:
# Find parent
for m in mntinfo:
if mnt['parent_id'] == m['id']:
if (
len(m['root']) > 1 and
mnt['root'].startswith("%s/" % m['root'])):
# Ommit the parent's root in the child's root
# == Example:
# 204 136 253:2 /rootfs / rw - ext4 /dev/sdb2 rw
# 141 140 253:2 /rootfs/tmp/aaa /tmp/bbb rw - ext4 /dev/sdb2 rw
# == Expected result:
# src=/tmp/aaa
mnt['root'] = mnt['root'][len(m['root']) + 1:]
# Prepend the parent's dst to the child's root
# == Example:
# 42 60 0:35 / /tmp rw - tmpfs tmpfs rw
# 78 42 0:35 /aaa /tmp/bbb rw - tmpfs tmpfs rw
# == Expected result:
# src=/tmp/aaa
if m['dst'] != '/':
mnt['root'] = "%s%s" % (m['dst'], mnt['root'])
src = mnt['root']
break
# Prepend the parent's dst to the child's root
# == Example:
# 42 60 0:35 / /tmp rw - tmpfs tmpfs rw
# 78 42 0:35 /aaa /tmp/bbb rw - tmpfs tmpfs rw
# == Expected result:
# src=/tmp/aaa
if m['dst'] != '/':
mnt['root'] = "%s%s" % (m['dst'], mnt['root'])
src = mnt['root']
else:
src = mnt['src']
record = {
'dst': mnt['dst'],
@ -557,7 +543,7 @@ def get_linux_mounts(module):
'fs': mnt['fs']
}
mounts.append(record)
mounts[mnt['dst']] = record
return mounts
@ -618,7 +604,7 @@ def main():
linux_mounts = []
# Cache all mounts here in order we have consistent results if we need to
# call is_bind_mouted() multiple times
# call is_bind_mounted() multiple times
if get_platform() == 'Linux':
linux_mounts = get_linux_mounts(module)

View file

@ -0,0 +1,25 @@
import os
import tempfile
from ansible.compat.tests import unittest
from ansible.module_utils._text import to_bytes
from ansible.modules.system.mount import get_linux_mounts
class LinuxMountsTestCase(unittest.TestCase):
def _create_file(self, content):
tmp_file = tempfile.NamedTemporaryFile(prefix='ansible-test-', delete=False)
tmp_file.write(to_bytes(content))
tmp_file.close()
self.addCleanup(os.unlink, tmp_file.name)
return tmp_file.name
def test_code_comment(self):
path = self._create_file(
'140 136 253:2 /rootfs / rw - ext4 /dev/sdb2 rw\n'
'141 140 253:2 /rootfs/tmp/aaa /tmp/bbb rw - ext4 /dev/sdb2 rw\n'
)
mounts = get_linux_mounts(None, path)
self.assertEqual(mounts['/tmp/bbb']['src'], '/tmp/aaa')