nixos/test/test-driver: Class-ify the test driver

This commit encapsulates the involved domain into classes and
defines explicit and typed arguments where untyped dicts where used.

It preserves backwards compatibility through legacy wrappers.
This commit is contained in:
David Arnold 2021-06-12 17:47:25 -05:00 committed by David Arnold
parent 3069ba0dd1
commit b0fc9da879
3 changed files with 531 additions and 311 deletions

View file

@ -21,7 +21,6 @@ import shutil
import socket import socket
import subprocess import subprocess
import sys import sys
import telnetlib
import tempfile import tempfile
import time import time
import unicodedata import unicodedata
@ -89,55 +88,6 @@ CHAR_TO_KEY = {
")": "shift-0x0B", ")": "shift-0x0B",
} }
global log, machines, test_script
def eprint(*args: object, **kwargs: Any) -> None:
print(*args, file=sys.stderr, **kwargs)
def make_command(args: list) -> str:
return " ".join(map(shlex.quote, (map(str, args))))
def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
log.log("starting VDE switch for network {}".format(vlan_nr))
vde_socket = tempfile.mkdtemp(
prefix="nixos-test-vde-", suffix="-vde{}.ctl".format(vlan_nr)
)
pty_master, pty_slave = pty.openpty()
vde_process = subprocess.Popen(
["vde_switch", "-s", vde_socket, "--dirmode", "0700"],
stdin=pty_slave,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=False,
)
fd = os.fdopen(pty_master, "w")
fd.write("version\n")
# TODO: perl version checks if this can be read from
# an if not, dies. we could hang here forever. Fix it.
assert vde_process.stdout is not None
vde_process.stdout.readline()
if not os.path.exists(os.path.join(vde_socket, "ctl")):
raise Exception("cannot start vde_switch")
return (vlan_nr, vde_socket, vde_process, fd)
def retry(fn: Callable, timeout: int = 900) -> None:
"""Call the given function repeatedly, with 1 second intervals,
until it returns True or a timeout is reached.
"""
for _ in range(timeout):
if fn(False):
return
time.sleep(1)
if not fn(True):
raise Exception(f"action timed out after {timeout} seconds")
class Logger: class Logger:
def __init__(self) -> None: def __init__(self) -> None:
@ -151,6 +101,10 @@ class Logger:
self._print_serial_logs = True self._print_serial_logs = True
@staticmethod
def _eprint(*args: object, **kwargs: Any) -> None:
print(*args, file=sys.stderr, **kwargs)
def close(self) -> None: def close(self) -> None:
self.xml.endElement("logfile") self.xml.endElement("logfile")
self.xml.endDocument() self.xml.endDocument()
@ -169,15 +123,27 @@ class Logger:
self.xml.characters(message) self.xml.characters(message)
self.xml.endElement("line") self.xml.endElement("line")
def info(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
def warning(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
def error(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
sys.exit(1)
def log(self, message: str, attributes: Dict[str, str] = {}) -> None: def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
eprint(self.maybe_prefix(message, attributes)) self._eprint(self.maybe_prefix(message, attributes))
self.drain_log_queue() self.drain_log_queue()
self.log_line(message, attributes) self.log_line(message, attributes)
def log_serial(self, message: str, machine: str) -> None: def log_serial(self, message: str, machine: str) -> None:
self.enqueue({"msg": message, "machine": machine, "type": "serial"}) self.enqueue({"msg": message, "machine": machine, "type": "serial"})
if self._print_serial_logs: if self._print_serial_logs:
eprint(Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL) self._eprint(
Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL
)
def enqueue(self, item: Dict[str, str]) -> None: def enqueue(self, item: Dict[str, str]) -> None:
self.queue.put(item) self.queue.put(item)
@ -194,7 +160,7 @@ class Logger:
@contextmanager @contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
eprint(self.maybe_prefix(message, attributes)) self._eprint(self.maybe_prefix(message, attributes))
self.xml.startElement("nest", attrs={}) self.xml.startElement("nest", attrs={})
self.xml.startElement("head", attributes) self.xml.startElement("head", attributes)
@ -211,6 +177,27 @@ class Logger:
self.xml.endElement("nest") self.xml.endElement("nest")
rootlog = Logger()
def make_command(args: list) -> str:
return " ".join(map(shlex.quote, (map(str, args))))
def retry(fn: Callable, timeout: int = 900) -> None:
"""Call the given function repeatedly, with 1 second intervals,
until it returns True or a timeout is reached.
"""
for _ in range(timeout):
if fn(False):
return
time.sleep(1)
if not fn(True):
raise Exception(f"action timed out after {timeout} seconds")
def _perform_ocr_on_screenshot( def _perform_ocr_on_screenshot(
screenshot_path: str, model_ids: Iterable[int] screenshot_path: str, model_ids: Iterable[int]
) -> List[str]: ) -> List[str]:
@ -242,113 +229,256 @@ def _perform_ocr_on_screenshot(
return model_results return model_results
class StartCommand:
"""The Base Start Command knows how to append the necesary
runtime qemu options as determined by a particular test driver
run. Any such start command is expected to happily receive and
append additional qemu args.
"""
_cmd: str
def cmd(
self,
monitor_socket_path: pathlib.Path,
shell_socket_path: pathlib.Path,
allow_reboot: bool = False, # TODO: unused, legacy?
) -> str:
display_opts = ""
display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"])
if display_available:
display_opts += " -nographic"
# qemu options
qemu_opts = ""
qemu_opts += (
""
if allow_reboot
else " -no-reboot"
" -device virtio-serial"
" -device virtconsole,chardev=shell"
" -device virtio-rng-pci"
" -serial stdio"
)
# TODO: qemu script already catpures this env variable, legacy?
qemu_opts += " " + os.environ.get("QEMU_OPTS", "")
return (
f"{self._cmd}"
f" -monitor unix:{monitor_socket_path}"
f" -chardev socket,id=shell,path={shell_socket_path}"
f"{qemu_opts}"
f"{display_opts}"
)
@staticmethod
def build_environment(
state_dir: pathlib.Path,
shared_dir: pathlib.Path,
) -> dict:
# We make a copy to not update the current environment
env = dict(os.environ)
env.update(
{
"TMPDIR": str(state_dir),
"SHARED_DIR": str(shared_dir),
"USE_TMPDIR": "1",
}
)
return env
def run(
self,
state_dir: pathlib.Path,
shared_dir: pathlib.Path,
monitor_socket_path: pathlib.Path,
shell_socket_path: pathlib.Path,
) -> subprocess.Popen:
return subprocess.Popen(
self.cmd(monitor_socket_path, shell_socket_path),
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
shell=True,
cwd=state_dir,
env=self.build_environment(state_dir, shared_dir),
)
class NixStartScript(StartCommand):
"""A start script from nixos/modules/virtualiation/qemu-vm.nix
that also satisfies the requirement of the BaseStartCommand.
These Nix commands have the particular charactersitic that the
machine name can be extracted out of them via a regex match.
(Admittedly a _very_ implicit contract, evtl. TODO fix)
"""
def __init__(self, script: str):
self._cmd = script
@property
def machine_name(self) -> str:
match = re.search("run-(.+)-vm$", self._cmd)
name = "machine"
if match:
name = match.group(1)
return name
class LegacyStartCommand(StartCommand):
"""Used in some places to create an ad-hoc machine instead of
using nix test instrumentation + module system for that purpose.
Legacy.
"""
def __init__(
self,
netBackendArgs: Optional[str] = None,
netFrontendArgs: Optional[str] = None,
hda: Optional[Tuple[pathlib.Path, str]] = None,
cdrom: Optional[str] = None,
usb: Optional[str] = None,
bios: Optional[str] = None,
qemuFlags: Optional[str] = None,
):
self._cmd = "qemu-kvm -m 384"
# networking
net_backend = "-netdev user,id=net0"
net_frontend = "-device virtio-net-pci,netdev=net0"
if netBackendArgs is not None:
net_backend += "," + netBackendArgs
if netFrontendArgs is not None:
net_frontend += "," + netFrontendArgs
self._cmd += f" {net_backend} {net_frontend}"
# hda
hda_cmd = ""
if hda is not None:
hda_path = hda[0].resolve()
hda_interface = hda[1]
if hda_interface == "scsi":
hda_cmd += (
f" -drive id=hda,file={hda_path},werror=report,if=none"
" -device scsi-hd,drive=hda"
)
else:
hda_cmd += f" -drive file={hda_path},if={hda_interface},werror=report"
self._cmd += hda_cmd
# cdrom
if cdrom is not None:
self._cmd += f" -cdrom {cdrom}"
# usb
usb_cmd = ""
if usb is not None:
# https://github.com/qemu/qemu/blob/master/docs/usb2.txt
usb_cmd += (
" -device usb-ehci"
f" -drive id=usbdisk,file={usb},if=none,readonly"
" -device usb-storage,drive=usbdisk "
)
self._cmd += usb_cmd
# bios
if bios is not None:
self._cmd += f" -bios {bios}"
# qemu flags
if qemuFlags is not None:
self._cmd += f" {qemuFlags}"
class Machine: class Machine:
"""A handle to the machine with this name, that also knows how to manage
the machine lifecycle with the help of a start script / command."""
name: str
tmp_dir: pathlib.Path
shared_dir: pathlib.Path
state_dir: pathlib.Path
monitor_path: pathlib.Path
shell_path: pathlib.Path
start_command: StartCommand
keep_vm_state: bool
allow_reboot: bool
process: Optional[subprocess.Popen] = None
pid: Optional[int] = None
monitor: Optional[socket.socket] = None
shell: Optional[socket.socket] = None
booted: bool = False
connected: bool = False
# Store last serial console lines for use
# of wait_for_console_text
last_lines: Queue = Queue()
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Machine '{self.name}'>" return f"<Machine '{self.name}'>"
def __init__(self, args: Dict[str, Any]) -> None: def __init__(
if "name" in args: self,
self.name = args["name"] tmp_dir: pathlib.Path,
else: start_command: StartCommand,
self.name = "machine" name: str = "machine",
cmd = args.get("startCommand", None) keep_vm_state: bool = False,
if cmd: allow_reboot: bool = False,
match = re.search("run-(.+)-vm$", cmd) ) -> None:
if match: self.tmp_dir = tmp_dir
self.name = match.group(1) self.keep_vm_state = keep_vm_state
self.logger = args["log"] self.allow_reboot = allow_reboot
self.script = args.get("startCommand", self.create_startcommand(args)) self.name = name
self.start_command = start_command
tmp_dir = os.environ.get("TMPDIR", tempfile.gettempdir()) # set up directories
self.shared_dir = self.tmp_dir / "shared-xchg"
self.shared_dir.mkdir(mode=0o700, exist_ok=True)
def create_dir(name: str) -> str: self.state_dir = self.tmp_dir / f"vm-state-{self.name}"
path = os.path.join(tmp_dir, name) self.monitor_path = self.state_dir / "monitor"
os.makedirs(path, mode=0o700, exist_ok=True) self.shell_path = self.state_dir / "shell"
return path if (not self.keep_vm_state) and self.state_dir.exists():
self.state_dir = os.path.join(tmp_dir, f"vm-state-{self.name}")
if not args.get("keepVmState", False):
self.cleanup_statedir() self.cleanup_statedir()
os.makedirs(self.state_dir, mode=0o700, exist_ok=True) self.state_dir.mkdir(mode=0o700, exist_ok=True)
self.shared_dir = create_dir("shared-xchg")
self.booted = False
self.connected = False
self.pid: Optional[int] = None
self.socket = None
self.monitor: Optional[socket.socket] = None
self.allow_reboot = args.get("allowReboot", False)
@staticmethod @staticmethod
def create_startcommand(args: Dict[str, str]) -> str: def create_startcommand(args: Dict[str, str]) -> StartCommand:
net_backend = "-netdev user,id=net0" rootlog.warning(
net_frontend = "-device virtio-net-pci,netdev=net0" "Using legacy create_startcommand(),"
"please use proper nix test vm instrumentation, instead"
if "netBackendArgs" in args: "to generate the appropriate nixos test vm qemu startup script"
net_backend += "," + args["netBackendArgs"] )
hda = None
if "netFrontendArgs" in args: if args.get("hda"):
net_frontend += "," + args["netFrontendArgs"] hda_arg: str = args.get("hda", "")
hda_arg_path: pathlib.Path = pathlib.Path(hda_arg)
start_command = ( hda = (hda_arg_path, args.get("hdaInterface", ""))
args.get("qemuBinary", "qemu-kvm") return LegacyStartCommand(
+ " -m 384 " netBackendArgs=args.get("netBackendArgs"),
+ net_backend netFrontendArgs=args.get("netFrontendArgs"),
+ " " hda=hda,
+ net_frontend cdrom=args.get("cdrom"),
+ " $QEMU_OPTS " usb=args.get("usb"),
bios=args.get("bios"),
qemuFlags=args.get("qemuFlags"),
) )
if "hda" in args:
hda_path = os.path.abspath(args["hda"])
if args.get("hdaInterface", "") == "scsi":
start_command += (
"-drive id=hda,file="
+ hda_path
+ ",werror=report,if=none "
+ "-device scsi-hd,drive=hda "
)
else:
start_command += (
"-drive file="
+ hda_path
+ ",if="
+ args["hdaInterface"]
+ ",werror=report "
)
if "cdrom" in args:
start_command += "-cdrom " + args["cdrom"] + " "
if "usb" in args:
# https://github.com/qemu/qemu/blob/master/docs/usb2.txt
start_command += (
"-device usb-ehci -drive "
+ "id=usbdisk,file="
+ args["usb"]
+ ",if=none,readonly "
+ "-device usb-storage,drive=usbdisk "
)
if "bios" in args:
start_command += "-bios " + args["bios"] + " "
start_command += args.get("qemuFlags", "")
return start_command
def is_up(self) -> bool: def is_up(self) -> bool:
return self.booted and self.connected return self.booted and self.connected
def log(self, msg: str) -> None: def log(self, msg: str) -> None:
self.logger.log(msg, {"machine": self.name}) rootlog.log(msg, {"machine": self.name})
def log_serial(self, msg: str) -> None: def log_serial(self, msg: str) -> None:
self.logger.log_serial(msg, self.name) rootlog.log_serial(msg, self.name)
def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager: def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
my_attrs = {"machine": self.name} my_attrs = {"machine": self.name}
my_attrs.update(attrs) my_attrs.update(attrs)
return self.logger.nested(msg, my_attrs) return rootlog.nested(msg, my_attrs)
def wait_for_monitor_prompt(self) -> str: def wait_for_monitor_prompt(self) -> str:
assert self.monitor is not None assert self.monitor is not None
@ -446,6 +576,7 @@ class Machine:
self.connect() self.connect()
out_command = "( set -euo pipefail; {} ); echo '|!=EOF' $?\n".format(command) out_command = "( set -euo pipefail; {} ); echo '|!=EOF' $?\n".format(command)
assert self.shell
self.shell.send(out_command.encode()) self.shell.send(out_command.encode())
output = "" output = ""
@ -466,6 +597,8 @@ class Machine:
Should only be used during test development, not in the production test.""" Should only be used during test development, not in the production test."""
self.connect() self.connect()
self.log("Terminal is ready (there is no prompt):") self.log("Terminal is ready (there is no prompt):")
assert self.shell
subprocess.run( subprocess.run(
["socat", "READLINE", f"FD:{self.shell.fileno()}"], ["socat", "READLINE", f"FD:{self.shell.fileno()}"],
pass_fds=[self.shell.fileno()], pass_fds=[self.shell.fileno()],
@ -534,6 +667,7 @@ class Machine:
with self.nested("waiting for the VM to power off"): with self.nested("waiting for the VM to power off"):
sys.stdout.flush() sys.stdout.flush()
assert self.process
self.process.wait() self.process.wait()
self.pid = None self.pid = None
@ -611,6 +745,8 @@ class Machine:
with self.nested("waiting for the VM to finish booting"): with self.nested("waiting for the VM to finish booting"):
self.start() self.start()
assert self.shell
tic = time.time() tic = time.time()
self.shell.recv(1024) self.shell.recv(1024)
# TODO: Timeout # TODO: Timeout
@ -750,65 +886,35 @@ class Machine:
self.log("starting vm") self.log("starting vm")
def create_socket(path: str) -> socket.socket: def clear(path: pathlib.Path) -> pathlib.Path:
if os.path.exists(path): if path.exists():
os.unlink(path) path.unlink()
return path
def create_socket(path: pathlib.Path) -> socket.socket:
s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
s.bind(path) s.bind(str(path))
s.listen(1) s.listen(1)
return s return s
monitor_path = os.path.join(self.state_dir, "monitor") monitor_socket = create_socket(clear(self.monitor_path))
self.monitor_socket = create_socket(monitor_path) shell_socket = create_socket(clear(self.shell_path))
self.process = self.start_command.run(
shell_path = os.path.join(self.state_dir, "shell") self.state_dir,
self.shell_socket = create_socket(shell_path) self.shared_dir,
self.monitor_path,
display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"]) self.shell_path,
qemu_options = (
" ".join(
[
"" if self.allow_reboot else "-no-reboot",
"-monitor unix:{}".format(monitor_path),
"-chardev socket,id=shell,path={}".format(shell_path),
"-device virtio-serial",
"-device virtconsole,chardev=shell",
"-device virtio-rng-pci",
"-serial stdio" if display_available else "-nographic",
]
)
+ " "
+ os.environ.get("QEMU_OPTS", "")
) )
self.monitor, _ = monitor_socket.accept()
environment = dict(os.environ) self.shell, _ = shell_socket.accept()
environment.update(
{
"TMPDIR": self.state_dir,
"SHARED_DIR": self.shared_dir,
"USE_TMPDIR": "1",
"QEMU_OPTS": qemu_options,
}
)
self.process = subprocess.Popen(
self.script,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
shell=True,
cwd=self.state_dir,
env=environment,
)
self.monitor, _ = self.monitor_socket.accept()
self.shell, _ = self.shell_socket.accept()
# Store last serial console lines for use # Store last serial console lines for use
# of wait_for_console_text # of wait_for_console_text
self.last_lines: Queue = Queue() self.last_lines: Queue = Queue()
def process_serial_output() -> None: def process_serial_output() -> None:
assert self.process.stdout is not None assert self.process
assert self.process.stdout
for _line in self.process.stdout: for _line in self.process.stdout:
# Ignore undecodable bytes that may occur in boot menus # Ignore undecodable bytes that may occur in boot menus
line = _line.decode(errors="ignore").replace("\r", "").rstrip() line = _line.decode(errors="ignore").replace("\r", "").rstrip()
@ -825,15 +931,15 @@ class Machine:
self.log("QEMU running (pid {})".format(self.pid)) self.log("QEMU running (pid {})".format(self.pid))
def cleanup_statedir(self) -> None: def cleanup_statedir(self) -> None:
if os.path.isdir(self.state_dir): shutil.rmtree(self.state_dir)
shutil.rmtree(self.state_dir) rootlog.log(f"deleting VM state directory {self.state_dir}")
self.logger.log(f"deleting VM state directory {self.state_dir}") rootlog.log("if you want to keep the VM state, pass --keep-vm-state")
self.logger.log("if you want to keep the VM state, pass --keep-vm-state")
def shutdown(self) -> None: def shutdown(self) -> None:
if not self.booted: if not self.booted:
return return
assert self.shell
self.shell.send("poweroff\n".encode()) self.shell.send("poweroff\n".encode())
self.wait_for_shutdown() self.wait_for_shutdown()
@ -908,41 +1014,225 @@ class Machine:
"""Make the machine reachable.""" """Make the machine reachable."""
self.send_monitor_command("set_link virtio-net-pci.1 on") self.send_monitor_command("set_link virtio-net-pci.1 on")
def release(self) -> None:
def create_machine(args: Dict[str, Any]) -> Machine: if self.pid is None:
args["log"] = log return
return Machine(args) rootlog.info(f"kill machine (pid {self.pid})")
assert self.process
assert self.shell
assert self.monitor
self.process.terminate()
self.shell.close()
self.monitor.close()
def start_all() -> None: class VLan:
with log.nested("starting all VMs"): """A handle to the vlan with this number, that also knows how to manage
for machine in machines: it's lifecycle.
machine.start() """
nr: int
socket_dir: pathlib.Path
process: Optional[subprocess.Popen]
pid: Optional[int]
fd: Optional[io.TextIOBase]
def __repr__(self) -> str:
return f"<Vlan Nr. {self.nr}>"
def __init__(self, nr: int, tmp_dir: pathlib.Path):
self.nr = nr
self.socket_dir = tmp_dir / f"vde{self.nr}.ctl"
# TODO: don't side-effect environment here
os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir)
def start(self) -> None:
rootlog.info("start vlan")
pty_master, pty_slave = pty.openpty()
self.process = subprocess.Popen(
["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"],
stdin=pty_slave,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=False,
)
self.pid = self.process.pid
self.fd = os.fdopen(pty_master, "w")
self.fd.write("version\n")
# TODO: perl version checks if this can be read from
# an if not, dies. we could hang here forever. Fix it.
assert self.process.stdout is not None
self.process.stdout.readline()
if not (self.socket_dir / "ctl").exists():
rootlog.error("cannot start vde_switch")
rootlog.info(f"running vlan (pid {self.pid})")
def release(self) -> None:
if self.pid is None:
return
rootlog.info(f"kill vlan (pid {self.pid})")
assert self.fd
assert self.process
self.fd.close()
self.process.terminate()
def join_all() -> None: class Driver:
with log.nested("waiting for all VMs to finish"): """A handle to the driver that sets up the environment
for machine in machines: and runs the tests"""
machine.wait_for_shutdown()
tests: str
vlans: List[VLan]
machines: List[Machine]
def run_tests(interactive: bool = False) -> None: def __init__(
if interactive: self,
ptpython.repl.embed(test_symbols(), {}) start_scripts: List[str],
else: vlans: List[int],
test_script() tests: str,
keep_vm_state: bool = False,
):
self.tests = tests
tmp_dir = pathlib.Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
tmp_dir.mkdir(mode=0o700, exist_ok=True)
self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
with rootlog.nested("start all VLans"):
for vlan in self.vlans:
vlan.start()
def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
for s in scripts:
yield NixStartScript(s)
self.machines = [
Machine(
start_command=cmd,
keep_vm_state=keep_vm_state,
name=cmd.machine_name,
tmp_dir=tmp_dir,
)
for cmd in cmd(start_scripts)
]
@atexit.register
def clean_up() -> None:
with rootlog.nested("clean up"):
for machine in self.machines:
machine.release()
for vlan in self.vlans:
vlan.release()
def subtest(self, name: str) -> Iterator[None]:
"""Group logs under a given test name"""
with rootlog.nested(name):
try:
yield
return True
except:
rootlog.error(f'Test "{name}" failed with error:')
raise
def test_symbols(self) -> Dict[str, Any]:
@contextmanager
def subtest(name: str) -> Iterator[None]:
return self.subtest(name)
general_symbols = dict(
start_all=self.start_all,
test_script=self.test_script,
machines=self.machines,
vlans=self.vlans,
driver=self,
log=rootlog,
os=os,
create_machine=self.create_machine,
subtest=subtest,
run_tests=self.run_tests,
join_all=self.join_all,
retry=retry,
serial_stdout_off=self.serial_stdout_off,
serial_stdout_on=self.serial_stdout_on,
Machine=Machine, # for typing
)
machine_symbols = {
m.name: self.machines[idx] for idx, m in enumerate(self.machines)
}
vlan_symbols = {
f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
}
print(
"additionally exposed symbols:\n "
+ ", ".join(map(lambda m: m.name, self.machines))
+ ",\n "
+ ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
+ ",\n "
+ ", ".join(list(general_symbols.keys()))
)
return {**general_symbols, **machine_symbols, **vlan_symbols}
def test_script(self) -> None:
"""Run the test script"""
with rootlog.nested("run the VM test script"):
symbols = self.test_symbols() # call eagerly
exec(self.tests, symbols, None)
def run_tests(self) -> None:
"""Run the test script (for non-interactive test runs)"""
self.test_script()
# TODO: Collect coverage data # TODO: Collect coverage data
for machine in machines: for machine in self.machines:
if machine.is_up(): if machine.is_up():
machine.execute("sync") machine.execute("sync")
def start_all(self) -> None:
"""Start all machines"""
with rootlog.nested("start all VMs"):
for machine in self.machines:
machine.start()
def serial_stdout_on() -> None: def join_all(self) -> None:
log._print_serial_logs = True """Wait for all machines to shut down"""
with rootlog.nested("wait for all VMs to finish"):
for machine in self.machines:
machine.wait_for_shutdown()
def create_machine(self, args: Dict[str, Any]) -> Machine:
rootlog.warning(
"Using legacy create_machine(), please instantiate the"
"Machine class directly, instead"
)
tmp_dir = pathlib.Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
tmp_dir.mkdir(mode=0o700, exist_ok=True)
def serial_stdout_off() -> None: if args.get("startCommand"):
log._print_serial_logs = False start_command: str = args.get("startCommand", "")
cmd = NixStartScript(start_command)
name = args.get("name", cmd.machine_name)
else:
cmd = Machine.create_startcommand(args) # type: ignore
name = args.get("name", "machine")
return Machine(
tmp_dir=tmp_dir,
start_command=cmd,
name=name,
keep_vm_state=args.get("keep_vm_state", False),
allow_reboot=args.get("allow_reboot", False),
)
def serial_stdout_on(self) -> None:
rootlog._print_serial_logs = True
def serial_stdout_off(self) -> None:
rootlog._print_serial_logs = False
class EnvDefault(argparse.Action): class EnvDefault(argparse.Action):
@ -970,52 +1260,6 @@ class EnvDefault(argparse.Action):
setattr(namespace, self.dest, values) setattr(namespace, self.dest, values)
@contextmanager
def subtest(name: str) -> Iterator[None]:
with log.nested(name):
try:
yield
return True
except Exception as e:
log.log(f'Test "{name}" failed with error: "{e}"')
raise e
return False
def _test_symbols() -> Dict[str, Any]:
general_symbols = dict(
start_all=start_all,
test_script=globals().get("test_script"), # same
machines=globals().get("machines"), # without being initialized
log=globals().get("log"), # extracting those symbol keys
os=os,
create_machine=create_machine,
subtest=subtest,
run_tests=run_tests,
join_all=join_all,
retry=retry,
serial_stdout_off=serial_stdout_off,
serial_stdout_on=serial_stdout_on,
Machine=Machine, # for typing
)
return general_symbols
def test_symbols() -> Dict[str, Any]:
general_symbols = _test_symbols()
machine_symbols = {m.name: machines[idx] for idx, m in enumerate(machines)}
print(
"additionally exposed symbols:\n "
+ ", ".join(map(lambda m: m.name, machines))
+ ",\n "
+ ", ".join(list(general_symbols.keys()))
)
return {**general_symbols, **machine_symbols}
if __name__ == "__main__": if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(prog="nixos-test-driver") arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
arg_parser.add_argument( arg_parser.add_argument(
@ -1055,44 +1299,18 @@ if __name__ == "__main__":
) )
args = arg_parser.parse_args() args = arg_parser.parse_args()
testscript = pathlib.Path(args.testscript).read_text()
global log, machines, test_script if not args.keep_vm_state:
rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
log = Logger() driver = Driver(
args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state
)
vde_sockets = [create_vlan(v) for v in args.vlans] if args.interactive:
for nr, vde_socket, _, _ in vde_sockets: ptpython.repl.embed(driver.test_symbols(), {})
os.environ["QEMU_VDE_SOCKET_{}".format(nr)] = vde_socket else:
tic = time.time()
machines = [ driver.run_tests()
create_machine({"startCommand": s, "keepVmState": args.keep_vm_state}) toc = time.time()
for s in args.start_scripts rootlog.info(f"test script finished in {(toc-tic):.2f}s")
]
machine_eval = [
"{0} = machines[{1}]".format(m.name, idx) for idx, m in enumerate(machines)
]
exec("\n".join(machine_eval))
@atexit.register
def clean_up() -> None:
with log.nested("cleaning up"):
for machine in machines:
if machine.pid is None:
continue
log.log("killing {} (pid {})".format(machine.name, machine.pid))
machine.process.kill()
for _, _, process, _ in vde_sockets:
process.terminate()
log.close()
def test_script() -> None:
with log.nested("running the VM test script"):
symbols = test_symbols() # call eagerly
exec(testscript, symbols, None)
interactive = args.interactive or (not bool(testscript))
tic = time.time()
run_tests(interactive)
toc = time.time()
print("test script finished in {:.2f}s".format(toc - tic))

View file

@ -43,7 +43,8 @@ rec {
from pydoc import importfile from pydoc import importfile
with open('driver-symbols', 'w') as fp: with open('driver-symbols', 'w') as fp:
t = importfile('${testDriverScript}') t = importfile('${testDriverScript}')
test_symbols = t._test_symbols() d = t.Driver([],[],"")
test_symbols = d.test_symbols()
fp.write(','.join(test_symbols.keys())) fp.write(','.join(test_symbols.keys()))
EOF EOF
''; '';
@ -188,14 +189,6 @@ rec {
--set startScripts "''${vmStartScripts[*]}" \ --set startScripts "''${vmStartScripts[*]}" \
--set testScript "$out/test-script" \ --set testScript "$out/test-script" \
--set vlans '${toString vlans}' --set vlans '${toString vlans}'
${lib.optionalString (testScript == "") ''
ln -s ${testDriver}/bin/nixos-test-driver $out/bin/nixos-run-vms
wrapProgram $out/bin/nixos-run-vms \
--set startScripts "''${vmStartScripts[*]}" \
--set testScript "${pkgs.writeText "start-all" "start_all(); join_all();"}" \
--set vlans '${toString vlans}'
''}
''); '');
# Make a full-blown test # Make a full-blown test

View file

@ -8,11 +8,20 @@ let
_file = "${networkExpr}@node-${vm}"; _file = "${networkExpr}@node-${vm}";
imports = [ module ]; imports = [ module ];
}) (import networkExpr); }) (import networkExpr);
testing = import ../../../../lib/testing-python.nix {
inherit system;
pkgs = import ../../../../.. { inherit system config; };
};
interactiveDriver = (testing.makeTest { inherit nodes; testScript = "start_all(); join_all();"; }).driverInteractive;
in in
with import ../../../../lib/testing-python.nix {
inherit system;
pkgs = import ../../../../.. { inherit system config; };
};
(makeTest { inherit nodes; testScript = ""; }).driverInteractive pkgs.runCommand "nixos-build-vms" ''
mkdir -p $out/bin
ln -s ${interactiveDriver}/bin/nixos-test-driver $out/bin/nixos-test-driver
ln -s ${interactiveDriver}/bin/nixos-test-driver $out/bin/nixos-run-vms
wrapProgram $out/bin/nixos-test-driver \
--add-flags "--interactive"
''