from __future__ import annotations import os import pathlib import signal import subprocess import sys from functools import partial from typing import Protocol import pytest import trio._repl class RawInput(Protocol): def __call__(self, prompt: str = "") -> str: ... def build_raw_input(cmds: list[str]) -> RawInput: """ Pass in a list of strings. Returns a callable that returns each string, each time its called When there are not more strings to return, raise EOFError """ cmds_iter = iter(cmds) prompts = [] def _raw_helper(prompt: str = "") -> str: prompts.append(prompt) try: return next(cmds_iter) except StopIteration: raise EOFError from None return _raw_helper def test_build_raw_input() -> None: """Quick test of our helper function.""" raw_input = build_raw_input(["cmd1"]) assert raw_input() == "cmd1" with pytest.raises(EOFError): raw_input() async def test_basic_interaction( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: """ Run some basic commands through the interpreter while capturing stdout. Ensure that the interpreted prints the expected results. """ console = trio._repl.TrioInteractiveConsole() raw_input = build_raw_input( [ # evaluate simple expression and recall the value "x = 1", "print(f'{x=}')", # Literal gets printed "'hello'", # define and call sync function "def func():", " print(x + 1)", "", "func()", # define and call async function "async def afunc():", " return 4", "", "await afunc()", # import works "import sys", "sys.stdout.write('hello stdout\\n')", ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) out, _err = capsys.readouterr() assert out.splitlines() == ["x=1", "'hello'", "2", "4", "hello stdout", "13"] async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) -> None: console = trio._repl.TrioInteractiveConsole() raw_input = build_raw_input( [ "raise SystemExit", ], ) monkeypatch.setattr(console, "raw_input", raw_input) with pytest.raises(SystemExit): await trio._repl.run_repl(console) async def test_KI_interrupts( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: console = trio._repl.TrioInteractiveConsole() raw_input = build_raw_input( [ "import signal, trio, trio.lowlevel", "async def f():", " trio.lowlevel.spawn_system_task(" " trio.to_thread.run_sync," " signal.raise_signal, signal.SIGINT," " )", # just awaiting this kills the test runner?! " await trio.sleep_forever()", " print('should not see this')", "", "await f()", "print('AFTER KeyboardInterrupt')", ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) out, err = capsys.readouterr() assert "KeyboardInterrupt" in err assert "should" not in out assert "AFTER KeyboardInterrupt" in out async def test_system_exits_in_exc_group( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: console = trio._repl.TrioInteractiveConsole() raw_input = build_raw_input( [ "import sys", "if sys.version_info < (3, 11):", " from exceptiongroup import BaseExceptionGroup", "", "raise BaseExceptionGroup('', [RuntimeError(), SystemExit()])", "print('AFTER BaseExceptionGroup')", ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) out, _err = capsys.readouterr() # assert that raise SystemExit in an exception group # doesn't quit assert "AFTER BaseExceptionGroup" in out async def test_system_exits_in_nested_exc_group( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: console = trio._repl.TrioInteractiveConsole() raw_input = build_raw_input( [ "import sys", "if sys.version_info < (3, 11):", " from exceptiongroup import BaseExceptionGroup", "", "raise BaseExceptionGroup(", " '', [BaseExceptionGroup('', [RuntimeError(), SystemExit()])])", "print('AFTER BaseExceptionGroup')", ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) out, _err = capsys.readouterr() # assert that raise SystemExit in an exception group # doesn't quit assert "AFTER BaseExceptionGroup" in out async def test_base_exception_captured( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: console = trio._repl.TrioInteractiveConsole() raw_input = build_raw_input( [ # The statement after raise should still get executed "raise BaseException", "print('AFTER BaseException')", ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) out, err = capsys.readouterr() assert "_threads.py" not in err assert "_repl.py" not in err assert "AFTER BaseException" in out async def test_exc_group_captured( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: console = trio._repl.TrioInteractiveConsole() raw_input = build_raw_input( [ # The statement after raise should still get executed "raise ExceptionGroup('', [KeyError()])", "print('AFTER ExceptionGroup')", ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) out, _err = capsys.readouterr() assert "AFTER ExceptionGroup" in out async def test_base_exception_capture_from_coroutine( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: console = trio._repl.TrioInteractiveConsole() raw_input = build_raw_input( [ "async def async_func_raises_base_exception():", " raise BaseException", "", # This will raise, but the statement after should still # be executed "await async_func_raises_base_exception()", "print('AFTER BaseException')", ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) out, err = capsys.readouterr() assert "_threads.py" not in err assert "_repl.py" not in err assert "AFTER BaseException" in out def test_main_entrypoint() -> None: """ Basic smoke test when running via the package __main__ entrypoint. """ repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()") assert repl.returncode == 0 def should_try_newline_injection() -> bool: if sys.platform != "linux": return False sysctl = pathlib.Path("/proc/sys/dev/tty/legacy_tiocsti") if not sysctl.exists(): # pragma: no cover return True else: return sysctl.read_text() == "1" @pytest.mark.skipif( not should_try_newline_injection(), reason="the ioctl we use is disabled in CI", ) def test_ki_newline_injection() -> None: # TODO: test this line # TODO: we want to remove this functionality, eg by using vendored # pyrepls. assert sys.platform != "win32" import pty # NOTE: this cannot be subprocess.Popen because pty.fork # does some magic to set the controlling terminal. # (which I don't know how to replicate... so I copied this # structure from pty.spawn...) pid, pty_fd = pty.fork() # type: ignore[attr-defined,unused-ignore] if pid == 0: os.execlp(sys.executable, *[sys.executable, "-u", "-m", "trio"]) # setup: buffer = b"" while not buffer.endswith(b"import trio\r\n>>> "): buffer += os.read(pty_fd, 4096) # sanity check: print(buffer.decode()) buffer = b"" os.write(pty_fd, b'print("hello!")\n') while not buffer.endswith(b">>> "): buffer += os.read(pty_fd, 4096) assert buffer.count(b"hello!") == 2 # press ctrl+c print(buffer.decode()) buffer = b"" os.kill(pid, signal.SIGINT) while not buffer.endswith(b">>> "): buffer += os.read(pty_fd, 4096) assert b"KeyboardInterrupt" in buffer # press ctrl+c later print(buffer.decode()) buffer = b"" os.write(pty_fd, b'print("hello!")') os.kill(pid, signal.SIGINT) while not buffer.endswith(b">>> "): buffer += os.read(pty_fd, 4096) assert b"KeyboardInterrupt" in buffer print(buffer.decode()) os.close(pty_fd) os.waitpid(pid, 0)[1] async def test_ki_in_repl() -> None: async with trio.open_nursery() as nursery: proc = await nursery.start( partial( trio.run_process, [sys.executable, "-u", "-m", "trio"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, stdin=subprocess.PIPE, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if sys.platform == "win32" else 0, # type: ignore[attr-defined,unused-ignore] ) ) async with proc.stdout: # setup buffer = b"" async for part in proc.stdout: # pragma: no branch buffer += part # TODO: consider making run_process stdout have some universal newlines thing if buffer.replace(b"\r\n", b"\n").endswith(b"import trio\n>>> "): break # ensure things work print(buffer.decode()) buffer = b"" await proc.stdin.send_all(b'print("hello!")\n') async for part in proc.stdout: # pragma: no branch buffer += part if buffer.endswith(b">>> "): break assert b"hello!" in buffer print(buffer.decode()) # this seems to be necessary on Windows for reasons # (the parents of process groups ignore ctrl+c by default...) if sys.platform == "win32": buffer = b"" await proc.stdin.send_all( b"import ctypes; ctypes.windll.kernel32.SetConsoleCtrlHandler(None, False)\n" ) async for part in proc.stdout: # pragma: no branch buffer += part if buffer.endswith(b">>> "): break print(buffer.decode()) # try to decrease flakiness... buffer = b"" await proc.stdin.send_all( b"import coverage; trio.lowlevel.enable_ki_protection(coverage.pytracer.PyTracer._trace)\n" ) async for part in proc.stdout: # pragma: no branch buffer += part if buffer.endswith(b">>> "): break print(buffer.decode()) # ensure that ctrl+c on a prompt works # NOTE: for some reason, signal.SIGINT doesn't work for this test. # Using CTRL_C_EVENT is also why we need subprocess.CREATE_NEW_PROCESS_GROUP signal_sent = signal.CTRL_C_EVENT if sys.platform == "win32" else signal.SIGINT # type: ignore[attr-defined,unused-ignore] os.kill(proc.pid, signal_sent) if sys.platform == "win32": # we rely on EOFError which... doesn't happen with pipes. # I'm not sure how to fix it... await proc.stdin.send_all(b"\n") else: # we test injection separately await proc.stdin.send_all(b"\n") buffer = b"" async for part in proc.stdout: # pragma: no branch buffer += part if buffer.endswith(b">>> "): break assert b"KeyboardInterrupt" in buffer # ensure ctrl+c while a command runs works print(buffer.decode()) await proc.stdin.send_all(b'print("READY"); await trio.sleep_forever()\n') killed = False buffer = b"" async for part in proc.stdout: # pragma: no branch buffer += part if buffer.replace(b"\r\n", b"\n").endswith(b"READY\n") and not killed: os.kill(proc.pid, signal_sent) killed = True if buffer.endswith(b">>> "): break assert b"trio" in buffer assert b"KeyboardInterrupt" in buffer # make sure it works for sync commands too # (though this would be hard to break) print(buffer.decode()) await proc.stdin.send_all( b'import time; print("READY"); time.sleep(99999)\n' ) killed = False buffer = b"" async for part in proc.stdout: # pragma: no branch buffer += part if buffer.replace(b"\r\n", b"\n").endswith(b"READY\n") and not killed: os.kill(proc.pid, signal_sent) killed = True if buffer.endswith(b">>> "): break assert b"Traceback" in buffer assert b"KeyboardInterrupt" in buffer print(buffer.decode()) # kill the process nursery.cancel_scope.cancel()