initial commit
This commit is contained in:
@@ -0,0 +1,428 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import Protocol
|
||||
|
||||
import pytest
|
||||
|
||||
import trio._repl
|
||||
|
||||
|
||||
class RawInput(Protocol):
|
||||
def __call__(self, prompt: str = "") -> str: ...
|
||||
|
||||
|
||||
def build_raw_input(cmds: list[str]) -> RawInput:
|
||||
"""
|
||||
Pass in a list of strings.
|
||||
Returns a callable that returns each string, each time its called
|
||||
When there are not more strings to return, raise EOFError
|
||||
"""
|
||||
cmds_iter = iter(cmds)
|
||||
prompts = []
|
||||
|
||||
def _raw_helper(prompt: str = "") -> str:
|
||||
prompts.append(prompt)
|
||||
try:
|
||||
return next(cmds_iter)
|
||||
except StopIteration:
|
||||
raise EOFError from None
|
||||
|
||||
return _raw_helper
|
||||
|
||||
|
||||
def test_build_raw_input() -> None:
|
||||
"""Quick test of our helper function."""
|
||||
raw_input = build_raw_input(["cmd1"])
|
||||
assert raw_input() == "cmd1"
|
||||
with pytest.raises(EOFError):
|
||||
raw_input()
|
||||
|
||||
|
||||
async def test_basic_interaction(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Run some basic commands through the interpreter while capturing stdout.
|
||||
Ensure that the interpreted prints the expected results.
|
||||
"""
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# evaluate simple expression and recall the value
|
||||
"x = 1",
|
||||
"print(f'{x=}')",
|
||||
# Literal gets printed
|
||||
"'hello'",
|
||||
# define and call sync function
|
||||
"def func():",
|
||||
" print(x + 1)",
|
||||
"",
|
||||
"func()",
|
||||
# define and call async function
|
||||
"async def afunc():",
|
||||
" return 4",
|
||||
"",
|
||||
"await afunc()",
|
||||
# import works
|
||||
"import sys",
|
||||
"sys.stdout.write('hello stdout\\n')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, _err = capsys.readouterr()
|
||||
assert out.splitlines() == ["x=1", "'hello'", "2", "4", "hello stdout", "13"]
|
||||
|
||||
|
||||
async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"raise SystemExit",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
with pytest.raises(SystemExit):
|
||||
await trio._repl.run_repl(console)
|
||||
|
||||
|
||||
async def test_KI_interrupts(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"import signal, trio, trio.lowlevel",
|
||||
"async def f():",
|
||||
" trio.lowlevel.spawn_system_task("
|
||||
" trio.to_thread.run_sync,"
|
||||
" signal.raise_signal, signal.SIGINT,"
|
||||
" )", # just awaiting this kills the test runner?!
|
||||
" await trio.sleep_forever()",
|
||||
" print('should not see this')",
|
||||
"",
|
||||
"await f()",
|
||||
"print('AFTER KeyboardInterrupt')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "KeyboardInterrupt" in err
|
||||
assert "should" not in out
|
||||
assert "AFTER KeyboardInterrupt" in out
|
||||
|
||||
|
||||
async def test_system_exits_in_exc_group(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"import sys",
|
||||
"if sys.version_info < (3, 11):",
|
||||
" from exceptiongroup import BaseExceptionGroup",
|
||||
"",
|
||||
"raise BaseExceptionGroup('', [RuntimeError(), SystemExit()])",
|
||||
"print('AFTER BaseExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, _err = capsys.readouterr()
|
||||
# assert that raise SystemExit in an exception group
|
||||
# doesn't quit
|
||||
assert "AFTER BaseExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_system_exits_in_nested_exc_group(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"import sys",
|
||||
"if sys.version_info < (3, 11):",
|
||||
" from exceptiongroup import BaseExceptionGroup",
|
||||
"",
|
||||
"raise BaseExceptionGroup(",
|
||||
" '', [BaseExceptionGroup('', [RuntimeError(), SystemExit()])])",
|
||||
"print('AFTER BaseExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, _err = capsys.readouterr()
|
||||
# assert that raise SystemExit in an exception group
|
||||
# doesn't quit
|
||||
assert "AFTER BaseExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_base_exception_captured(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# The statement after raise should still get executed
|
||||
"raise BaseException",
|
||||
"print('AFTER BaseException')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "_threads.py" not in err
|
||||
assert "_repl.py" not in err
|
||||
assert "AFTER BaseException" in out
|
||||
|
||||
|
||||
async def test_exc_group_captured(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# The statement after raise should still get executed
|
||||
"raise ExceptionGroup('', [KeyError()])",
|
||||
"print('AFTER ExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, _err = capsys.readouterr()
|
||||
assert "AFTER ExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_base_exception_capture_from_coroutine(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"async def async_func_raises_base_exception():",
|
||||
" raise BaseException",
|
||||
"",
|
||||
# This will raise, but the statement after should still
|
||||
# be executed
|
||||
"await async_func_raises_base_exception()",
|
||||
"print('AFTER BaseException')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "_threads.py" not in err
|
||||
assert "_repl.py" not in err
|
||||
assert "AFTER BaseException" in out
|
||||
|
||||
|
||||
def test_main_entrypoint() -> None:
|
||||
"""
|
||||
Basic smoke test when running via the package __main__ entrypoint.
|
||||
"""
|
||||
repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()")
|
||||
assert repl.returncode == 0
|
||||
|
||||
|
||||
def should_try_newline_injection() -> bool:
|
||||
if sys.platform != "linux":
|
||||
return False
|
||||
|
||||
sysctl = pathlib.Path("/proc/sys/dev/tty/legacy_tiocsti")
|
||||
if not sysctl.exists(): # pragma: no cover
|
||||
return True
|
||||
|
||||
else:
|
||||
return sysctl.read_text() == "1"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not should_try_newline_injection(),
|
||||
reason="the ioctl we use is disabled in CI",
|
||||
)
|
||||
def test_ki_newline_injection() -> None: # TODO: test this line
|
||||
# TODO: we want to remove this functionality, eg by using vendored
|
||||
# pyrepls.
|
||||
assert sys.platform != "win32"
|
||||
|
||||
import pty
|
||||
|
||||
# NOTE: this cannot be subprocess.Popen because pty.fork
|
||||
# does some magic to set the controlling terminal.
|
||||
# (which I don't know how to replicate... so I copied this
|
||||
# structure from pty.spawn...)
|
||||
pid, pty_fd = pty.fork() # type: ignore[attr-defined,unused-ignore]
|
||||
if pid == 0:
|
||||
os.execlp(sys.executable, *[sys.executable, "-u", "-m", "trio"])
|
||||
|
||||
# setup:
|
||||
buffer = b""
|
||||
while not buffer.endswith(b"import trio\r\n>>> "):
|
||||
buffer += os.read(pty_fd, 4096)
|
||||
|
||||
# sanity check:
|
||||
print(buffer.decode())
|
||||
buffer = b""
|
||||
os.write(pty_fd, b'print("hello!")\n')
|
||||
while not buffer.endswith(b">>> "):
|
||||
buffer += os.read(pty_fd, 4096)
|
||||
|
||||
assert buffer.count(b"hello!") == 2
|
||||
|
||||
# press ctrl+c
|
||||
print(buffer.decode())
|
||||
buffer = b""
|
||||
os.kill(pid, signal.SIGINT)
|
||||
while not buffer.endswith(b">>> "):
|
||||
buffer += os.read(pty_fd, 4096)
|
||||
|
||||
assert b"KeyboardInterrupt" in buffer
|
||||
|
||||
# press ctrl+c later
|
||||
print(buffer.decode())
|
||||
buffer = b""
|
||||
os.write(pty_fd, b'print("hello!")')
|
||||
os.kill(pid, signal.SIGINT)
|
||||
while not buffer.endswith(b">>> "):
|
||||
buffer += os.read(pty_fd, 4096)
|
||||
|
||||
assert b"KeyboardInterrupt" in buffer
|
||||
print(buffer.decode())
|
||||
os.close(pty_fd)
|
||||
os.waitpid(pid, 0)[1]
|
||||
|
||||
|
||||
async def test_ki_in_repl() -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
proc = await nursery.start(
|
||||
partial(
|
||||
trio.run_process,
|
||||
[sys.executable, "-u", "-m", "trio"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE,
|
||||
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if sys.platform == "win32" else 0, # type: ignore[attr-defined,unused-ignore]
|
||||
)
|
||||
)
|
||||
|
||||
async with proc.stdout:
|
||||
# setup
|
||||
buffer = b""
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
# TODO: consider making run_process stdout have some universal newlines thing
|
||||
if buffer.replace(b"\r\n", b"\n").endswith(b"import trio\n>>> "):
|
||||
break
|
||||
|
||||
# ensure things work
|
||||
print(buffer.decode())
|
||||
buffer = b""
|
||||
await proc.stdin.send_all(b'print("hello!")\n')
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
assert b"hello!" in buffer
|
||||
print(buffer.decode())
|
||||
|
||||
# this seems to be necessary on Windows for reasons
|
||||
# (the parents of process groups ignore ctrl+c by default...)
|
||||
if sys.platform == "win32":
|
||||
buffer = b""
|
||||
await proc.stdin.send_all(
|
||||
b"import ctypes; ctypes.windll.kernel32.SetConsoleCtrlHandler(None, False)\n"
|
||||
)
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
print(buffer.decode())
|
||||
|
||||
# try to decrease flakiness...
|
||||
buffer = b""
|
||||
await proc.stdin.send_all(
|
||||
b"import coverage; trio.lowlevel.enable_ki_protection(coverage.pytracer.PyTracer._trace)\n"
|
||||
)
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
print(buffer.decode())
|
||||
|
||||
# ensure that ctrl+c on a prompt works
|
||||
# NOTE: for some reason, signal.SIGINT doesn't work for this test.
|
||||
# Using CTRL_C_EVENT is also why we need subprocess.CREATE_NEW_PROCESS_GROUP
|
||||
signal_sent = signal.CTRL_C_EVENT if sys.platform == "win32" else signal.SIGINT # type: ignore[attr-defined,unused-ignore]
|
||||
os.kill(proc.pid, signal_sent)
|
||||
if sys.platform == "win32":
|
||||
# we rely on EOFError which... doesn't happen with pipes.
|
||||
# I'm not sure how to fix it...
|
||||
await proc.stdin.send_all(b"\n")
|
||||
else:
|
||||
# we test injection separately
|
||||
await proc.stdin.send_all(b"\n")
|
||||
|
||||
buffer = b""
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
assert b"KeyboardInterrupt" in buffer
|
||||
|
||||
# ensure ctrl+c while a command runs works
|
||||
print(buffer.decode())
|
||||
await proc.stdin.send_all(b'print("READY"); await trio.sleep_forever()\n')
|
||||
killed = False
|
||||
buffer = b""
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.replace(b"\r\n", b"\n").endswith(b"READY\n") and not killed:
|
||||
os.kill(proc.pid, signal_sent)
|
||||
killed = True
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
assert b"trio" in buffer
|
||||
assert b"KeyboardInterrupt" in buffer
|
||||
|
||||
# make sure it works for sync commands too
|
||||
# (though this would be hard to break)
|
||||
print(buffer.decode())
|
||||
await proc.stdin.send_all(
|
||||
b'import time; print("READY"); time.sleep(99999)\n'
|
||||
)
|
||||
killed = False
|
||||
buffer = b""
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.replace(b"\r\n", b"\n").endswith(b"READY\n") and not killed:
|
||||
os.kill(proc.pid, signal_sent)
|
||||
killed = True
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
assert b"Traceback" in buffer
|
||||
assert b"KeyboardInterrupt" in buffer
|
||||
|
||||
print(buffer.decode())
|
||||
|
||||
# kill the process
|
||||
nursery.cancel_scope.cancel()
|
||||
Reference in New Issue
Block a user