Fix connection plugin tests.

This commit is contained in:
Rudd-O 2020-05-05 10:53:09 +00:00
parent c2f006868c
commit eb8d0ab162
2 changed files with 81 additions and 59 deletions

View File

@ -450,30 +450,30 @@ class Connection(ConnectionBase):
super(Connection, self).fetch_file(in_path, out_path) super(Connection, self).fetch_file(in_path, out_path)
display.vvvv("FETCH %s to %s" % (in_path, out_path), host=self._play_context.remote_addr) display.vvvv("FETCH %s to %s" % (in_path, out_path), host=self._play_context.remote_addr)
in_path = _prefix_login_path(in_path) in_path = _prefix_login_path(in_path)
out_file = open(out_path, "wb") with open(out_path, "wb") as out_file:
try: try:
payload = 'fetch(%r, %r)\n' % (in_path, BUFSIZE) payload = 'fetch(%r, %r)\n' % (in_path, BUFSIZE)
self._transport.stdin.write(payload.encode("utf-8")) self._transport.stdin.write(payload.encode("utf-8"))
self._transport.stdin.flush() self._transport.stdin.flush()
while True: while True:
chunk_len = self._transport.stdout.readline(16) chunk_len = self._transport.stdout.readline(16)
try: try:
chunk_len = int(chunk_len) chunk_len = int(chunk_len)
except Exception: except Exception:
if chunk_len == "N\n": if chunk_len == "N\n":
exc = decode_exception(self._transport.stdin) exc = decode_exception(self._transport.stdin)
raise exc raise exc
else: else:
self._abort_transport() self._abort_transport()
raise errors.AnsibleError("chunk size from remote end is unexpected: %r" % chunk_len) raise errors.AnsibleError("chunk size from remote end is unexpected: %r" % chunk_len)
if chunk_len > BUFSIZE or chunk_len < 0: if chunk_len > BUFSIZE or chunk_len < 0:
raise errors.AnsibleError("chunk size from remote end is invalid: %r" % chunk_len) raise errors.AnsibleError("chunk size from remote end is invalid: %r" % chunk_len)
if chunk_len == 0: if chunk_len == 0:
break break
chunk = self._transport.stdout.read(chunk_len) chunk = self._transport.stdout.read(chunk_len)
if len(chunk) != chunk_len: if len(chunk) != chunk_len:
raise errors.AnsibleError("stderr size from remote end does not match actual stderr length: %s != %s" % (chunk_len, len(chunk))) raise errors.AnsibleError("stderr size from remote end does not match actual stderr length: %s != %s" % (chunk_len, len(chunk)))
out_file.write(chunk) out_file.write(chunk)
except Exception: except Exception:
self._abort_transport() self._abort_transport()
raise raise

View File

