nixos/tests/test-driver: better control test env symbols

Previous to this commit, the entire test driver environment was shared
with the actual python test environment.

This is a hefty api surface. This commit selectively exposes only those
symbols to the test environment that are actually meant to be used by
tests.
This commit is contained in:
David Arnold 2021-06-06 13:50:02 -05:00 committed by David Arnold
parent 5edf5b60c3
commit db614e11d6
No known key found for this signature in database
GPG key ID: AB15A6AF1101390D
2 changed files with 45 additions and 16 deletions

View file

@ -89,9 +89,7 @@ CHAR_TO_KEY = {
")": "shift-0x0B",
}
# Forward references
log: "Logger"
machines: "List[Machine]"
global log, machines, test_script
def eprint(*args: object, **kwargs: Any) -> None:
@ -103,7 +101,6 @@ def make_command(args: list) -> str:
def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
global log
log.log("starting VDE switch for network {}".format(vlan_nr))
vde_socket = tempfile.mkdtemp(
prefix="nixos-test-vde-", suffix="-vde{}.ctl".format(vlan_nr)
@ -246,6 +243,9 @@ def _perform_ocr_on_screenshot(
class Machine:
def __repr__(self) -> str:
return f"<Machine '{self.name}'>"
def __init__(self, args: Dict[str, Any]) -> None:
if "name" in args:
self.name = args["name"]
@ -910,29 +910,25 @@ class Machine:
def create_machine(args: Dict[str, Any]) -> Machine:
global log
args["log"] = log
return Machine(args)
def start_all() -> None:
global machines
with log.nested("starting all VMs"):
for machine in machines:
machine.start()
def join_all() -> None:
global machines
with log.nested("waiting for all VMs to finish"):
for machine in machines:
machine.wait_for_shutdown()
def run_tests(interactive: bool = False) -> None:
global machines
if interactive:
ptpython.repl.embed(globals(), locals())
ptpython.repl.embed(test_symbols(), {})
else:
test_script()
# TODO: Collect coverage data
@ -942,12 +938,10 @@ def run_tests(interactive: bool = False) -> None:
def serial_stdout_on() -> None:
global log
log._print_serial_logs = True
def serial_stdout_off() -> None:
global log
log._print_serial_logs = False
@ -989,6 +983,37 @@ def subtest(name: str) -> Iterator[None]:
return False
def _test_symbols() -> Dict[str, Any]:
general_symbols = dict(
start_all=start_all,
test_script=globals().get("test_script"), # same
machines=globals().get("machines"), # without being initialized
log=globals().get("log"), # extracting those symbol keys
os=os,
create_machine=create_machine,
subtest=subtest,
run_tests=run_tests,
join_all=join_all,
serial_stdout_off=serial_stdout_off,
serial_stdout_on=serial_stdout_on,
)
return general_symbols
def test_symbols() -> Dict[str, Any]:
general_symbols = _test_symbols()
machine_symbols = {m.name: machines[idx] for idx, m in enumerate(machines)}
print(
"additionally exposed symbols:\n "
+ ", ".join(map(lambda m: m.name, machines))
+ ",\n "
+ ", ".join(list(general_symbols.keys()))
)
return {**general_symbols, **machine_symbols}
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
arg_parser.add_argument(
@ -1028,12 +1053,9 @@ if __name__ == "__main__":
)
args = arg_parser.parse_args()
global test_script
testscript = pathlib.Path(args.testscript).read_text()
def test_script() -> None:
with log.nested("running the VM test script"):
exec(testscript, globals())
global log, machines, test_script
log = Logger()
@ -1062,6 +1084,11 @@ if __name__ == "__main__":
process.terminate()
log.close()
def test_script() -> None:
with log.nested("running the VM test script"):
symbols = test_symbols() # call eagerly
exec(testscript, symbols, None)
interactive = args.interactive or (not bool(testscript))
tic = time.time()
run_tests(interactive)

View file

@ -42,7 +42,9 @@ rec {
python <<EOF
from pydoc import importfile
with open('driver-symbols', 'w') as fp:
fp.write(','.join(dir(importfile('${testDriverScript}'))))
t = importfile('${testDriverScript}')
test_symbols = t._test_symbols()
fp.write(','.join(test_symbols.keys()))
EOF
'';