mirror of
				https://github.com/Rudd-O/ansible-qubes.git
				synced 2025-11-04 05:28:54 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			500 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			500 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
#!/usr/bin/python3 -u
 | 
						|
 | 
						|
import base64
 | 
						|
import pickle
 | 
						|
import contextlib
 | 
						|
import errno
 | 
						|
import fcntl
 | 
						|
import os
 | 
						|
 | 
						|
try:
 | 
						|
    from shlex import quote
 | 
						|
except ImportError:
 | 
						|
    from pipes import quote  # noqa
 | 
						|
 | 
						|
try:
 | 
						|
    from queue import Queue
 | 
						|
except ImportError:
 | 
						|
    from Queue import Queue  # noqa
 | 
						|
import select
 | 
						|
import signal
 | 
						|
import struct
 | 
						|
import subprocess
 | 
						|
import sys
 | 
						|
import syslog
 | 
						|
import threading
 | 
						|
import time
 | 
						|
import traceback
 | 
						|
 | 
						|
 | 
						|
MAX_MUX_READ = 128 * 1024  # 64*1024*1024
 | 
						|
PACKLEN = 8
 | 
						|
PACKFORMAT = "!HbIx"
 | 
						|
 | 
						|
 | 
						|
def set_proc_name(newname):
 | 
						|
    from ctypes import cdll, byref, create_string_buffer
 | 
						|
 | 
						|
    if isinstance(newname, str):
 | 
						|
        newname = newname.encode("utf-8")
 | 
						|
    libc = cdll.LoadLibrary("libc.so.6")
 | 
						|
    buff = create_string_buffer(len(newname) + 1)
 | 
						|
    buff.value = newname
 | 
						|
    libc.prctl(15, byref(buff), 0, 0, 0)
 | 
						|
 | 
						|
 | 
						|
@contextlib.contextmanager
 | 
						|
def mutexfile(filepath):
 | 
						|
    oldumask = os.umask(0o077)
 | 
						|
    try:
 | 
						|
        f = open(filepath, "a")
 | 
						|
    finally:
 | 
						|
        os.umask(oldumask)
 | 
						|
    fcntl.lockf(f.fileno(), fcntl.LOCK_EX)
 | 
						|
    yield
 | 
						|
    f.close()
 | 
						|
 | 
						|
 | 
						|
def unset_cloexec(fd):
 | 
						|
    old = fcntl.fcntl(fd, fcntl.F_GETFD)
 | 
						|
    fcntl.fcntl(fd, fcntl.F_SETFD, old & ~fcntl.FD_CLOEXEC)
 | 
						|
 | 
						|
 | 
						|
def openfdforappend(fd):
 | 
						|
    f = None
 | 
						|
    try:
 | 
						|
        f = os.fdopen(fd, "ab", 0)
 | 
						|
    except IOError as e:
 | 
						|
        if e.errno != errno.ESPIPE:
 | 
						|
            raise
 | 
						|
        f = os.fdopen(fd, "wb", 0)
 | 
						|
    unset_cloexec(f.fileno())
 | 
						|
    return f
 | 
						|
 | 
						|
 | 
						|
def openfdforread(fd):
 | 
						|
    f = os.fdopen(fd, "rb", 0)
 | 
						|
    unset_cloexec(f.fileno())
 | 
						|
    return f
 | 
						|
 | 
						|
 | 
						|
debug_lock = threading.Lock()
 | 
						|
debug_enabled = False
 | 
						|
_startt = time.time()
 | 
						|
 | 
						|
 | 
						|
