ansible-qubes/bin/bombshell-client
2022-08-18 15:37:12 +00:00

500 lines
14 KiB
Python
Executable File

#!/usr/bin/python3 -u
import base64
import pickle
import contextlib
import errno
import fcntl
import os
try:
from shlex import quote
except ImportError:
from pipes import quote # noqa
try:
from queue import Queue
except ImportError:
from Queue import Queue # noqa
import select
import signal
import struct
import subprocess
import sys
import syslog
import threading
import time
import traceback
MAX_MUX_READ = 128 * 1024 # 64*1024*1024
PACKLEN = 8
PACKFORMAT = "!HbIx"
def set_proc_name(newname):
from ctypes import cdll, byref, create_string_buffer
if isinstance(newname, str):
newname = newname.encode("utf-8")
libc = cdll.LoadLibrary("libc.so.6")
buff = create_string_buffer(len(newname) + 1)
buff.value = newname
libc.prctl(15, byref(buff), 0, 0, 0)
@contextlib.contextmanager
def mutexfile(filepath):
oldumask = os.umask(0o077)
try:
f = open(filepath, "a")
finally:
os.umask(oldumask)
fcntl.lockf(f.fileno(), fcntl.LOCK_EX)
yield
f.close()
def unset_cloexec(fd):
old = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, old & ~fcntl.FD_CLOEXEC)
def openfdforappend(fd):
f = None
try:
f = os.fdopen(fd, "ab", 0)
except IOError as e:
if e.errno != errno.ESPIPE:
raise
f = os.fdopen(fd, "wb", 0)
unset_cloexec(f.fileno())
return f
def openfdforread(fd):
f = os.fdopen(fd, "rb", 0)
unset_cloexec(f.fileno())
return f
debug_lock = threading.Lock()
debug_enabled = False
_startt = time.time()
class LoggingEmu:
def __init__(self, prefix):
self.prefix = prefix
syslog.openlog("bombshell-client.%s" % self.prefix)
def debug(self, *a, **kw):
if not debug_enabled:
return
self._print(syslog.LOG_DEBUG, *a, **kw)
def info(self, *a, **kw):
self._print(syslog.LOG_INFO, *a, **kw)
def error(self, *a, **kw):
self._print(syslog.LOG_ERR, *a, **kw)
def _print(self, prio, *a, **kw):
debug_lock.acquire()
global _startt
deltat = time.time() - _startt
try:
if len(a) == 1:
string = a[0]
else:
string = a[0] % a[1:]
n = threading.current_thread().name
syslog.syslog(
prio,
("%.3f " % deltat) + n + ": " + string,
)
finally:
debug_lock.release()
logging = None
def send_confirmation(chan, retval, errmsg):
chan.write(struct.pack("!H", retval))
ln = len(errmsg)
assert ln < 1 << 32
chan.write(struct.pack("!I", ln))
chan.write(errmsg)
chan.flush()
logging.debug(
"Sent confirmation on channel %s: %s %s",
chan,
retval,
errmsg,
)
def recv_confirmation(chan):
logging.debug("Waiting for confirmation on channel %s", chan)
r = chan.read(2)
if len(r) == 0:
# This happens when the remote domain does not exist.
r, errmsg = 125, "domain does not exist"
logging.debug("No confirmation: %s %s", r, errmsg)
return r, errmsg
assert len(r) == 2, r
r = struct.unpack("!H", r)[0]
lc = chan.read(4)
assert len(lc) == 4, lc
lu = struct.unpack("!I", lc)[0]
errmsg = chan.read(lu)
logging.debug("Received confirmation: %s %s", r, errmsg)
return r, errmsg
class MyThread(threading.Thread):
def run(self):
try:
self._run()
except Exception:
n = threading.current_thread().name
logging.error("%s: unexpected exception", n)
tb = traceback.format_exc()
logging.error("%s: traceback: %s", n, tb)
logging.error("%s: exiting program", n)
os._exit(124)
class SignalSender(MyThread):
def __init__(self, signals, sigqueue):
"""Handles signals by pushing them into a file-like object."""
threading.Thread.__init__(self)
self.daemon = True
self.queue = Queue()
self.sigqueue = sigqueue
for sig in signals:
signal.signal(sig, self.copy)
def copy(self, signum, frame):
self.queue.put(signum)
logging.debug("Signal %s pushed to queue", signum)
def _run(self):
while True:
signum = self.queue.get()
logging.debug("Dequeued signal %s", signum)
if signum is None:
break
assert signum > 0
self.sigqueue.write(struct.pack("!H", signum))
self.sigqueue.flush()
logging.debug("Wrote signal %s to remote end", signum)
class Signaler(MyThread):
def __init__(self, process, sigqueue):
"""Reads integers from a file-like object and relays that as kill()."""
threading.Thread.__init__(self)
self.daemon = True
self.process = process
self.sigqueue = sigqueue
def _run(self):
while True:
data = self.sigqueue.read(2)
if len(data) == 0:
logging.debug("Received no signal data")
break
assert len(data) == 2
signum = struct.unpack("!H", data)[0]
logging.debug(
"Received relayed signal %s, sending to process %s",
signum,
self.process.pid,
)
try:
self.process.send_signal(signum)
except BaseException as e:
logging.error(
"Failed to relay signal %s to process %s: %s",
signum,
self.process.pid,
e,
)
logging.debug("End of signaler")
def write(dst, buffer, ln):
alreadywritten = 0
mv = memoryview(buffer)[:ln]
while len(mv):
dst.write(mv)
writtenthisloop = len(mv)
if writtenthisloop is None or writtenthisloop < 1:
raise Exception("copy: Failed to write any bytes")
mv = mv[writtenthisloop:]
alreadywritten = alreadywritten + writtenthisloop
def copy(src, dst, buffer, ln):
alreadyread = 0
mv = memoryview(buffer)[:ln]
assert len(mv) == ln, "Buffer object is too small: %s %s" % (len(mv), ln)
while len(mv):
_, _, _ = select.select([src], (), ())
readthisloop = src.readinto(mv)
if readthisloop is None or readthisloop < 1:
raise Exception("copy: Failed to read any bytes")
mv = mv[readthisloop:]
alreadyread = alreadyread + readthisloop
return write(dst, buffer, ln)
class DataMultiplexer(MyThread):
def __init__(self, sources, sink):
threading.Thread.__init__(self)
self.daemon = True
self.sources = dict((s, num) for num, s in enumerate(sources))
self.sink = sink
def _run(self):
logging.debug(
"mux: Started with sources %s and sink %s", self.sources, self.sink
)
buffer = bytearray(MAX_MUX_READ)
while self.sources:
sources, _, x = select.select(
(s for s in self.sources), (), (s for s in self.sources)
)
assert not x, x
for s in sources:
n = self.sources[s]
logging.debug("mux: Source %s (%s) is active", n, s)
readthisloop = s.readinto(buffer)
if readthisloop == 0:
logging.debug(
"mux: Received no bytes from source %s, signaling"
" peer to close corresponding source",
n,
)
del self.sources[s]
header = struct.pack(PACKFORMAT, n, False, 0)
self.sink.write(header)
continue
ln = readthisloop
header = struct.pack(PACKFORMAT, n, True, ln)
self.sink.write(header)
write(self.sink, buffer, ln)
logging.debug("mux: End of data multiplexer")
class DataDemultiplexer(MyThread):
def __init__(self, source, sinks):
threading.Thread.__init__(self)
self.daemon = True
self.sinks = dict(enumerate(sinks))
self.source = source
def _run(self):
logging.debug(
"demux: Started with source %s and sinks %s",
self.source,
self.sinks,
)
buffer = bytearray(MAX_MUX_READ)
while self.sinks:
r, _, x = select.select([self.source], (), [self.source])
assert not x, x
for s in r:
header = s.read(PACKLEN)
if header == b"":
logging.debug(
"demux: Received no bytes from source, closing sinks",
)
for sink in self.sinks.values():
sink.close()
self.sinks = []
break
n, active, ln = struct.unpack(PACKFORMAT, header)
if not active:
logging.debug(
"demux: Source %s inactive, closing matching sink %s",
s,
self.sinks[n],
)
self.sinks[n].close()
del self.sinks[n]
else:
copy(self.source, self.sinks[n], buffer, ln)
logging.debug("demux: End of data demultiplexer")
def quotedargs():
return " ".join(quote(x) for x in sys.argv[1:])
def main_master():
set_proc_name("bombshell-client (master) %s" % quotedargs())
global logging
logging = LoggingEmu("master")
logging.info("Started with arguments: %s", sys.argv[1:])
global debug_enabled
args = sys.argv[1:]
if args[0] == "-d":
args = args[1:]
debug_enabled = True
remote_vm = args[0]
remote_command = args[1:]
assert remote_command
def anypython(exe):
return "` test -x %s && echo %s || echo python`" % (
quote(exe),
quote(exe),
)
remote_helper_text = b"exec "
remote_helper_text += bytes(anypython(sys.executable), "utf-8")
remote_helper_text += bytes(" -u -c ", "utf-8")
remote_helper_text += bytes(
quote(open(__file__, "r").read()),
"ascii",
)
remote_helper_text += b" -d " if debug_enabled else b" "
remote_helper_text += base64.b64encode(pickle.dumps(remote_command, 2))
remote_helper_text += b"\n"
saved_stderr = openfdforappend(os.dup(sys.stderr.fileno()))
with mutexfile(os.path.expanduser("~/.bombshell-lock")):
try:
p = subprocess.Popen(
["qrexec-client-vm", remote_vm, "qubes.VMShell"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
close_fds=True,
preexec_fn=os.setpgrp,
bufsize=0,
)
except OSError as e:
logging.error("cannot launch qrexec-client-vm: %s", e)
return 127
logging.debug("Writing the helper text into the other side")
p.stdin.write(remote_helper_text)
p.stdin.flush()
confirmation, errmsg = recv_confirmation(p.stdout)
if confirmation != 0:
logging.error("remote: %s", errmsg)
return confirmation
handled_signals = (
signal.SIGINT,
signal.SIGABRT,
signal.SIGALRM,
signal.SIGTERM,
signal.SIGUSR1,
signal.SIGUSR2,
signal.SIGTSTP,
signal.SIGCONT,
)
read_signals, write_signals = pairofpipes()
signaler = SignalSender(handled_signals, write_signals)
signaler.name = "master signaler"
signaler.start()
muxer = DataMultiplexer([sys.stdin, read_signals], p.stdin)
muxer.name = "master multiplexer"
muxer.start()
demuxer = DataDemultiplexer(p.stdout, [sys.stdout, saved_stderr])
demuxer.name = "master demultiplexer"
demuxer.start()
retval = p.wait()
logging.info("Return code %s for qubes.VMShell proxy", retval)
demuxer.join()
logging.info("Ending bombshell")
return retval
def pairofpipes():
read, write = os.pipe()
return os.fdopen(read, "rb", 0), os.fdopen(write, "wb", 0)
def main_remote():
set_proc_name("bombshell-client (remote) %s" % quotedargs())
global logging
logging = LoggingEmu("remote")
logging.info("Started with arguments: %s", sys.argv[1:])
global debug_enabled
if "-d" in sys.argv[1:]:
debug_enabled = True
cmd = sys.argv[2]
else:
cmd = sys.argv[1]
cmd = pickle.loads(base64.b64decode(cmd))
logging.debug("Received command: %s", cmd)
nicecmd = " ".join(quote(a) for a in cmd)
try:
p = subprocess.Popen(
cmd,
# ["strace", "-s4096", "-ff"] + cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=True,
bufsize=0,
)
send_confirmation(sys.stdout, 0, b"")
except OSError as e:
msg = "cannot execute %s: %s" % (nicecmd, e)
logging.error(msg)
send_confirmation(sys.stdout, 127, bytes(msg, "utf-8"))
sys.exit(0)
except BaseException as e:
msg = "cannot execute %s: %s" % (nicecmd, e)
logging.error(msg)
send_confirmation(sys.stdout, 126, bytes(msg, "utf-8"))
sys.exit(0)
signals_read, signals_written = pairofpipes()
signaler = Signaler(p, signals_read)
signaler.name = "remote signaler"
signaler.start()
demuxer = DataDemultiplexer(sys.stdin, [p.stdin, signals_written])
demuxer.name = "remote demultiplexer"
demuxer.start()
muxer = DataMultiplexer([p.stdout, p.stderr], sys.stdout)
muxer.name = "remote multiplexer"
muxer.start()
logging.info("Started %s", nicecmd)
retval = p.wait()
logging.info("Return code %s for %s", retval, nicecmd)
muxer.join()
logging.info("Ending bombshell")
return retval
sys.stdin = openfdforread(sys.stdin.fileno())
sys.stdout = openfdforappend(sys.stdout.fileno())
if "__file__" in locals() and not ("-s" in sys.argv[1:2]):
sys.exit(main_master())
else:
sys.exit(main_remote())