@ -1,81 +1,103 @@
import sys, os ; sys.path.append(os.path.dirname(__file__)) import sys, os ; sys.path.append(os.path.dirname(__file__))
import StringIO import contextlib
import qubes try:
from StringIO import StringIO
BytesIO = StringIO
except ImportError:
from io import StringIO, BytesIO
import unittest import unittest
import tempfile import tempfile
import qubes
cases = [
(['true'], '', 'Y\n0\n0\n0\n'),
(['false'], '', 'Y\n1\n0\n0\n'), if sys.version_info.major == 3:
(['sh', '-c', 'echo yes'], '', 'Y\n0\n4\nyes\n0\n'), cases = [
(['sh', '-c', 'echo yes >&2'], '', 'Y\n0\n0\n4\nyes\n'), (['true'], '', b'Y\n0\n0\n0\n'),
] (['false'], '', b'Y\n1\n0\n0\n'),
cases_with_harness = [ (['sh', '-c', 'echo yes'], '', b'Y\n0\n4\nyes\n0\n'),
(['true'], '', 0, '', ''), (['sh', '-c', 'echo yes >&2'], '', b'Y\n0\n0\n4\nyes\n'),
(['false'], '', 1, '', ''), ]
(['sh', '-c', 'echo yes'], '', 0, 'yes\n', ''), cases_with_harness = [
(['sh', '-c', 'echo yes >&2'], '', 0, '', 'yes\n'), (['true'], '', 0, '', ''),
] (['false'], '', 1, '', ''),
(['sh', '-c', 'echo yes'], '', 0, b'yes\n', ''),
(['sh', '-c', 'echo yes >&2'], '', 0, '', b'yes\n'),
]
else:
cases = []
cases_with_harness = []
class MockPlayContext(object): class MockPlayContext(object):
shell = 'sh' shell = 'sh'
executable = 'sh'
become = False become = False
become_method = 'sudo' become_method = 'sudo'
remote_addr = '127.0.0.7' remote_addr = '127.0.0.7'
@contextlib.contextmanager
def local_connection(): def local_connection():
c = qubes.Connection( c = qubes.Connection(
MockPlayContext(), None, MockPlayContext(), None,
transport_cmd=['sh', '-c', '"$@"'] transport_cmd=['sh', '-c', '"$@"']
) )
c._options = {"management_proxy": None} c._options = {"management_proxy": None}
return c try:
yield c
finally:
c.close()
class TestBasicThings(unittest.TestCase): class TestBasicThings(unittest.TestCase):
def test_popen(self): def test_popen(self):
for cmd, in_, out in cases: for cmd, in_, out in cases:
outf = StringIO.StringIO() outf = BytesIO()
qubes.popen(cmd, in_, outf=outf) qubes.popen(cmd, in_, outf=outf)
self.assertMultiLineEqual( self.assertEqual(
outf.getvalue(), outf.getvalue(),
out out
) )
def test_exec_command_with_harness(self): def test_exec_command_with_harness(self):
for cmd, in_, ret, out, err in cases_with_harness: for cmd, in_, ret, out, err in cases_with_harness:
c = local_connection() with local_connection() as c:
retcode, stdout, stderr = c.exec_command(cmd) retcode, stdout, stderr = c.exec_command(cmd)
self.assertEqual(ret, retcode) self.assertEqual(ret, retcode)
self.assertMultiLineEqual(out, stdout) self.assertEqual(out, stdout)
self.assertMultiLineEqual(err, stderr) self.assertEqual(err, stderr)
c.close()
self.assertEqual(c._transport, None) self.assertEqual(c._transport, None)
def test_fetch_file_with_harness(self): def test_fetch_file_with_harness(self):
in_text = "abcd" if sys.version_info.major == 2:
in_text = "abcd"
else:
in_text = b"abcd"
with tempfile.NamedTemporaryFile() as x: with tempfile.NamedTemporaryFile() as x:
x.write(in_text) x.write(in_text)
x.flush() x.flush()
with tempfile.NamedTemporaryFile() as y: with tempfile.NamedTemporaryFile() as y:
c = local_connection() with local_connection() as c:
c.fetch_file(in_path=x.name, out_path=y.name) c.fetch_file(in_path=x.name, out_path=y.name)
out_text = y.read() y.seek(0)
out_text = y.read()
self.assertEqual(in_text, out_text) self.assertEqual(in_text, out_text)
def test_put_file_with_harness(self): def test_put_file_with_harness(self):
in_text = "abcd" if sys.version_info.major == 2:
in_text = "abcd"
else:
in_text = b"abcd"
with tempfile.NamedTemporaryFile() as x: with tempfile.NamedTemporaryFile() as x:
x.write(in_text) x.write(in_text)
x.flush() x.flush()
with tempfile.NamedTemporaryFile() as y: with tempfile.NamedTemporaryFile() as y:
c = local_connection() with local_connection() as c:
c.put_file(in_path=x.name, out_path=y.name) c.put_file(in_path=x.name, out_path=y.name)
out_text = y.read() y.seek(0)
out_text = y.read()
self.assertEqual(in_text, out_text) self.assertEqual(in_text, out_text)