class LoggingEmu:
 | 
						|
    def __init__(self, prefix):
 | 
						|
        self.prefix = prefix
 | 
						|
        syslog.openlog("bombshell-client.%s" % self.prefix)
 | 
						|
 | 
						|
    def debug(self, *a, **kw):
 | 
						|
        if not debug_enabled:
 | 
						|
            return
 | 
						|
        self._print(syslog.LOG_DEBUG, *a, **kw)
 | 
						|
 | 
						|
    def info(self, *a, **kw):
 | 
						|
        self._print(syslog.LOG_INFO, *a, **kw)
 | 
						|
 | 
						|
    def error(self, *a, **kw):
 | 
						|
        self._print(syslog.LOG_ERR, *a, **kw)
 | 
						|
 | 
						|
    def _print(self, prio, *a, **kw):
 | 
						|
        debug_lock.acquire()
 | 
						|
        global _startt
 | 
						|
        deltat = time.time() - _startt
 | 
						|
        try:
 | 
						|
            if len(a) == 1:
 | 
						|
                string = a[0]
 | 
						|
            else:
 | 
						|
                string = a[0] % a[1:]
 | 
						|
            n = threading.current_thread().name
 | 
						|
            syslog.syslog(
 | 
						|
                prio,
 | 
						|
                ("%.3f  " % deltat) + n + ": " + string,
 | 
						|
            )
 | 
						|
        finally:
 | 
						|
            debug_lock.release()
 | 
						|
 | 
						|
 | 
						|
logging = None
 | 
						|
 | 
						|
 | 
						|
