ansible-qubes/bin/bombshell-client

369 lines
9.8 KiB
Python
Executable File

#!/usr/bin/python -u
import cPickle
import contextlib
import fcntl
import os
import pipes
import Queue
import select
import signal
import struct
import subprocess
import sys
import threading
@contextlib.contextmanager
def mutexfile(filepath):
oldumask = os.umask(0077)
try:
f = file(filepath, "a")
finally:
os.umask(oldumask)
fcntl.lockf(f.fileno(), fcntl.LOCK_EX)
yield
f.close()
debug_lock = threading.Lock()
debug_enabled = False
class LoggingEmu():
def __init__(self, prefix):
self.prefix = prefix
def debug(self, *a, **kw):
if not debug_enabled:
return
self._print(*a, **kw)
def info(self, *a, **kw):
self._print(*a, **kw)
def error(self, *a, **kw):
self._print(*a, **kw)
def _print(self, *a, **kw):
debug_lock.acquire()
try:
if len(a) == 1:
string = a[0]
else:
string = a[0] % a[1:]
print >> sys.stderr, self.prefix, string
finally:
debug_lock.release()
logging = LoggingEmu("master:" if "__file__" in globals() else "remote:")
def send_command(chan, cmd):
"""Sends a command over the wire.
Args:
chan: writable file-like object to send the command to
cmd: command to send, iterable of strings
"""
pickled = cPickle.dumps(cmd)
l = len(pickled)
assert l < 1<<32, l
chan.write(struct.pack("!I", l))
chan.write(pickled)
chan.flush()
logging.debug("Sent command: %s", cmd)
def recv_command(chan):
l = struct.unpack("!I", sys.stdin.read(4))[0]
pickled = chan.read(l)
cmd = cPickle.loads(pickled)
logging.debug("Received command: %s", cmd)
return cmd
def send_confirmation(chan, retval, errmsg):
chan.write(struct.pack("!H", retval))
l = len(errmsg)
assert l < 1<<32
chan.write(struct.pack("!I", l))
chan.write(errmsg)
chan.flush()
logging.debug("Sent confirmation: %s %s", retval, errmsg)
def recv_confirmation(chan):
logging.debug("Waiting for confirmation")
r = chan.read(2)
if len(r) == 0:
# This happens when the remote domain does not exist.
r, errmsg = 125, "domain does not exist"
logging.debug("No confirmation: %s %s", r, errmsg)
return r, errmsg
assert len(r) == 2, r
r = struct.unpack("!H", r)[0]
l = chan.read(4)
assert len(l) == 4, l
l = struct.unpack("!I", l)[0]
errmsg = chan.read(l)
logging.debug("Received confirmation: %s %s", r, errmsg)
return r, errmsg
class SignalSender(threading.Thread):
def __init__(self, signals, sigqueue):
"""Handles signals by pushing them into a file-like object."""
threading.Thread.__init__(self)
self.setDaemon(True)
self.queue = Queue.Queue()
self.sigqueue = sigqueue
for sig in signals:
signal.signal(sig, self.copy)
def copy(self, signum, frame):
self.queue.put(signum)
logging.debug("Signal %s pushed to queue", signum)
def run(self):
while True:
signum = self.queue.get()
logging.debug("Dequeued signal %s", signum)
if signum is None:
break
assert signum > 0
self.sigqueue.write(struct.pack("!H", signum))
self.sigqueue.flush()
logging.debug("Wrote signal %s to remote end", signum)
class Signaler(threading.Thread):
def __init__(self, process, sigqueue):
"""Reads integers from a file-like object and relays that as kill()."""
threading.Thread.__init__(self)
self.setDaemon(True)
self.process = process
self.sigqueue = sigqueue
def run(self):
while True:
data = self.sigqueue.read(2)
if len(data) == 0:
logging.debug("Received no signal data")
break
assert len(data) == 2
signum = struct.unpack("!H", data)[0]
logging.debug("Received relayed signal %s, sending to process", signum)
self.process.send_signal(signum)
logging.debug("End of signaler")
def unblock(fobj):
fl = fcntl.fcntl(fobj, fcntl.F_GETFL)
fcntl.fcntl(fobj, fcntl.F_SETFL, fl | os.O_NONBLOCK)
class DataMultiplexer(threading.Thread):
def __init__(self, sources, sink):
threading.Thread.__init__(self)
self.setDaemon(True)
self.sources = dict((s,num) for num, s in enumerate(sources))
self.sink = sink
def run(self):
map(unblock, (s for s in self.sources))
sources, _, x = select.select((s for s in self.sources), (), (s for s in self.sources))
assert not x, x
while sources:
logging.debug("mux: Sources that alarmed: %s", [self.sources[s] for s in sources])
for s in sources:
n = self.sources[s]
data = s.read()
if data == "":
logging.debug("Received no bytes from source %s", n)
del self.sources[s]
self.sink.write(struct.pack("!H", n))
self.sink.write(struct.pack("b", False))
self.sink.flush()
logging.debug("Informed sink about death of source %s", n)
continue
l = len(data)
logging.debug("Received %s bytes from source %s", l, n)
assert l < 1<<32
self.sink.write(struct.pack("!H", n))
self.sink.write(struct.pack("b", True))
self.sink.write(struct.pack("!I", l))
self.sink.write(data)
self.sink.flush()
logging.debug("Copied those %s bytes to sink", l)
if not self.sources:
break
sources, _, _ = select.select((s for s in self.sources), (), (s for s in self.sources))
assert not x, x
logging.debug("End of data multiplexer")
class DataDemultiplexer(threading.Thread):
def __init__(self, source, sinks):
threading.Thread.__init__(self)
self.setDaemon(True)
self.sinks = dict(enumerate(sinks))
self.source = source
def run(self):
while self.sinks:
r, _, x = select.select([self.source], (), [self.source])
assert not x, x
logging.debug("demux: Source alarmed")
for s in r:
n = s.read(2)
if n == "":
logging.debug("Received no bytes from source, closing all sinks")
for sink in self.sinks.values():
sink.close()
self.sinks = []
break
assert len(n) == 2, data
n = struct.unpack("!H", n)[0]
active = s.read(1)
assert len(active) == 1, active
active = struct.unpack("b", active)[0]
if not active:
logging.debug("Source %s now inactive, closing corresponding sink", n)
self.sinks[n].close()
del self.sinks[n]
else:
l = s.read(4)
assert len(l) == 4, l
l = struct.unpack("!I", l)[0]
data = s.read(l)
assert len(data) == l, len(data)
logging.debug("Received %s bytes from source %s, relaying to corresponding sink", l, n)
self.sinks[n].write(data)
self.sinks[n].flush()
logging.debug("Relayed %s bytes to sink %s", l, n)
logging.debug("End of data demultiplexer")
def main_master():
global debug_enabled
args = sys.argv[1:]
if args[0] == "-d":
args = args[1:]
debug_enabled = True
remote_vm = args[0]
remote_command = args[1:]
assert remote_command
remote_helper_text = "\n".join([
"exec python -u -c '",
open(__file__, "rb").read().replace("'", "'\\''"),
"'" + (" -d" if debug_enabled else ""),
"",
])
saved_stderr = os.fdopen(os.dup(sys.stderr.fileno()), "a")
with mutexfile(os.path.expanduser("~/.bombshell-lock")):
try:
p = subprocess.Popen(
["qrexec-client-vm", remote_vm, "qubes.VMShell"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
close_fds=True,
preexec_fn=os.setpgrp,
)
except OSError, e:
logging.error("cannot launch qrexec-client-vm: %s", e)
return 127
p.stdin.write(remote_helper_text)
p.stdin.flush()
send_command(p.stdin, remote_command)
confirmation, errmsg = recv_confirmation(p.stdout)
if confirmation != 0:
logging.error("remote: %s", errmsg)
return confirmation
handled_signals = (
signal.SIGINT,
signal.SIGABRT,
signal.SIGALRM,
signal.SIGTERM,
signal.SIGUSR1,
signal.SIGUSR2,
signal.SIGTSTP,
signal.SIGCONT,
)
read_signals, write_signals = pairofpipes()
signaler = SignalSender(handled_signals, write_signals)
signaler.start()
muxer = DataMultiplexer([sys.stdin, read_signals], p.stdin)
muxer.start()
demuxer = DataDemultiplexer(p.stdout, [sys.stdout, saved_stderr])
demuxer.start()
retval = p.wait()
demuxer.join()
return retval
def pairofpipes():
read, write = os.pipe()
return os.fdopen(read, "rb"), os.fdopen(write, "wb")
def main_remote():
global debug_enabled
if len(sys.argv) > 1 and sys.argv[1] == "-d":
debug_enabled = True
cmd = recv_command(sys.stdin)
nicecmd = " ".join(pipes.quote(a) for a in cmd)
try:
p = subprocess.Popen(
cmd,
# ["strace", "-s4096", "-ff"] + cmd,
stdin = subprocess.PIPE,
stdout = subprocess.PIPE,
stderr = subprocess.PIPE,
close_fds=True,
)
send_confirmation(sys.stdout, 0, "")
except OSError, e:
msg = "cannot execute %s: %s" % (nicecmd, e)
logging.error(msg)
send_confirmation(sys.stdout, 127, msg)
sys.exit(0)
except BaseException, e:
msg = "cannot execute %s: %s" % (nicecmd, e)
logging.error(msg)
send_confirmation(sys.stdout, 126, msg)
sys.exit(0)
signals_read, signals_written = pairofpipes()
signaler = Signaler(p, signals_read)
signaler.start()
demuxer = DataDemultiplexer(sys.stdin, [p.stdin, signals_written])
demuxer.start()
muxer = DataMultiplexer([p.stdout, p.stderr], sys.stdout)
muxer.start()
logging.info("started %s", nicecmd)
retval = p.wait()
logging.info("return code %s for %s", retval, nicecmd)
muxer.join()
return retval
if "__file__" in locals():
sys.exit(main_master())
else:
sys.exit(main_remote())