diff --git a/nixos/lib/test-driver/default.nix b/nixos/lib/test-driver/default.nix new file mode 100644 index 00000000000..3f63bc705b9 --- /dev/null +++ b/nixos/lib/test-driver/default.nix @@ -0,0 +1,32 @@ +{ lib +, python3Packages +, enableOCR ? false +, qemu_pkg ? qemu_test +, coreutils +, imagemagick_light +, libtiff +, netpbm +, qemu_test +, socat +, tesseract4 +, vde2 +}: + +python3Packages.buildPythonApplication rec { + pname = "nixos-test-driver"; + version = "1.0"; + src = ./.; + + propagatedBuildInputs = [ coreutils netpbm python3Packages.colorama python3Packages.ptpython qemu_pkg socat vde2 ] + ++ (lib.optionals enableOCR [ imagemagick_light tesseract4 ]); + + doCheck = true; + checkInputs = with python3Packages; [ mypy pylint black ]; + checkPhase = '' + mypy --disallow-untyped-defs \ + --no-implicit-optional \ + --ignore-missing-imports ${src}/test_driver + pylint --errors-only ${src}/test_driver + black --check --diff ${src}/test_driver + ''; +} diff --git a/nixos/lib/test-driver/setup.py b/nixos/lib/test-driver/setup.py new file mode 100644 index 00000000000..15699547216 --- /dev/null +++ b/nixos/lib/test-driver/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name="nixos-test-driver", + version='1.0', + packages=find_packages(), + entry_points={ + "console_scripts": [ + "nixos-test-driver=test_driver:main", + "generate-driver-symbols=test_driver:generate_driver_symbols" + ] + }, +) diff --git a/nixos/lib/test-driver/test_driver/__init__.py b/nixos/lib/test-driver/test_driver/__init__.py new file mode 100755 index 00000000000..5477ab5cd03 --- /dev/null +++ b/nixos/lib/test-driver/test_driver/__init__.py @@ -0,0 +1,100 @@ +from pathlib import Path +import argparse +import ptpython.repl +import os +import time + +from test_driver.logger import rootlog +from test_driver.driver import Driver + + +class EnvDefault(argparse.Action): + """An argpars Action that takes values from the specified + environment variable as the flags default value. + """ + + def __init__(self, envvar, required=False, default=None, nargs=None, **kwargs): # type: ignore + if not default and envvar: + if envvar in os.environ: + if nargs is not None and (nargs.isdigit() or nargs in ["*", "+"]): + default = os.environ[envvar].split() + else: + default = os.environ[envvar] + kwargs["help"] = ( + kwargs["help"] + f" (default from environment: {default})" + ) + if required and default: + required = False + super(EnvDefault, self).__init__( + default=default, required=required, nargs=nargs, **kwargs + ) + + def __call__(self, parser, namespace, values, option_string=None): # type: ignore + setattr(namespace, self.dest, values) + + +def main() -> None: + arg_parser = argparse.ArgumentParser(prog="nixos-test-driver") + arg_parser.add_argument( + "-K", + "--keep-vm-state", + help="re-use a VM state coming from a previous run", + action="store_true", + ) + arg_parser.add_argument( + "-I", + "--interactive", + help="drop into a python repl and run the tests interactively", + action="store_true", + ) + arg_parser.add_argument( + "--start-scripts", + metavar="START-SCRIPT", + action=EnvDefault, + envvar="startScripts", + nargs="*", + help="start scripts for participating virtual machines", + ) + arg_parser.add_argument( + "--vlans", + metavar="VLAN", + action=EnvDefault, + envvar="vlans", + nargs="*", + help="vlans to span by the driver", + ) + arg_parser.add_argument( + "testscript", + action=EnvDefault, + envvar="testScript", + help="the test script to run", + type=Path, + ) + + args = arg_parser.parse_args() + + if not args.keep_vm_state: + rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state") + + with Driver( + args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state + ) as driver: + if args.interactive: + ptpython.repl.embed(driver.test_symbols(), {}) + else: + tic = time.time() + driver.run_tests() + toc = time.time() + rootlog.info(f"test script finished in {(toc-tic):.2f}s") + + +def generate_driver_symbols() -> None: + """ + This generates a file with symbols of the test-driver code that can be used + in user's test scripts. That list is then used by pyflakes to lint those + scripts. + """ + d = Driver([], [], "") + test_symbols = d.test_symbols() + with open("driver-symbols", "w") as fp: + fp.write(",".join(test_symbols.keys())) diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py new file mode 100644 index 00000000000..f3af98537ad --- /dev/null +++ b/nixos/lib/test-driver/test_driver/driver.py @@ -0,0 +1,161 @@ +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, Iterator, List +import os +import tempfile + +from test_driver.logger import rootlog +from test_driver.machine import Machine, NixStartScript, retry +from test_driver.vlan import VLan + + +class Driver: + """A handle to the driver that sets up the environment + and runs the tests""" + + tests: str + vlans: List[VLan] + machines: List[Machine] + + def __init__( + self, + start_scripts: List[str], + vlans: List[int], + tests: str, + keep_vm_state: bool = False, + ): + self.tests = tests + + tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) + tmp_dir.mkdir(mode=0o700, exist_ok=True) + + with rootlog.nested("start all VLans"): + self.vlans = [VLan(nr, tmp_dir) for nr in vlans] + + def cmd(scripts: List[str]) -> Iterator[NixStartScript]: + for s in scripts: + yield NixStartScript(s) + + self.machines = [ + Machine( + start_command=cmd, + keep_vm_state=keep_vm_state, + name=cmd.machine_name, + tmp_dir=tmp_dir, + ) + for cmd in cmd(start_scripts) + ] + + def __enter__(self) -> "Driver": + return self + + def __exit__(self, *_: Any) -> None: + with rootlog.nested("cleanup"): + for machine in self.machines: + machine.release() + + def subtest(self, name: str) -> Iterator[None]: + """Group logs under a given test name""" + with rootlog.nested(name): + try: + yield + return True + except Exception as e: + rootlog.error(f'Test "{name}" failed with error: "{e}"') + raise e + + def test_symbols(self) -> Dict[str, Any]: + @contextmanager + def subtest(name: str) -> Iterator[None]: + return self.subtest(name) + + general_symbols = dict( + start_all=self.start_all, + test_script=self.test_script, + machines=self.machines, + vlans=self.vlans, + driver=self, + log=rootlog, + os=os, + create_machine=self.create_machine, + subtest=subtest, + run_tests=self.run_tests, + join_all=self.join_all, + retry=retry, + serial_stdout_off=self.serial_stdout_off, + serial_stdout_on=self.serial_stdout_on, + Machine=Machine, # for typing + ) + machine_symbols = {m.name: m for m in self.machines} + # If there's exactly one machine, make it available under the name + # "machine", even if it's not called that. + if len(self.machines) == 1: + (machine_symbols["machine"],) = self.machines + vlan_symbols = { + f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans) + } + print( + "additionally exposed symbols:\n " + + ", ".join(map(lambda m: m.name, self.machines)) + + ",\n " + + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans)) + + ",\n " + + ", ".join(list(general_symbols.keys())) + ) + return {**general_symbols, **machine_symbols, **vlan_symbols} + + def test_script(self) -> None: + """Run the test script""" + with rootlog.nested("run the VM test script"): + symbols = self.test_symbols() # call eagerly + exec(self.tests, symbols, None) + + def run_tests(self) -> None: + """Run the test script (for non-interactive test runs)""" + self.test_script() + # TODO: Collect coverage data + for machine in self.machines: + if machine.is_up(): + machine.execute("sync") + + def start_all(self) -> None: + """Start all machines""" + with rootlog.nested("start all VMs"): + for machine in self.machines: + machine.start() + + def join_all(self) -> None: + """Wait for all machines to shut down""" + with rootlog.nested("wait for all VMs to finish"): + for machine in self.machines: + machine.wait_for_shutdown() + + def create_machine(self, args: Dict[str, Any]) -> Machine: + rootlog.warning( + "Using legacy create_machine(), please instantiate the" + "Machine class directly, instead" + ) + tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) + tmp_dir.mkdir(mode=0o700, exist_ok=True) + + if args.get("startCommand"): + start_command: str = args.get("startCommand", "") + cmd = NixStartScript(start_command) + name = args.get("name", cmd.machine_name) + else: + cmd = Machine.create_startcommand(args) # type: ignore + name = args.get("name", "machine") + + return Machine( + tmp_dir=tmp_dir, + start_command=cmd, + name=name, + keep_vm_state=args.get("keep_vm_state", False), + allow_reboot=args.get("allow_reboot", False), + ) + + def serial_stdout_on(self) -> None: + rootlog._print_serial_logs = True + + def serial_stdout_off(self) -> None: + rootlog._print_serial_logs = False diff --git a/nixos/lib/test-driver/test_driver/logger.py b/nixos/lib/test-driver/test_driver/logger.py new file mode 100644 index 00000000000..5b3091a5129 --- /dev/null +++ b/nixos/lib/test-driver/test_driver/logger.py @@ -0,0 +1,101 @@ +from colorama import Style +from contextlib import contextmanager +from typing import Any, Dict, Iterator +from queue import Queue, Empty +from xml.sax.saxutils import XMLGenerator +import codecs +import os +import sys +import time +import unicodedata + + +class Logger: + def __init__(self) -> None: + self.logfile = os.environ.get("LOGFILE", "/dev/null") + self.logfile_handle = codecs.open(self.logfile, "wb") + self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8") + self.queue: "Queue[Dict[str, str]]" = Queue() + + self.xml.startDocument() + self.xml.startElement("logfile", attrs={}) + + self._print_serial_logs = True + + @staticmethod + def _eprint(*args: object, **kwargs: Any) -> None: + print(*args, file=sys.stderr, **kwargs) + + def close(self) -> None: + self.xml.endElement("logfile") + self.xml.endDocument() + self.logfile_handle.close() + + def sanitise(self, message: str) -> str: + return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C") + + def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str: + if "machine" in attributes: + return "{}: {}".format(attributes["machine"], message) + return message + + def log_line(self, message: str, attributes: Dict[str, str]) -> None: + self.xml.startElement("line", attributes) + self.xml.characters(message) + self.xml.endElement("line") + + def info(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + + def warning(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + + def error(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + sys.exit(1) + + def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + self._eprint(self.maybe_prefix(message, attributes)) + self.drain_log_queue() + self.log_line(message, attributes) + + def log_serial(self, message: str, machine: str) -> None: + self.enqueue({"msg": message, "machine": machine, "type": "serial"}) + if self._print_serial_logs: + self._eprint( + Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL + ) + + def enqueue(self, item: Dict[str, str]) -> None: + self.queue.put(item) + + def drain_log_queue(self) -> None: + try: + while True: + item = self.queue.get_nowait() + msg = self.sanitise(item["msg"]) + del item["msg"] + self.log_line(msg, item) + except Empty: + pass + + @contextmanager + def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + self._eprint(self.maybe_prefix(message, attributes)) + + self.xml.startElement("nest", attrs={}) + self.xml.startElement("head", attributes) + self.xml.characters(message) + self.xml.endElement("head") + + tic = time.time() + self.drain_log_queue() + yield + self.drain_log_queue() + toc = time.time() + self.log("(finished: {}, in {:.2f} seconds)".format(message, toc - tic)) + + self.xml.endElement("nest") + + +rootlog = Logger() diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test_driver/machine.py old mode 100755 new mode 100644 similarity index 72% rename from nixos/lib/test-driver/test-driver.py rename to nixos/lib/test-driver/test_driver/machine.py index 90c9e9be45c..b3dbe5126fc --- a/nixos/lib/test-driver/test-driver.py +++ b/nixos/lib/test-driver/test_driver/machine.py @@ -1,19 +1,11 @@ -#! /somewhere/python3 -from contextlib import contextmanager, _GeneratorContextManager -from queue import Queue, Empty -from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List, Iterable -from xml.sax.saxutils import XMLGenerator -from colorama import Style +from contextlib import _GeneratorContextManager from pathlib import Path -import queue -import io -import threading -import argparse +from queue import Queue +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import base64 -import codecs +import io import os -import ptpython.repl -import pty +import queue import re import shlex import shutil @@ -21,8 +13,10 @@ import socket import subprocess import sys import tempfile +import threading import time -import unicodedata + +from test_driver.logger import rootlog CHAR_TO_KEY = { "A": "shift-a", @@ -88,115 +82,10 @@ CHAR_TO_KEY = { } -class Logger: - def __init__(self) -> None: - self.logfile = os.environ.get("LOGFILE", "/dev/null") - self.logfile_handle = codecs.open(self.logfile, "wb") - self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8") - self.queue: "Queue[Dict[str, str]]" = Queue() - - self.xml.startDocument() - self.xml.startElement("logfile", attrs={}) - - self._print_serial_logs = True - - @staticmethod - def _eprint(*args: object, **kwargs: Any) -> None: - print(*args, file=sys.stderr, **kwargs) - - def close(self) -> None: - self.xml.endElement("logfile") - self.xml.endDocument() - self.logfile_handle.close() - - def sanitise(self, message: str) -> str: - return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C") - - def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str: - if "machine" in attributes: - return "{}: {}".format(attributes["machine"], message) - return message - - def log_line(self, message: str, attributes: Dict[str, str]) -> None: - self.xml.startElement("line", attributes) - self.xml.characters(message) - self.xml.endElement("line") - - def info(self, *args, **kwargs) -> None: # type: ignore - self.log(*args, **kwargs) - - def warning(self, *args, **kwargs) -> None: # type: ignore - self.log(*args, **kwargs) - - def error(self, *args, **kwargs) -> None: # type: ignore - self.log(*args, **kwargs) - sys.exit(1) - - def log(self, message: str, attributes: Dict[str, str] = {}) -> None: - self._eprint(self.maybe_prefix(message, attributes)) - self.drain_log_queue() - self.log_line(message, attributes) - - def log_serial(self, message: str, machine: str) -> None: - self.enqueue({"msg": message, "machine": machine, "type": "serial"}) - if self._print_serial_logs: - self._eprint( - Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL - ) - - def enqueue(self, item: Dict[str, str]) -> None: - self.queue.put(item) - - def drain_log_queue(self) -> None: - try: - while True: - item = self.queue.get_nowait() - msg = self.sanitise(item["msg"]) - del item["msg"] - self.log_line(msg, item) - except Empty: - pass - - @contextmanager - def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: - self._eprint(self.maybe_prefix(message, attributes)) - - self.xml.startElement("nest", attrs={}) - self.xml.startElement("head", attributes) - self.xml.characters(message) - self.xml.endElement("head") - - tic = time.time() - self.drain_log_queue() - yield - self.drain_log_queue() - toc = time.time() - self.log("(finished: {}, in {:.2f} seconds)".format(message, toc - tic)) - - self.xml.endElement("nest") - - -rootlog = Logger() - - def make_command(args: list) -> str: return " ".join(map(shlex.quote, (map(str, args)))) -def retry(fn: Callable, timeout: int = 900) -> None: - """Call the given function repeatedly, with 1 second intervals, - until it returns True or a timeout is reached. - """ - - for _ in range(timeout): - if fn(False): - return - time.sleep(1) - - if not fn(True): - raise Exception(f"action timed out after {timeout} seconds") - - def _perform_ocr_on_screenshot( screenshot_path: str, model_ids: Iterable[int] ) -> List[str]: @@ -228,6 +117,20 @@ def _perform_ocr_on_screenshot( return model_results +def retry(fn: Callable, timeout: int = 900) -> None: + """Call the given function repeatedly, with 1 second intervals, + until it returns True or a timeout is reached. + """ + + for _ in range(timeout): + if fn(False): + return + time.sleep(1) + + if not fn(True): + raise Exception(f"action timed out after {timeout} seconds") + + class StartCommand: """The Base Start Command knows how to append the necesary runtime qemu options as determined by a particular test driver @@ -1066,286 +969,3 @@ class Machine: self.shell.close() self.monitor.close() self.serial_thread.join() - - -class VLan: - """This class handles a VLAN that the run-vm scripts identify via its - number handles. The network's lifetime equals the object's lifetime. - """ - - nr: int - socket_dir: Path - - process: subprocess.Popen - pid: int - fd: io.TextIOBase - - def __repr__(self) -> str: - return f"" - - def __init__(self, nr: int, tmp_dir: Path): - self.nr = nr - self.socket_dir = tmp_dir / f"vde{self.nr}.ctl" - - # TODO: don't side-effect environment here - os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir) - - rootlog.info("start vlan") - pty_master, pty_slave = pty.openpty() - - self.process = subprocess.Popen( - ["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"], - stdin=pty_slave, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=False, - ) - self.pid = self.process.pid - self.fd = os.fdopen(pty_master, "w") - self.fd.write("version\n") - - # TODO: perl version checks if this can be read from - # an if not, dies. we could hang here forever. Fix it. - assert self.process.stdout is not None - self.process.stdout.readline() - if not (self.socket_dir / "ctl").exists(): - rootlog.error("cannot start vde_switch") - - rootlog.info(f"running vlan (pid {self.pid})") - - def __del__(self) -> None: - rootlog.info(f"kill vlan (pid {self.pid})") - self.fd.close() - self.process.terminate() - - -class Driver: - """A handle to the driver that sets up the environment - and runs the tests""" - - tests: str - vlans: List[VLan] - machines: List[Machine] - - def __init__( - self, - start_scripts: List[str], - vlans: List[int], - tests: str, - keep_vm_state: bool = False, - ): - self.tests = tests - - tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) - tmp_dir.mkdir(mode=0o700, exist_ok=True) - - with rootlog.nested("start all VLans"): - self.vlans = [VLan(nr, tmp_dir) for nr in vlans] - - def cmd(scripts: List[str]) -> Iterator[NixStartScript]: - for s in scripts: - yield NixStartScript(s) - - self.machines = [ - Machine( - start_command=cmd, - keep_vm_state=keep_vm_state, - name=cmd.machine_name, - tmp_dir=tmp_dir, - ) - for cmd in cmd(start_scripts) - ] - - def __enter__(self) -> "Driver": - return self - - def __exit__(self, *_: Any) -> None: - with rootlog.nested("cleanup"): - for machine in self.machines: - machine.release() - - def subtest(self, name: str) -> Iterator[None]: - """Group logs under a given test name""" - with rootlog.nested(name): - try: - yield - return True - except Exception as e: - rootlog.error(f'Test "{name}" failed with error: "{e}"') - raise e - - def test_symbols(self) -> Dict[str, Any]: - @contextmanager - def subtest(name: str) -> Iterator[None]: - return self.subtest(name) - - general_symbols = dict( - start_all=self.start_all, - test_script=self.test_script, - machines=self.machines, - vlans=self.vlans, - driver=self, - log=rootlog, - os=os, - create_machine=self.create_machine, - subtest=subtest, - run_tests=self.run_tests, - join_all=self.join_all, - retry=retry, - serial_stdout_off=self.serial_stdout_off, - serial_stdout_on=self.serial_stdout_on, - Machine=Machine, # for typing - ) - machine_symbols = {m.name: m for m in self.machines} - # If there's exactly one machine, make it available under the name - # "machine", even if it's not called that. - if len(self.machines) == 1: - (machine_symbols["machine"],) = self.machines - vlan_symbols = { - f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans) - } - print( - "additionally exposed symbols:\n " - + ", ".join(map(lambda m: m.name, self.machines)) - + ",\n " - + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans)) - + ",\n " - + ", ".join(list(general_symbols.keys())) - ) - return {**general_symbols, **machine_symbols, **vlan_symbols} - - def test_script(self) -> None: - """Run the test script""" - with rootlog.nested("run the VM test script"): - symbols = self.test_symbols() # call eagerly - exec(self.tests, symbols, None) - - def run_tests(self) -> None: - """Run the test script (for non-interactive test runs)""" - self.test_script() - # TODO: Collect coverage data - for machine in self.machines: - if machine.is_up(): - machine.execute("sync") - - def start_all(self) -> None: - """Start all machines""" - with rootlog.nested("start all VMs"): - for machine in self.machines: - machine.start() - - def join_all(self) -> None: - """Wait for all machines to shut down""" - with rootlog.nested("wait for all VMs to finish"): - for machine in self.machines: - machine.wait_for_shutdown() - - def create_machine(self, args: Dict[str, Any]) -> Machine: - rootlog.warning( - "Using legacy create_machine(), please instantiate the" - "Machine class directly, instead" - ) - tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) - tmp_dir.mkdir(mode=0o700, exist_ok=True) - - if args.get("startCommand"): - start_command: str = args.get("startCommand", "") - cmd = NixStartScript(start_command) - name = args.get("name", cmd.machine_name) - else: - cmd = Machine.create_startcommand(args) # type: ignore - name = args.get("name", "machine") - - return Machine( - tmp_dir=tmp_dir, - start_command=cmd, - name=name, - keep_vm_state=args.get("keep_vm_state", False), - allow_reboot=args.get("allow_reboot", False), - ) - - def serial_stdout_on(self) -> None: - rootlog._print_serial_logs = True - - def serial_stdout_off(self) -> None: - rootlog._print_serial_logs = False - - -class EnvDefault(argparse.Action): - """An argpars Action that takes values from the specified - environment variable as the flags default value. - """ - - def __init__(self, envvar, required=False, default=None, nargs=None, **kwargs): # type: ignore - if not default and envvar: - if envvar in os.environ: - if nargs is not None and (nargs.isdigit() or nargs in ["*", "+"]): - default = os.environ[envvar].split() - else: - default = os.environ[envvar] - kwargs["help"] = ( - kwargs["help"] + f" (default from environment: {default})" - ) - if required and default: - required = False - super(EnvDefault, self).__init__( - default=default, required=required, nargs=nargs, **kwargs - ) - - def __call__(self, parser, namespace, values, option_string=None): # type: ignore - setattr(namespace, self.dest, values) - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser(prog="nixos-test-driver") - arg_parser.add_argument( - "-K", - "--keep-vm-state", - help="re-use a VM state coming from a previous run", - action="store_true", - ) - arg_parser.add_argument( - "-I", - "--interactive", - help="drop into a python repl and run the tests interactively", - action="store_true", - ) - arg_parser.add_argument( - "--start-scripts", - metavar="START-SCRIPT", - action=EnvDefault, - envvar="startScripts", - nargs="*", - help="start scripts for participating virtual machines", - ) - arg_parser.add_argument( - "--vlans", - metavar="VLAN", - action=EnvDefault, - envvar="vlans", - nargs="*", - help="vlans to span by the driver", - ) - arg_parser.add_argument( - "testscript", - action=EnvDefault, - envvar="testScript", - help="the test script to run", - type=Path, - ) - - args = arg_parser.parse_args() - - if not args.keep_vm_state: - rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state") - - with Driver( - args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state - ) as driver: - if args.interactive: - ptpython.repl.embed(driver.test_symbols(), {}) - else: - tic = time.time() - driver.run_tests() - toc = time.time() - rootlog.info(f"test script finished in {(toc-tic):.2f}s") diff --git a/nixos/lib/test-driver/test_driver/vlan.py b/nixos/lib/test-driver/test_driver/vlan.py new file mode 100644 index 00000000000..e5c8f07b4ed --- /dev/null +++ b/nixos/lib/test-driver/test_driver/vlan.py @@ -0,0 +1,58 @@ +from pathlib import Path +import io +import os +import pty +import subprocess + +from test_driver.logger import rootlog + + +class VLan: + """This class handles a VLAN that the run-vm scripts identify via its + number handles. The network's lifetime equals the object's lifetime. + """ + + nr: int + socket_dir: Path + + process: subprocess.Popen + pid: int + fd: io.TextIOBase + + def __repr__(self) -> str: + return f"" + + def __init__(self, nr: int, tmp_dir: Path): + self.nr = nr + self.socket_dir = tmp_dir / f"vde{self.nr}.ctl" + + # TODO: don't side-effect environment here + os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir) + + rootlog.info("start vlan") + pty_master, pty_slave = pty.openpty() + + self.process = subprocess.Popen( + ["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"], + stdin=pty_slave, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + ) + self.pid = self.process.pid + self.fd = os.fdopen(pty_master, "w") + self.fd.write("version\n") + + # TODO: perl version checks if this can be read from + # an if not, dies. we could hang here forever. Fix it. + assert self.process.stdout is not None + self.process.stdout.readline() + if not (self.socket_dir / "ctl").exists(): + rootlog.error("cannot start vde_switch") + + rootlog.info(f"running vlan (pid {self.pid})") + + def __del__(self) -> None: + rootlog.info(f"kill vlan (pid {self.pid})") + self.fd.close() + self.process.terminate() diff --git a/nixos/lib/testing-python.nix b/nixos/lib/testing-python.nix index 4306d102b2d..365e2271457 100644 --- a/nixos/lib/testing-python.nix +++ b/nixos/lib/testing-python.nix @@ -16,65 +16,6 @@ rec { inherit pkgs; - # Reifies and correctly wraps the python test driver for - # the respective qemu version and with or without ocr support - pythonTestDriver = { - qemu_pkg ? pkgs.qemu_test - , enableOCR ? false - }: - let - name = "nixos-test-driver"; - testDriverScript = ./test-driver/test-driver.py; - ocrProg = tesseract4.override { enableLanguages = [ "eng" ]; }; - imagemagick_tiff = imagemagick_light.override { inherit libtiff; }; - in stdenv.mkDerivation { - inherit name; - - nativeBuildInputs = [ makeWrapper ]; - buildInputs = [ (python3.withPackages (p: [ p.ptpython p.colorama ])) ]; - checkInputs = with python3Packages; [ pylint black mypy ]; - - dontUnpack = true; - - preferLocalBuild = true; - - buildPhase = '' - python < $out/test-script ln -s ${testDriver}/bin/nixos-test-driver $out/bin/nixos-test-driver + ${testDriver}/bin/generate-driver-symbols ${lib.optionalString (!skipLint) '' PYFLAKES_BUILTINS="$( echo -n ${lib.escapeShellArg (lib.concatStringsSep "," nodeHostNames)}, - < ${lib.escapeShellArg "${testDriver}/nix-support/driver-symbols"} + < ${lib.escapeShellArg "driver-symbols"} )" ${python3Packages.pyflakes}/bin/pyflakes $out/test-script ''}