def send_confirmation(chan, retval, errmsg):
 | 
						|
    chan.write(struct.pack("!H", retval))
 | 
						|
    ln = len(errmsg)
 | 
						|
    assert ln < 1 << 32
 | 
						|
    chan.write(struct.pack("!I", ln))
 | 
						|
    chan.write(errmsg)
 | 
						|
    chan.flush()
 | 
						|
    logging.debug(
 | 
						|
        "Sent confirmation on channel %s: %s %s",
 | 
						|
        chan,
 | 
						|
        retval,
 | 
						|
        errmsg,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def recv_confirmation(chan):
 | 
						|
    logging.debug("Waiting for confirmation on channel %s", chan)
 | 
						|
    r = chan.read(2)
 | 
						|
    if len(r) == 0:
 | 
						|
        # This happens when the remote domain does not exist.
 | 
						|
        r, errmsg = 125, "domain does not exist"
 | 
						|
        logging.debug("No confirmation: %s %s", r, errmsg)
 | 
						|
        return r, errmsg
 | 
						|
    assert len(r) == 2, r
 | 
						|
    r = struct.unpack("!H", r)[0]
 | 
						|
    lc = chan.read(4)
 | 
						|
    assert len(lc) == 4, lc
 | 
						|
    lu = struct.unpack("!I", lc)[0]
 | 
						|
    errmsg = chan.read(lu)
 | 
						|
    logging.debug("Received confirmation: %s %s", r, errmsg)
 | 
						|
    return r, errmsg
 | 
						|
 | 
						|
 | 
						|
class MyThread(threading.Thread):
 | 
						|
    def run(self):
 | 
						|
        try:
 | 
						|
            self._run()
 | 
						|
        except Exception:
 | 
						|
            n = threading.current_thread().name
 | 
						|
            logging.error("%s: unexpected exception", n)
 | 
						|
            tb = traceback.format_exc()
 | 
						|
            logging.error("%s: traceback: %s", n, tb)
 | 
						|
            logging.error("%s: exiting program", n)
 | 
						|
            os._exit(124)
 | 
						|
 | 
						|
 | 
						|
class SignalSender(MyThread):
 | 
						|
    def __init__(self, signals, sigqueue):
 | 
						|
        """Handles signals by pushing them into a file-like object."""
 | 
						|
        threading.Thread.__init__(self)
 | 
						|
        self.daemon = True
 | 
						|
        self.queue = Queue()
 | 
						|
        self.sigqueue = sigqueue
 | 
						|
        for sig in signals:
 | 
						|
            signal.signal(sig, self.copy)
 | 
						|
 | 
						|
    def copy(self, signum, frame):
 | 
						|
        self.queue.put(signum)
 | 
						|
        logging.debug("Signal %s pushed to queue", signum)
 | 
						|
 | 
						|
    def _run(self):
 | 
						|
        while True:
 | 
						|
            signum = self.queue.get()
 | 
						|
            logging.debug("Dequeued signal %s", signum)
 | 
						|
            if signum is None:
 | 
						|
                break
 | 
						|
            assert signum > 0
 | 
						|
            self.sigqueue.write(struct.pack("!H", signum))
 | 
						|
            self.sigqueue.flush()
 | 
						|
            logging.debug("Wrote signal %s to remote end", signum)
 | 
						|
 | 
						|
 | 
						|
class Signaler(MyThread):
 | 
						|
    def __init__(self, process, sigqueue):
 | 
						|
        """Reads integers from a file-like object and relays that as kill()."""
 | 
						|
        threading.Thread.__init__(self)
 | 
						|
        self.daemon = True
 | 
						|
        self.process = process
 | 
						|
        self.sigqueue = sigqueue
 | 
						|
 | 
						|
    def _run(self):
 | 
						|
        while True:
 | 
						|
            data = self.sigqueue.read(2)
 | 
						|
            if len(data) == 0:
 | 
						|
                logging.debug("Received no signal data")
 | 
						|
                break
 | 
						|
            assert len(data) == 2
 | 
						|
            signum = struct.unpack("!H", data)[0]
 | 
						|
            logging.debug(
 | 
						|
                "Received relayed signal %s, sending to process %s",
 | 
						|
                signum,
 | 
						|
                self.process.pid,
 | 
						|
            )
 | 
						|
            try:
 | 
						|
                self.process.send_signal(signum)
 | 
						|
            except BaseException as e:
 | 
						|
                logging.error(
 | 
						|
                    "Failed to relay signal %s to process %s: %s",
 | 
						|
                    signum,
 | 
						|
                    self.process.pid,
 | 
						|
                    e,
 | 
						|
                )
 | 
						|
        logging.debug("End of signaler")
 | 
						|
 | 
						|
 | 
						|
def write(dst, buffer, ln):
 | 
						|
    alreadywritten = 0
 | 
						|
    mv = memoryview(buffer)[:ln]
 | 
						|
    while len(mv):
 | 
						|
        dst.write(mv)
 | 
						|
        writtenthisloop = len(mv)
 | 
						|
        if writtenthisloop is None or writtenthisloop < 1:
 | 
						|
            raise Exception("copy: Failed to write any bytes")
 | 
						|
        mv = mv[writtenthisloop:]
 | 
						|
        alreadywritten = alreadywritten + writtenthisloop
 | 
						|
 | 
						|
 | 
						|
def copy(src, dst, buffer, ln):
 | 
						|
    alreadyread = 0
 | 
						|
    mv = memoryview(buffer)[:ln]
 | 
						|
    assert len(mv) == ln, "Buffer object is too small: %s %s" % (len(mv), ln)
 | 
						|
    while len(mv):
 | 
						|
        _, _, _ = select.select([src], (), ())
 | 
						|
        readthisloop = src.readinto(mv)
 | 
						|
        if readthisloop is None or readthisloop < 1:
 | 
						|
            raise Exception("copy: Failed to read any bytes")
 | 
						|
        mv = mv[readthisloop:]
 | 
						|
        alreadyread = alreadyread + readthisloop
 | 
						|
    return write(dst, buffer, ln)
 | 
						|
 | 
						|
 | 
						|
class DataMultiplexer(MyThread):
 | 
						|
    def __init__(self, sources, sink):
 | 
						|
        threading.Thread.__init__(self)
 | 
						|
        self.daemon = True
 | 
						|
        self.sources = dict((s, num) for num, s in enumerate(sources))
 | 
						|
        self.sink = sink
 | 
						|
 | 
						|
    def _run(self):
 | 
						|
        logging.debug(
 | 
						|
            "mux: Started with sources %s and sink %s", self.sources, self.sink
 | 
						|
        )
 | 
						|
        buffer = bytearray(MAX_MUX_READ)
 | 
						|
        while self.sources:
 | 
						|
            sources, _, x = select.select(
 | 
						|
                (s for s in self.sources), (), (s for s in self.sources)
 | 
						|
            )
 | 
						|
            assert not x, x
 | 
						|
            for s in sources:
 | 
						|
                n = self.sources[s]
 | 
						|
                logging.debug("mux: Source %s (%s) is active", n, s)
 | 
						|
                readthisloop = s.readinto(buffer)
 | 
						|
                if readthisloop == 0:
 | 
						|
                    logging.debug(
 | 
						|
                        "mux: Received no bytes from source %s, signaling"
 | 
						|
                        " peer to close corresponding source",
 | 
						|
                        n,
 | 
						|
                    )
 | 
						|
                    del self.sources[s]
 | 
						|
                    header = struct.pack(PACKFORMAT, n, False, 0)
 | 
						|
                    self.sink.write(header)
 | 
						|
                    continue
 | 
						|
                ln = readthisloop
 | 
						|
                header = struct.pack(PACKFORMAT, n, True, ln)
 | 
						|
                self.sink.write(header)
 | 
						|
                write(self.sink, buffer, ln)
 | 
						|
        logging.debug("mux: End of data multiplexer")
 | 
						|
 | 
						|
 | 
						|
class DataDemultiplexer(MyThread):
 | 
						|
    def __init__(self, source, sinks):
 | 
						|
        threading.Thread.__init__(self)
 | 
						|
        self.daemon = True
 | 
						|
        self.sinks = dict(enumerate(sinks))
 | 
						|
        self.source = source
 | 
						|
 | 
						|
    def _run(self):
 | 
						|
        logging.debug(
 | 
						|
            "demux: Started with source %s and sinks %s",
 | 
						|
            self.source,
 | 
						|
            self.sinks,
 | 
						|
        )
 | 
						|
        buffer = bytearray(MAX_MUX_READ)
 | 
						|
        while self.sinks:
 | 
						|
            r, _, x = select.select([self.source], (), [self.source])
 | 
						|
            assert not x, x
 | 
						|
            for s in r:
 | 
						|
                header = s.read(PACKLEN)
 | 
						|
                if header == b"":
 | 
						|
                    logging.debug(
 | 
						|
                        "demux: Received no bytes from source, closing sinks",
 | 
						|
                    )
 | 
						|
                    for sink in self.sinks.values():
 | 
						|
                        sink.close()
 | 
						|
                    self.sinks = []
 | 
						|
                    break
 | 
						|
                n, active, ln = struct.unpack(PACKFORMAT, header)
 | 
						|
                if not active:
 | 
						|
                    logging.debug(
 | 
						|
                        "demux: Source %s inactive, closing matching sink %s",
 | 
						|
                        s,
 | 
						|
                        self.sinks[n],
 | 
						|
                    )
 | 
						|
                    self.sinks[n].close()
 | 
						|
                    del self.sinks[n]
 | 
						|
                else:
 | 
						|
                    copy(self.source, self.sinks[n], buffer, ln)
 | 
						|
        logging.debug("demux: End of data demultiplexer")
 | 
						|
 | 
						|
 | 
						|
def quotedargs():
 | 
						|
    return " ".join(quote(x) for x in sys.argv[1:])
 | 
						|
 | 
						|
 | 
						|
def main_master():
 | 
						|
    set_proc_name("bombshell-client (master) %s" % quotedargs())
 | 
						|
    global logging
 | 
						|
    logging = LoggingEmu("master")
 | 
						|
 | 
						|
    logging.info("Started with arguments: %s", sys.argv[1:])
 | 
						|
 | 
						|
    global debug_enabled
 | 
						|
    args = sys.argv[1:]
 | 
						|
    if args[0] == "-d":
 | 
						|
        args = args[1:]
 | 
						|
        debug_enabled = True
 | 
						|
 | 
						|
    remote_vm = args[0]
 | 
						|
    remote_command = args[1:]
 | 
						|
    assert remote_command
 | 
						|
 | 
						|
    def anypython(exe):
 | 
						|
        return "` test -x %s && echo %s || echo python3`" % (
 | 
						|
            quote(exe),
 | 
						|
            quote(exe),
 | 
						|
        )
 | 
						|
 | 
						|
    remote_helper_text = b"exec "
 | 
						|
    remote_helper_text += bytes(anypython(sys.executable), "utf-8")
 | 
						|
    remote_helper_text += bytes(" -u -c ", "utf-8")
 | 
						|
    remote_helper_text += bytes(
 | 
						|
        quote(open(__file__, "r").read()),
 | 
						|
        "ascii",
 | 
						|
    )
 | 
						|
    remote_helper_text += b" -d " if debug_enabled else b" "
 | 
						|
    remote_helper_text += base64.b64encode(pickle.dumps(remote_command, 2))
 | 
						|
    remote_helper_text += b"\n"
 | 
						|
 | 
						|
    saved_stderr = openfdforappend(os.dup(sys.stderr.fileno()))
 | 
						|
 | 
						|
    with mutexfile(os.path.expanduser("~/.bombshell-lock")):
 | 
						|
        try:
 | 
						|
            p = subprocess.Popen(
 | 
						|
                ["qrexec-client-vm", remote_vm, "qubes.VMShell"],
 | 
						|
                stdin=subprocess.PIPE,
 | 
						|
                stdout=subprocess.PIPE,
 | 
						|
                close_fds=True,
 | 
						|
                preexec_fn=os.setpgrp,
 | 
						|
                bufsize=0,
 | 
						|
            )
 | 
						|
        except OSError as e:
 | 
						|
            logging.error("cannot launch qrexec-client-vm: %s", e)
 | 
						|
            return 127
 | 
						|
 | 
						|
        logging.debug("Writing the helper text into the other side")
 | 
						|
        p.stdin.write(remote_helper_text)
 | 
						|
        p.stdin.flush()
 | 
						|
 | 
						|
        confirmation, errmsg = recv_confirmation(p.stdout)
 | 
						|
        if confirmation != 0:
 | 
						|
            logging.error("remote: %s", errmsg)
 | 
						|
            return confirmation
 | 
						|
 | 
						|
    handled_signals = (
 | 
						|
        signal.SIGINT,
 | 
						|
        signal.SIGABRT,
 | 
						|
        signal.SIGALRM,
 | 
						|
        signal.SIGTERM,
 | 
						|
        signal.SIGUSR1,
 | 
						|
        signal.SIGUSR2,
 | 
						|
        signal.SIGTSTP,
 | 
						|
        signal.SIGCONT,
 | 
						|
    )
 | 
						|
    read_signals, write_signals = pairofpipes()
 | 
						|
    signaler = SignalSender(handled_signals, write_signals)
 | 
						|
    signaler.name = "master signaler"
 | 
						|
    signaler.start()
 | 
						|
 | 
						|
    muxer = DataMultiplexer([sys.stdin, read_signals], p.stdin)
 | 
						|
    muxer.name = "master multiplexer"
 | 
						|
    muxer.start()
 | 
						|
 | 
						|
    demuxer = DataDemultiplexer(p.stdout, [sys.stdout, saved_stderr])
 | 
						|
    demuxer.name = "master demultiplexer"
 | 
						|
    demuxer.start()
 | 
						|
 | 
						|
    retval = p.wait()
 | 
						|
    logging.info("Return code %s for qubes.VMShell proxy", retval)
 | 
						|
    demuxer.join()
 | 
						|
    logging.info("Ending bombshell")
 | 
						|
    return retval
 | 
						|
 | 
						|
 | 
						|
def pairofpipes():
 | 
						|
    read, write = os.pipe()
 | 
						|
    return os.fdopen(read, "rb", 0), os.fdopen(write, "wb", 0)
 | 
						|
 | 
						|
 | 
						|
def main_remote():
 | 
						|
    set_proc_name("bombshell-client (remote) %s" % quotedargs())
 | 
						|
    global logging
 | 
						|
    logging = LoggingEmu("remote")
 | 
						|
 | 
						|
    logging.info("Started with arguments: %s", sys.argv[1:])
 | 
						|
 | 
						|
    global debug_enabled
 | 
						|
    if "-d" in sys.argv[1:]:
 | 
						|
        debug_enabled = True
 | 
						|
        cmd = sys.argv[2]
 | 
						|
    else:
 | 
						|
        cmd = sys.argv[1]
 | 
						|
 | 
						|
    cmd = pickle.loads(base64.b64decode(cmd))
 | 
						|
    logging.debug("Received command: %s", cmd)
 | 
						|
 | 
						|
    nicecmd = " ".join(quote(a) for a in cmd)
 | 
						|
    try:
 | 
						|
        p = subprocess.Popen(
 | 
						|
            cmd,
 | 
						|
            #      ["strace", "-s4096", "-ff"] + cmd,
 | 
						|
            stdin=subprocess.PIPE,
 | 
						|
            stdout=subprocess.PIPE,
 | 
						|
            stderr=subprocess.PIPE,
 | 
						|
            close_fds=True,
 | 
						|
            bufsize=0,
 | 
						|
        )
 | 
						|
        send_confirmation(sys.stdout, 0, b"")
 | 
						|
    except OSError as e:
 | 
						|
        msg = "cannot execute %s: %s" % (nicecmd, e)
 | 
						|
        logging.error(msg)
 | 
						|
        send_confirmation(sys.stdout, 127, bytes(msg, "utf-8"))
 | 
						|
        sys.exit(0)
 | 
						|
    except BaseException as e:
 | 
						|
        msg = "cannot execute %s: %s" % (nicecmd, e)
 | 
						|
        logging.error(msg)
 | 
						|
        send_confirmation(sys.stdout, 126, bytes(msg, "utf-8"))
 | 
						|
        sys.exit(0)
 | 
						|
 | 
						|
    signals_read, signals_written = pairofpipes()
 | 
						|
 | 
						|
    signaler = Signaler(p, signals_read)
 | 
						|
    signaler.name = "remote signaler"
 | 
						|
    signaler.start()
 | 
						|
 | 
						|
    demuxer = DataDemultiplexer(sys.stdin, [p.stdin, signals_written])
 | 
						|
    demuxer.name = "remote demultiplexer"
 | 
						|
    demuxer.start()
 | 
						|
 | 
						|
    muxer = DataMultiplexer([p.stdout, p.stderr], sys.stdout)
 | 
						|
    muxer.name = "remote multiplexer"
 | 
						|
    muxer.start()
 | 
						|
 | 
						|
    logging.info("Started %s", nicecmd)
 | 
						|
 | 
						|
    retval = p.wait()
 | 
						|
    logging.info("Return code %s for %s", retval, nicecmd)
 | 
						|
    muxer.join()
 | 
						|
    logging.info("Ending bombshell")
 | 
						|
    return retval
 | 
						|
 | 
						|
 | 
						|
sys.stdin = openfdforread(sys.stdin.fileno())
 | 
						|
sys.stdout = openfdforappend(sys.stdout.fileno())
 | 
						|
if "__file__" in locals() and not ("-s" in sys.argv[1:2]):
 | 
						|
    sys.exit(main_master())
 | 
						|
else:
 | 
						|
    sys.exit(main_remote())
 |