diff --git a/nixos/lib/test-driver/test_driver/__init__.py b/nixos/lib/test-driver/test_driver/__init__.py index 5477ab5cd038..498a4f56c55b 100755 --- a/nixos/lib/test-driver/test_driver/__init__.py +++ b/nixos/lib/test-driver/test_driver/__init__.py @@ -33,6 +33,22 @@ class EnvDefault(argparse.Action): setattr(namespace, self.dest, values) +def writeable_dir(arg: str) -> Path: + """Raises an ArgumentTypeError if the given argument isn't a writeable directory + Note: We want to fail as early as possible if a directory isn't writeable, + since an executed nixos-test could fail (very late) because of the test-driver + writing in a directory without proper permissions. + """ + path = Path(arg) + if not path.is_dir(): + raise argparse.ArgumentTypeError("{0} is not a directory".format(path)) + if not os.access(path, os.W_OK): + raise argparse.ArgumentTypeError( + "{0} is not a writeable directory".format(path) + ) + return path + + def main() -> None: arg_parser = argparse.ArgumentParser(prog="nixos-test-driver") arg_parser.add_argument( @@ -63,6 +79,14 @@ def main() -> None: nargs="*", help="vlans to span by the driver", ) + arg_parser.add_argument( + "-o", + "--output_directory", + help="""The path to the directory where outputs copied from the VM will be placed. + By e.g. Machine.copy_from_vm or Machine.screenshot""", + default=Path.cwd(), + type=writeable_dir, + ) arg_parser.add_argument( "testscript", action=EnvDefault, @@ -77,7 +101,11 @@ def main() -> None: rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state") with Driver( - args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state + args.start_scripts, + args.vlans, + args.testscript.read_text(), + args.output_directory.resolve(), + args.keep_vm_state, ) as driver: if args.interactive: ptpython.repl.embed(driver.test_symbols(), {}) @@ -94,7 +122,7 @@ def generate_driver_symbols() -> None: in user's test scripts. That list is then used by pyflakes to lint those scripts. """ - d = Driver([], [], "") + d = Driver([], [], "", Path()) test_symbols = d.test_symbols() with open("driver-symbols", "w") as fp: fp.write(",".join(test_symbols.keys())) diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py index 49a42fe5fb4e..880b1c5fdec0 100644 --- a/nixos/lib/test-driver/test_driver/driver.py +++ b/nixos/lib/test-driver/test_driver/driver.py @@ -10,6 +10,28 @@ from test_driver.vlan import VLan from test_driver.polling_condition import PollingCondition +def get_tmp_dir() -> Path: + """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD + Raises an exception in case the retrieved temporary directory is not writeable + See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir + """ + tmp_dir = Path(tempfile.gettempdir()) + tmp_dir.mkdir(mode=0o700, exist_ok=True) + if not tmp_dir.is_dir(): + raise NotADirectoryError( + "The directory defined by TMPDIR, TEMP, TMP or CWD: {0} is not a directory".format( + tmp_dir + ) + ) + if not os.access(tmp_dir, os.W_OK): + raise PermissionError( + "The directory defined by TMPDIR, TEMP, TMP, or CWD: {0} is not writeable".format( + tmp_dir + ) + ) + return tmp_dir + + class Driver: """A handle to the driver that sets up the environment and runs the tests""" @@ -24,12 +46,13 @@ class Driver: start_scripts: List[str], vlans: List[int], tests: str, + out_dir: Path, keep_vm_state: bool = False, ): self.tests = tests + self.out_dir = out_dir - tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) - tmp_dir.mkdir(mode=0o700, exist_ok=True) + tmp_dir = get_tmp_dir() with rootlog.nested("start all VLans"): self.vlans = [VLan(nr, tmp_dir) for nr in vlans] @@ -47,6 +70,7 @@ class Driver: name=cmd.machine_name, tmp_dir=tmp_dir, callbacks=[self.check_polling_conditions], + out_dir=self.out_dir, ) for cmd in cmd(start_scripts) ] @@ -141,8 +165,8 @@ class Driver: "Using legacy create_machine(), please instantiate the" "Machine class directly, instead" ) - tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) - tmp_dir.mkdir(mode=0o700, exist_ok=True) + + tmp_dir = get_tmp_dir() if args.get("startCommand"): start_command: str = args.get("startCommand", "") @@ -154,6 +178,7 @@ class Driver: return Machine( tmp_dir=tmp_dir, + out_dir=self.out_dir, start_command=cmd, name=name, keep_vm_state=args.get("keep_vm_state", False), diff --git a/nixos/lib/test-driver/test_driver/machine.py b/nixos/lib/test-driver/test_driver/machine.py index e050cbd7d990..a41c419ebe6a 100644 --- a/nixos/lib/test-driver/test_driver/machine.py +++ b/nixos/lib/test-driver/test_driver/machine.py @@ -297,6 +297,7 @@ class Machine: the machine lifecycle with the help of a start script / command.""" name: str + out_dir: Path tmp_dir: Path shared_dir: Path state_dir: Path @@ -325,6 +326,7 @@ class Machine: def __init__( self, + out_dir: Path, tmp_dir: Path, start_command: StartCommand, name: str = "machine", @@ -332,6 +334,7 @@ class Machine: allow_reboot: bool = False, callbacks: Optional[List[Callable]] = None, ) -> None: + self.out_dir = out_dir self.tmp_dir = tmp_dir self.keep_vm_state = keep_vm_state self.allow_reboot = allow_reboot @@ -702,10 +705,9 @@ class Machine: self.connected = True def screenshot(self, filename: str) -> None: - out_dir = os.environ.get("out", os.getcwd()) word_pattern = re.compile(r"^\w+$") if word_pattern.match(filename): - filename = os.path.join(out_dir, "{}.png".format(filename)) + filename = os.path.join(self.out_dir, "{}.png".format(filename)) tmp = "{}.ppm".format(filename) with self.nested( @@ -756,7 +758,6 @@ class Machine: all the VMs (using a temporary directory). """ # Compute the source, target, and intermediate shared file names - out_dir = Path(os.environ.get("out", os.getcwd())) vm_src = Path(source) with tempfile.TemporaryDirectory(dir=self.shared_dir) as shared_td: shared_temp = Path(shared_td) @@ -766,7 +767,7 @@ class Machine: # Copy the file to the shared directory inside VM self.succeed(make_command(["mkdir", "-p", vm_shared_temp])) self.succeed(make_command(["cp", "-r", vm_src, vm_intermediate])) - abs_target = out_dir / target_dir / vm_src.name + abs_target = self.out_dir / target_dir / vm_src.name abs_target.parent.mkdir(exist_ok=True, parents=True) # Copy the file from the shared directory outside VM if intermediate.is_dir():