Merge pull request #149329 from marijanp/test-driver-restructuring
nixos/test-driver: make the test-driver a python package
This commit is contained in:
commit
b6bf1ca717
8 changed files with 498 additions and 464 deletions
32
nixos/lib/test-driver/default.nix
Normal file
32
nixos/lib/test-driver/default.nix
Normal file
|
@ -0,0 +1,32 @@
|
|||
{ lib
|
||||
, python3Packages
|
||||
, enableOCR ? false
|
||||
, qemu_pkg ? qemu_test
|
||||
, coreutils
|
||||
, imagemagick_light
|
||||
, libtiff
|
||||
, netpbm
|
||||
, qemu_test
|
||||
, socat
|
||||
, tesseract4
|
||||
, vde2
|
||||
}:
|
||||
|
||||
python3Packages.buildPythonApplication rec {
|
||||
pname = "nixos-test-driver";
|
||||
version = "1.0";
|
||||
src = ./.;
|
||||
|
||||
propagatedBuildInputs = [ coreutils netpbm python3Packages.colorama python3Packages.ptpython qemu_pkg socat vde2 ]
|
||||
++ (lib.optionals enableOCR [ imagemagick_light tesseract4 ]);
|
||||
|
||||
doCheck = true;
|
||||
checkInputs = with python3Packages; [ mypy pylint black ];
|
||||
checkPhase = ''
|
||||
mypy --disallow-untyped-defs \
|
||||
--no-implicit-optional \
|
||||
--ignore-missing-imports ${src}/test_driver
|
||||
pylint --errors-only ${src}/test_driver
|
||||
black --check --diff ${src}/test_driver
|
||||
'';
|
||||
}
|
13
nixos/lib/test-driver/setup.py
Normal file
13
nixos/lib/test-driver/setup.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="nixos-test-driver",
|
||||
version='1.0',
|
||||
packages=find_packages(),
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"nixos-test-driver=test_driver:main",
|
||||
"generate-driver-symbols=test_driver:generate_driver_symbols"
|
||||
]
|
||||
},
|
||||
)
|
100
nixos/lib/test-driver/test_driver/__init__.py
Executable file
100
nixos/lib/test-driver/test_driver/__init__.py
Executable file
|
@ -0,0 +1,100 @@
|
|||
from pathlib import Path
|
||||
import argparse
|
||||
import ptpython.repl
|
||||
import os
|
||||
import time
|
||||
|
||||
from test_driver.logger import rootlog
|
||||
from test_driver.driver import Driver
|
||||
|
||||
|
||||
class EnvDefault(argparse.Action):
|
||||
"""An argpars Action that takes values from the specified
|
||||
environment variable as the flags default value.
|
||||
"""
|
||||
|
||||
def __init__(self, envvar, required=False, default=None, nargs=None, **kwargs): # type: ignore
|
||||
if not default and envvar:
|
||||
if envvar in os.environ:
|
||||
if nargs is not None and (nargs.isdigit() or nargs in ["*", "+"]):
|
||||
default = os.environ[envvar].split()
|
||||
else:
|
||||
default = os.environ[envvar]
|
||||
kwargs["help"] = (
|
||||
kwargs["help"] + f" (default from environment: {default})"
|
||||
)
|
||||
if required and default:
|
||||
required = False
|
||||
super(EnvDefault, self).__init__(
|
||||
default=default, required=required, nargs=nargs, **kwargs
|
||||
)
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None): # type: ignore
|
||||
setattr(namespace, self.dest, values)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
|
||||
arg_parser.add_argument(
|
||||
"-K",
|
||||
"--keep-vm-state",
|
||||
help="re-use a VM state coming from a previous run",
|
||||
action="store_true",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"-I",
|
||||
"--interactive",
|
||||
help="drop into a python repl and run the tests interactively",
|
||||
action="store_true",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--start-scripts",
|
||||
metavar="START-SCRIPT",
|
||||
action=EnvDefault,
|
||||
envvar="startScripts",
|
||||
nargs="*",
|
||||
help="start scripts for participating virtual machines",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--vlans",
|
||||
metavar="VLAN",
|
||||
action=EnvDefault,
|
||||
envvar="vlans",
|
||||
nargs="*",
|
||||
help="vlans to span by the driver",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"testscript",
|
||||
action=EnvDefault,
|
||||
envvar="testScript",
|
||||
help="the test script to run",
|
||||
type=Path,
|
||||
)
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
if not args.keep_vm_state:
|
||||
rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
|
||||
|
||||
with Driver(
|
||||
args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state
|
||||
) as driver:
|
||||
if args.interactive:
|
||||
ptpython.repl.embed(driver.test_symbols(), {})
|
||||
else:
|
||||
tic = time.time()
|
||||
driver.run_tests()
|
||||
toc = time.time()
|
||||
rootlog.info(f"test script finished in {(toc-tic):.2f}s")
|
||||
|
||||
|
||||
def generate_driver_symbols() -> None:
|
||||
"""
|
||||
This generates a file with symbols of the test-driver code that can be used
|
||||
in user's test scripts. That list is then used by pyflakes to lint those
|
||||
scripts.
|
||||
"""
|
||||
d = Driver([], [], "")
|
||||
test_symbols = d.test_symbols()
|
||||
with open("driver-symbols", "w") as fp:
|
||||
fp.write(",".join(test_symbols.keys()))
|
161
nixos/lib/test-driver/test_driver/driver.py
Normal file
161
nixos/lib/test-driver/test_driver/driver.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from test_driver.logger import rootlog
|
||||
from test_driver.machine import Machine, NixStartScript, retry
|
||||
from test_driver.vlan import VLan
|
||||
|
||||
|
||||
class Driver:
|
||||
"""A handle to the driver that sets up the environment
|
||||
and runs the tests"""
|
||||
|
||||
tests: str
|
||||
vlans: List[VLan]
|
||||
machines: List[Machine]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_scripts: List[str],
|
||||
vlans: List[int],
|
||||
tests: str,
|
||||
keep_vm_state: bool = False,
|
||||
):
|
||||
self.tests = tests
|
||||
|
||||
tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
|
||||
tmp_dir.mkdir(mode=0o700, exist_ok=True)
|
||||
|
||||
with rootlog.nested("start all VLans"):
|
||||
self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
|
||||
|
||||
def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
|
||||
for s in scripts:
|
||||
yield NixStartScript(s)
|
||||
|
||||
self.machines = [
|
||||
Machine(
|
||||
start_command=cmd,
|
||||
keep_vm_state=keep_vm_state,
|
||||
name=cmd.machine_name,
|
||||
tmp_dir=tmp_dir,
|
||||
)
|
||||
for cmd in cmd(start_scripts)
|
||||
]
|
||||
|
||||
def __enter__(self) -> "Driver":
|
||||
return self
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
with rootlog.nested("cleanup"):
|
||||
for machine in self.machines:
|
||||
machine.release()
|
||||
|
||||
def subtest(self, name: str) -> Iterator[None]:
|
||||
"""Group logs under a given test name"""
|
||||
with rootlog.nested(name):
|
||||
try:
|
||||
yield
|
||||
return True
|
||||
except Exception as e:
|
||||
rootlog.error(f'Test "{name}" failed with error: "{e}"')
|
||||
raise e
|
||||
|
||||
def test_symbols(self) -> Dict[str, Any]:
|
||||
@contextmanager
|
||||
def subtest(name: str) -> Iterator[None]:
|
||||
return self.subtest(name)
|
||||
|
||||
general_symbols = dict(
|
||||
start_all=self.start_all,
|
||||
test_script=self.test_script,
|
||||
machines=self.machines,
|
||||
vlans=self.vlans,
|
||||
driver=self,
|
||||
log=rootlog,
|
||||
os=os,
|
||||
create_machine=self.create_machine,
|
||||
subtest=subtest,
|
||||
run_tests=self.run_tests,
|
||||
join_all=self.join_all,
|
||||
retry=retry,
|
||||
serial_stdout_off=self.serial_stdout_off,
|
||||
serial_stdout_on=self.serial_stdout_on,
|
||||
Machine=Machine, # for typing
|
||||
)
|
||||
machine_symbols = {m.name: m for m in self.machines}
|
||||
# If there's exactly one machine, make it available under the name
|
||||
# "machine", even if it's not called that.
|
||||
if len(self.machines) == 1:
|
||||
(machine_symbols["machine"],) = self.machines
|
||||
vlan_symbols = {
|
||||
f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
|
||||
}
|
||||
print(
|
||||
"additionally exposed symbols:\n "
|
||||
+ ", ".join(map(lambda m: m.name, self.machines))
|
||||
+ ",\n "
|
||||
+ ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
|
||||
+ ",\n "
|
||||
+ ", ".join(list(general_symbols.keys()))
|
||||
)
|
||||
return {**general_symbols, **machine_symbols, **vlan_symbols}
|
||||
|
||||
def test_script(self) -> None:
|
||||
"""Run the test script"""
|
||||
with rootlog.nested("run the VM test script"):
|
||||
symbols = self.test_symbols() # call eagerly
|
||||
exec(self.tests, symbols, None)
|
||||
|
||||
def run_tests(self) -> None:
|
||||
"""Run the test script (for non-interactive test runs)"""
|
||||
self.test_script()
|
||||
# TODO: Collect coverage data
|
||||
for machine in self.machines:
|
||||
if machine.is_up():
|
||||
machine.execute("sync")
|
||||
|
||||
def start_all(self) -> None:
|
||||
"""Start all machines"""
|
||||
with rootlog.nested("start all VMs"):
|
||||
for machine in self.machines:
|
||||
machine.start()
|
||||
|
||||
def join_all(self) -> None:
|
||||
"""Wait for all machines to shut down"""
|
||||
with rootlog.nested("wait for all VMs to finish"):
|
||||
for machine in self.machines:
|
||||
machine.wait_for_shutdown()
|
||||
|
||||
def create_machine(self, args: Dict[str, Any]) -> Machine:
|
||||
rootlog.warning(
|
||||
"Using legacy create_machine(), please instantiate the"
|
||||
"Machine class directly, instead"
|
||||
)
|
||||
tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
|
||||
tmp_dir.mkdir(mode=0o700, exist_ok=True)
|
||||
|
||||
if args.get("startCommand"):
|
||||
start_command: str = args.get("startCommand", "")
|
||||
cmd = NixStartScript(start_command)
|
||||
name = args.get("name", cmd.machine_name)
|
||||
else:
|
||||
cmd = Machine.create_startcommand(args) # type: ignore
|
||||
name = args.get("name", "machine")
|
||||
|
||||
return Machine(
|
||||
tmp_dir=tmp_dir,
|
||||
start_command=cmd,
|
||||
name=name,
|
||||
keep_vm_state=args.get("keep_vm_state", False),
|
||||
allow_reboot=args.get("allow_reboot", False),
|
||||
)
|
||||
|
||||
def serial_stdout_on(self) -> None:
|
||||
rootlog._print_serial_logs = True
|
||||
|
||||
def serial_stdout_off(self) -> None:
|
||||
rootlog._print_serial_logs = False
|
101
nixos/lib/test-driver/test_driver/logger.py
Normal file
101
nixos/lib/test-driver/test_driver/logger.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
from colorama import Style
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator
|
||||
from queue import Queue, Empty
|
||||
from xml.sax.saxutils import XMLGenerator
|
||||
import codecs
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import unicodedata
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self) -> None:
|
||||
self.logfile = os.environ.get("LOGFILE", "/dev/null")
|
||||
self.logfile_handle = codecs.open(self.logfile, "wb")
|
||||
self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
|
||||
self.queue: "Queue[Dict[str, str]]" = Queue()
|
||||
|
||||
self.xml.startDocument()
|
||||
self.xml.startElement("logfile", attrs={})
|
||||
|
||||
self._print_serial_logs = True
|
||||
|
||||
@staticmethod
|
||||
def _eprint(*args: object, **kwargs: Any) -> None:
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
self.xml.endElement("logfile")
|
||||
self.xml.endDocument()
|
||||
self.logfile_handle.close()
|
||||
|
||||
def sanitise(self, message: str) -> str:
|
||||
return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
|
||||
|
||||
def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
|
||||
if "machine" in attributes:
|
||||
return "{}: {}".format(attributes["machine"], message)
|
||||
return message
|
||||
|
||||
def log_line(self, message: str, attributes: Dict[str, str]) -> None:
|
||||
self.xml.startElement("line", attributes)
|
||||
self.xml.characters(message)
|
||||
self.xml.endElement("line")
|
||||
|
||||
def info(self, *args, **kwargs) -> None: # type: ignore
|
||||
self.log(*args, **kwargs)
|
||||
|
||||
def warning(self, *args, **kwargs) -> None: # type: ignore
|
||||
self.log(*args, **kwargs)
|
||||
|
||||
def error(self, *args, **kwargs) -> None: # type: ignore
|
||||
self.log(*args, **kwargs)
|
||||
sys.exit(1)
|
||||
|
||||
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
|
||||
self._eprint(self.maybe_prefix(message, attributes))
|
||||
self.drain_log_queue()
|
||||
self.log_line(message, attributes)
|
||||
|
||||
def log_serial(self, message: str, machine: str) -> None:
|
||||
self.enqueue({"msg": message, "machine": machine, "type": "serial"})
|
||||
if self._print_serial_logs:
|
||||
self._eprint(
|
||||
Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL
|
||||
)
|
||||
|
||||
def enqueue(self, item: Dict[str, str]) -> None:
|
||||
self.queue.put(item)
|
||||
|
||||
def drain_log_queue(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
item = self.queue.get_nowait()
|
||||
msg = self.sanitise(item["msg"])
|
||||
del item["msg"]
|
||||
self.log_line(msg, item)
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
self._eprint(self.maybe_prefix(message, attributes))
|
||||
|
||||
self.xml.startElement("nest", attrs={})
|
||||
self.xml.startElement("head", attributes)
|
||||
self.xml.characters(message)
|
||||
self.xml.endElement("head")
|
||||
|
||||
tic = time.time()
|
||||
self.drain_log_queue()
|
||||
yield
|
||||
self.drain_log_queue()
|
||||
toc = time.time()
|
||||
self.log("(finished: {}, in {:.2f} seconds)".format(message, toc - tic))
|
||||
|
||||
self.xml.endElement("nest")
|
||||
|
||||
|
||||
rootlog = Logger()
|
424
nixos/lib/test-driver/test-driver.py → nixos/lib/test-driver/test_driver/machine.py
Executable file → Normal file
424
nixos/lib/test-driver/test-driver.py → nixos/lib/test-driver/test_driver/machine.py
Executable file → Normal file
|
@ -1,19 +1,11 @@
|
|||
#! /somewhere/python3
|
||||
from contextlib import contextmanager, _GeneratorContextManager
|
||||
from queue import Queue, Empty
|
||||
from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List, Iterable
|
||||
from xml.sax.saxutils import XMLGenerator
|
||||
from colorama import Style
|
||||
from contextlib import _GeneratorContextManager
|
||||
from pathlib import Path
|
||||
import queue
|
||||
import io
|
||||
import threading
|
||||
import argparse
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
import base64
|
||||
import codecs
|
||||
import io
|
||||
import os
|
||||
import ptpython.repl
|
||||
import pty
|
||||
import queue
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
|
@ -21,8 +13,10 @@ import socket
|
|||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unicodedata
|
||||
|
||||
from test_driver.logger import rootlog
|
||||
|
||||
CHAR_TO_KEY = {
|
||||
"A": "shift-a",
|
||||
|
@ -88,115 +82,10 @@ CHAR_TO_KEY = {
|
|||
}
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self) -> None:
|
||||
self.logfile = os.environ.get("LOGFILE", "/dev/null")
|
||||
self.logfile_handle = codecs.open(self.logfile, "wb")
|
||||
self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
|
||||
self.queue: "Queue[Dict[str, str]]" = Queue()
|
||||
|
||||
self.xml.startDocument()
|
||||
self.xml.startElement("logfile", attrs={})
|
||||
|
||||
self._print_serial_logs = True
|
||||
|
||||
@staticmethod
|
||||
def _eprint(*args: object, **kwargs: Any) -> None:
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
self.xml.endElement("logfile")
|
||||
self.xml.endDocument()
|
||||
self.logfile_handle.close()
|
||||
|
||||
def sanitise(self, message: str) -> str:
|
||||
return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
|
||||
|
||||
def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
|
||||
if "machine" in attributes:
|
||||
return "{}: {}".format(attributes["machine"], message)
|
||||
return message
|
||||
|
||||
def log_line(self, message: str, attributes: Dict[str, str]) -> None:
|
||||
self.xml.startElement("line", attributes)
|
||||
self.xml.characters(message)
|
||||
self.xml.endElement("line")
|
||||
|
||||
def info(self, *args, **kwargs) -> None: # type: ignore
|
||||
self.log(*args, **kwargs)
|
||||
|
||||
def warning(self, *args, **kwargs) -> None: # type: ignore
|
||||
self.log(*args, **kwargs)
|
||||
|
||||
def error(self, *args, **kwargs) -> None: # type: ignore
|
||||
self.log(*args, **kwargs)
|
||||
sys.exit(1)
|
||||
|
||||
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
|
||||
self._eprint(self.maybe_prefix(message, attributes))
|
||||
self.drain_log_queue()
|
||||
self.log_line(message, attributes)
|
||||
|
||||
def log_serial(self, message: str, machine: str) -> None:
|
||||
self.enqueue({"msg": message, "machine": machine, "type": "serial"})
|
||||
if self._print_serial_logs:
|
||||
self._eprint(
|
||||
Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL
|
||||
)
|
||||
|
||||
def enqueue(self, item: Dict[str, str]) -> None:
|
||||
self.queue.put(item)
|
||||
|
||||
def drain_log_queue(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
item = self.queue.get_nowait()
|
||||
msg = self.sanitise(item["msg"])
|
||||
del item["msg"]
|
||||
self.log_line(msg, item)
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
self._eprint(self.maybe_prefix(message, attributes))
|
||||
|
||||
self.xml.startElement("nest", attrs={})
|
||||
self.xml.startElement("head", attributes)
|
||||
self.xml.characters(message)
|
||||
self.xml.endElement("head")
|
||||
|
||||
tic = time.time()
|
||||
self.drain_log_queue()
|
||||
yield
|
||||
self.drain_log_queue()
|
||||
toc = time.time()
|
||||
self.log("(finished: {}, in {:.2f} seconds)".format(message, toc - tic))
|
||||
|
||||
self.xml.endElement("nest")
|
||||
|
||||
|
||||
rootlog = Logger()
|
||||
|
||||
|
||||
def make_command(args: list) -> str:
|
||||
return " ".join(map(shlex.quote, (map(str, args))))
|
||||
|
||||
|
||||
def retry(fn: Callable, timeout: int = 900) -> None:
|
||||
"""Call the given function repeatedly, with 1 second intervals,
|
||||
until it returns True or a timeout is reached.
|
||||
"""
|
||||
|
||||
for _ in range(timeout):
|
||||
if fn(False):
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
if not fn(True):
|
||||
raise Exception(f"action timed out after {timeout} seconds")
|
||||
|
||||
|
||||
def _perform_ocr_on_screenshot(
|
||||
screenshot_path: str, model_ids: Iterable[int]
|
||||
) -> List[str]:
|
||||
|
@ -228,6 +117,20 @@ def _perform_ocr_on_screenshot(
|
|||
return model_results
|
||||
|
||||
|
||||
def retry(fn: Callable, timeout: int = 900) -> None:
|
||||
"""Call the given function repeatedly, with 1 second intervals,
|
||||
until it returns True or a timeout is reached.
|
||||
"""
|
||||
|
||||
for _ in range(timeout):
|
||||
if fn(False):
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
if not fn(True):
|
||||
raise Exception(f"action timed out after {timeout} seconds")
|
||||
|
||||
|
||||
class StartCommand:
|
||||
"""The Base Start Command knows how to append the necesary
|
||||
runtime qemu options as determined by a particular test driver
|
||||
|
@ -1066,286 +969,3 @@ class Machine:
|
|||
self.shell.close()
|
||||
self.monitor.close()
|
||||
self.serial_thread.join()
|
||||
|
||||
|
||||
class VLan:
|
||||
"""This class handles a VLAN that the run-vm scripts identify via its
|
||||
number handles. The network's lifetime equals the object's lifetime.
|
||||
"""
|
||||
|
||||
nr: int
|
||||
socket_dir: Path
|
||||
|
||||
process: subprocess.Popen
|
||||
pid: int
|
||||
fd: io.TextIOBase
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Vlan Nr. {self.nr}>"
|
||||
|
||||
def __init__(self, nr: int, tmp_dir: Path):
|
||||
self.nr = nr
|
||||
self.socket_dir = tmp_dir / f"vde{self.nr}.ctl"
|
||||
|
||||
# TODO: don't side-effect environment here
|
||||
os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir)
|
||||
|
||||
rootlog.info("start vlan")
|
||||
pty_master, pty_slave = pty.openpty()
|
||||
|
||||
self.process = subprocess.Popen(
|
||||
["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"],
|
||||
stdin=pty_slave,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=False,
|
||||
)
|
||||
self.pid = self.process.pid
|
||||
self.fd = os.fdopen(pty_master, "w")
|
||||
self.fd.write("version\n")
|
||||
|
||||
# TODO: perl version checks if this can be read from
|
||||
# an if not, dies. we could hang here forever. Fix it.
|
||||
assert self.process.stdout is not None
|
||||
self.process.stdout.readline()
|
||||
if not (self.socket_dir / "ctl").exists():
|
||||
rootlog.error("cannot start vde_switch")
|
||||
|
||||
rootlog.info(f"running vlan (pid {self.pid})")
|
||||
|
||||
def __del__(self) -> None:
|
||||
rootlog.info(f"kill vlan (pid {self.pid})")
|
||||
self.fd.close()
|
||||
self.process.terminate()
|
||||
|
||||
|
||||
class Driver:
|
||||
"""A handle to the driver that sets up the environment
|
||||
and runs the tests"""
|
||||
|
||||
tests: str
|
||||
vlans: List[VLan]
|
||||
machines: List[Machine]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_scripts: List[str],
|
||||
vlans: List[int],
|
||||
tests: str,
|
||||
keep_vm_state: bool = False,
|
||||
):
|
||||
self.tests = tests
|
||||
|
||||
tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
|
||||
tmp_dir.mkdir(mode=0o700, exist_ok=True)
|
||||
|
||||
with rootlog.nested("start all VLans"):
|
||||
self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
|
||||
|
||||
def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
|
||||
for s in scripts:
|
||||
yield NixStartScript(s)
|
||||
|
||||
self.machines = [
|
||||
Machine(
|
||||
start_command=cmd,
|
||||
keep_vm_state=keep_vm_state,
|
||||
name=cmd.machine_name,
|
||||
tmp_dir=tmp_dir,
|
||||
)
|
||||
for cmd in cmd(start_scripts)
|
||||
]
|
||||
|
||||
def __enter__(self) -> "Driver":
|
||||
return self
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
with rootlog.nested("cleanup"):
|
||||
for machine in self.machines:
|
||||
machine.release()
|
||||
|
||||
def subtest(self, name: str) -> Iterator[None]:
|
||||
"""Group logs under a given test name"""
|
||||
with rootlog.nested(name):
|
||||
try:
|
||||
yield
|
||||
return True
|
||||
except Exception as e:
|
||||
rootlog.error(f'Test "{name}" failed with error: "{e}"')
|
||||
raise e
|
||||
|
||||
def test_symbols(self) -> Dict[str, Any]:
|
||||
@contextmanager
|
||||
def subtest(name: str) -> Iterator[None]:
|
||||
return self.subtest(name)
|
||||
|
||||
general_symbols = dict(
|
||||
start_all=self.start_all,
|
||||
test_script=self.test_script,
|
||||
machines=self.machines,
|
||||
vlans=self.vlans,
|
||||
driver=self,
|
||||
log=rootlog,
|
||||
os=os,
|
||||
create_machine=self.create_machine,
|
||||
subtest=subtest,
|
||||
run_tests=self.run_tests,
|
||||
join_all=self.join_all,
|
||||
retry=retry,
|
||||
serial_stdout_off=self.serial_stdout_off,
|
||||
serial_stdout_on=self.serial_stdout_on,
|
||||
Machine=Machine, # for typing
|
||||
)
|
||||
machine_symbols = {m.name: m for m in self.machines}
|
||||
# If there's exactly one machine, make it available under the name
|
||||
# "machine", even if it's not called that.
|
||||
if len(self.machines) == 1:
|
||||
(machine_symbols["machine"],) = self.machines
|
||||
vlan_symbols = {
|
||||
f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
|
||||
}
|
||||
print(
|
||||
"additionally exposed symbols:\n "
|
||||
+ ", ".join(map(lambda m: m.name, self.machines))
|
||||
+ ",\n "
|
||||
+ ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
|
||||
+ ",\n "
|
||||
+ ", ".join(list(general_symbols.keys()))
|
||||
)
|
||||
return {**general_symbols, **machine_symbols, **vlan_symbols}
|
||||
|
||||
def test_script(self) -> None:
|
||||
"""Run the test script"""
|
||||
with rootlog.nested("run the VM test script"):
|
||||
symbols = self.test_symbols() # call eagerly
|
||||
exec(self.tests, symbols, None)
|
||||
|
||||
def run_tests(self) -> None:
|
||||
"""Run the test script (for non-interactive test runs)"""
|
||||
self.test_script()
|
||||
# TODO: Collect coverage data
|
||||
for machine in self.machines:
|
||||
if machine.is_up():
|
||||
machine.execute("sync")
|
||||
|
||||
def start_all(self) -> None:
|
||||
"""Start all machines"""
|
||||
with rootlog.nested("start all VMs"):
|
||||
for machine in self.machines:
|
||||
machine.start()
|
||||
|
||||
def join_all(self) -> None:
|
||||
"""Wait for all machines to shut down"""
|
||||
with rootlog.nested("wait for all VMs to finish"):
|
||||
for machine in self.machines:
|
||||
machine.wait_for_shutdown()
|
||||
|
||||
def create_machine(self, args: Dict[str, Any]) -> Machine:
|
||||
rootlog.warning(
|
||||
"Using legacy create_machine(), please instantiate the"
|
||||
"Machine class directly, instead"
|
||||
)
|
||||
tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
|
||||
tmp_dir.mkdir(mode=0o700, exist_ok=True)
|
||||
|
||||
if args.get("startCommand"):
|
||||
start_command: str = args.get("startCommand", "")
|
||||
cmd = NixStartScript(start_command)
|
||||
name = args.get("name", cmd.machine_name)
|
||||
else:
|
||||
cmd = Machine.create_startcommand(args) # type: ignore
|
||||
name = args.get("name", "machine")
|
||||
|
||||
return Machine(
|
||||
tmp_dir=tmp_dir,
|
||||
start_command=cmd,
|
||||
name=name,
|
||||
keep_vm_state=args.get("keep_vm_state", False),
|
||||
allow_reboot=args.get("allow_reboot", False),
|
||||
)
|
||||
|
||||
def serial_stdout_on(self) -> None:
|
||||
rootlog._print_serial_logs = True
|
||||
|
||||
def serial_stdout_off(self) -> None:
|
||||
rootlog._print_serial_logs = False
|
||||
|
||||
|
||||
class EnvDefault(argparse.Action):
|
||||
"""An argpars Action that takes values from the specified
|
||||
environment variable as the flags default value.
|
||||
"""
|
||||
|
||||
def __init__(self, envvar, required=False, default=None, nargs=None, **kwargs): # type: ignore
|
||||
if not default and envvar:
|
||||
if envvar in os.environ:
|
||||
if nargs is not None and (nargs.isdigit() or nargs in ["*", "+"]):
|
||||
default = os.environ[envvar].split()
|
||||
else:
|
||||
default = os.environ[envvar]
|
||||
kwargs["help"] = (
|
||||
kwargs["help"] + f" (default from environment: {default})"
|
||||
)
|
||||
if required and default:
|
||||
required = False
|
||||
super(EnvDefault, self).__init__(
|
||||
default=default, required=required, nargs=nargs, **kwargs
|
||||
)
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None): # type: ignore
|
||||
setattr(namespace, self.dest, values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
|
||||
arg_parser.add_argument(
|
||||
"-K",
|
||||
"--keep-vm-state",
|
||||
help="re-use a VM state coming from a previous run",
|
||||
action="store_true",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"-I",
|
||||
"--interactive",
|
||||
help="drop into a python repl and run the tests interactively",
|
||||
action="store_true",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--start-scripts",
|
||||
metavar="START-SCRIPT",
|
||||
action=EnvDefault,
|
||||
envvar="startScripts",
|
||||
nargs="*",
|
||||
help="start scripts for participating virtual machines",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--vlans",
|
||||
metavar="VLAN",
|
||||
action=EnvDefault,
|
||||
envvar="vlans",
|
||||
nargs="*",
|
||||
help="vlans to span by the driver",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"testscript",
|
||||
action=EnvDefault,
|
||||
envvar="testScript",
|
||||
help="the test script to run",
|
||||
type=Path,
|
||||
)
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
if not args.keep_vm_state:
|
||||
rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
|
||||
|
||||
with Driver(
|
||||
args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state
|
||||
) as driver:
|
||||
if args.interactive:
|
||||
ptpython.repl.embed(driver.test_symbols(), {})
|
||||
else:
|
||||
tic = time.time()
|
||||
driver.run_tests()
|
||||
toc = time.time()
|
||||
rootlog.info(f"test script finished in {(toc-tic):.2f}s")
|
58
nixos/lib/test-driver/test_driver/vlan.py
Normal file
58
nixos/lib/test-driver/test_driver/vlan.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
from pathlib import Path
|
||||
import io
|
||||
import os
|
||||
import pty
|
||||
import subprocess
|
||||
|
||||
from test_driver.logger import rootlog
|
||||
|
||||
|
||||
class VLan:
|
||||
"""This class handles a VLAN that the run-vm scripts identify via its
|
||||
number handles. The network's lifetime equals the object's lifetime.
|
||||
"""
|
||||
|
||||
nr: int
|
||||
socket_dir: Path
|
||||
|
||||
process: subprocess.Popen
|
||||
pid: int
|
||||
fd: io.TextIOBase
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Vlan Nr. {self.nr}>"
|
||||
|
||||
def __init__(self, nr: int, tmp_dir: Path):
|
||||
self.nr = nr
|
||||
self.socket_dir = tmp_dir / f"vde{self.nr}.ctl"
|
||||
|
||||
# TODO: don't side-effect environment here
|
||||
os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir)
|
||||
|
||||
rootlog.info("start vlan")
|
||||
pty_master, pty_slave = pty.openpty()
|
||||
|
||||
self.process = subprocess.Popen(
|
||||
["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"],
|
||||
stdin=pty_slave,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=False,
|
||||
)
|
||||
self.pid = self.process.pid
|
||||
self.fd = os.fdopen(pty_master, "w")
|
||||
self.fd.write("version\n")
|
||||
|
||||
# TODO: perl version checks if this can be read from
|
||||
# an if not, dies. we could hang here forever. Fix it.
|
||||
assert self.process.stdout is not None
|
||||
self.process.stdout.readline()
|
||||
if not (self.socket_dir / "ctl").exists():
|
||||
rootlog.error("cannot start vde_switch")
|
||||
|
||||
rootlog.info(f"running vlan (pid {self.pid})")
|
||||
|
||||
def __del__(self) -> None:
|
||||
rootlog.info(f"kill vlan (pid {self.pid})")
|
||||
self.fd.close()
|
||||
self.process.terminate()
|
|
@ -16,65 +16,6 @@ rec {
|
|||
|
||||
inherit pkgs;
|
||||
|
||||
# Reifies and correctly wraps the python test driver for
|
||||
# the respective qemu version and with or without ocr support
|
||||
pythonTestDriver = {
|
||||
qemu_pkg ? pkgs.qemu_test
|
||||
, enableOCR ? false
|
||||
}:
|
||||
let
|
||||
name = "nixos-test-driver";
|
||||
testDriverScript = ./test-driver/test-driver.py;
|
||||
ocrProg = tesseract4.override { enableLanguages = [ "eng" ]; };
|
||||
imagemagick_tiff = imagemagick_light.override { inherit libtiff; };
|
||||
in stdenv.mkDerivation {
|
||||
inherit name;
|
||||
|
||||
nativeBuildInputs = [ makeWrapper ];
|
||||
buildInputs = [ (python3.withPackages (p: [ p.ptpython p.colorama ])) ];
|
||||
checkInputs = with python3Packages; [ pylint black mypy ];
|
||||
|
||||
dontUnpack = true;
|
||||
|
||||
preferLocalBuild = true;
|
||||
|
||||
buildPhase = ''
|
||||
python <<EOF
|
||||
from pydoc import importfile
|
||||
with open('driver-symbols', 'w') as fp:
|
||||
t = importfile('${testDriverScript}')
|
||||
d = t.Driver([],[],"")
|
||||
test_symbols = d.test_symbols()
|
||||
fp.write(','.join(test_symbols.keys()))
|
||||
EOF
|
||||
'';
|
||||
|
||||
doCheck = true;
|
||||
checkPhase = ''
|
||||
mypy --disallow-untyped-defs \
|
||||
--no-implicit-optional \
|
||||
--ignore-missing-imports ${testDriverScript}
|
||||
pylint --errors-only ${testDriverScript}
|
||||
black --check --diff ${testDriverScript}
|
||||
'';
|
||||
|
||||
installPhase =
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
cp ${testDriverScript} $out/bin/nixos-test-driver
|
||||
chmod u+x $out/bin/nixos-test-driver
|
||||
# TODO: copy user script part into this file (append)
|
||||
|
||||
wrapProgram $out/bin/nixos-test-driver \
|
||||
--argv0 ${name} \
|
||||
--prefix PATH : "${lib.makeBinPath [ qemu_pkg vde2 netpbm coreutils socat ]}" \
|
||||
${lib.optionalString enableOCR
|
||||
"--prefix PATH : '${ocrProg}/bin:${imagemagick_tiff}/bin'"} \
|
||||
|
||||
install -m 0644 -vD driver-symbols $out/nix-support/driver-symbols
|
||||
'';
|
||||
};
|
||||
|
||||
# Run an automated test suite in the given virtual network.
|
||||
runTests = { driver, pos }:
|
||||
stdenv.mkDerivation {
|
||||
|
@ -112,8 +53,15 @@ rec {
|
|||
, passthru ? {}
|
||||
}:
|
||||
let
|
||||
# FIXME: get this pkg from the module system
|
||||
testDriver = pythonTestDriver { inherit qemu_pkg enableOCR;};
|
||||
# Reifies and correctly wraps the python test driver for
|
||||
# the respective qemu version and with or without ocr support
|
||||
testDriver = pkgs.callPackage ./test-driver {
|
||||
inherit enableOCR;
|
||||
qemu_pkg = qemu_test;
|
||||
imagemagick_light = imagemagick_light.override { inherit libtiff; };
|
||||
tesseract4 = tesseract4.override { enableLanguages = [ "eng" ]; };
|
||||
};
|
||||
|
||||
|
||||
testDriverName =
|
||||
let
|
||||
|
@ -178,10 +126,11 @@ rec {
|
|||
echo -n "$testScript" > $out/test-script
|
||||
ln -s ${testDriver}/bin/nixos-test-driver $out/bin/nixos-test-driver
|
||||
|
||||
${testDriver}/bin/generate-driver-symbols
|
||||
${lib.optionalString (!skipLint) ''
|
||||
PYFLAKES_BUILTINS="$(
|
||||
echo -n ${lib.escapeShellArg (lib.concatStringsSep "," nodeHostNames)},
|
||||
< ${lib.escapeShellArg "${testDriver}/nix-support/driver-symbols"}
|
||||
< ${lib.escapeShellArg "driver-symbols"}
|
||||
)" ${python3Packages.pyflakes}/bin/pyflakes $out/test-script
|
||||
''}
|
||||
|
||||
|
|
Loading…
Reference in a new issue