initial commit
This commit is contained in:
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python3
|
||||
"""This is a file that wraps calls to `pyright --verifytypes`, achieving two things:
|
||||
1. give an error if docstrings are missing.
|
||||
pyright will give a number of missing docstrings, and error messages, but not exit with a non-zero value.
|
||||
2. filter out specific errors we don't care about.
|
||||
this is largely due to 1, but also because Trio does some very complex stuff and --verifytypes has few to no ways of ignoring specific errors.
|
||||
|
||||
If this check is giving you false alarms, you can ignore them by adding logic to `has_docstring_at_runtime`, in the main loop in `check_type`, or by updating the json file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# this file is not run as part of the tests, instead it's run standalone from check.sh
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
|
||||
# not needed if everything is working, but if somebody does something to generate
|
||||
# tons of errors, we can be nice and stop them from getting 3*tons of output
|
||||
printed_diagnostics: set[str] = set()
|
||||
|
||||
|
||||
# TODO: consider checking manually without `--ignoreexternal`, and/or
|
||||
# removing it from the below call later on.
|
||||
def run_pyright(platform: str) -> subprocess.CompletedProcess[bytes]:
|
||||
return subprocess.run(
|
||||
[
|
||||
"pyright",
|
||||
# Specify a platform and version to keep imported modules consistent.
|
||||
f"--pythonplatform={platform}",
|
||||
"--pythonversion=3.10",
|
||||
"--verifytypes=trio",
|
||||
"--outputjson",
|
||||
"--ignoreexternal",
|
||||
],
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
|
||||
def has_docstring_at_runtime(name: str) -> bool:
|
||||
"""Pyright gives us an object identifier of xx.yy.zz
|
||||
This function tries to decompose that into its constituent parts, such that we
|
||||
can resolve it, in order to check whether it has a `__doc__` at runtime and
|
||||
verifytypes misses it because we're doing overly fancy stuff.
|
||||
"""
|
||||
# This assert is solely for stopping isort from removing our imports of trio & trio.testing
|
||||
# It could also be done with isort:skip, but that'd also disable import sorting and the like.
|
||||
assert trio.testing is not None
|
||||
|
||||
# figure out what part of the name is the module, so we can "import" it
|
||||
name_parts = name.split(".")
|
||||
assert name_parts[0] == "trio"
|
||||
if name_parts[1] == "tests":
|
||||
return True
|
||||
|
||||
# traverse down the remaining identifiers with getattr
|
||||
obj = trio
|
||||
try:
|
||||
for obj_name in name_parts[1:]:
|
||||
obj = getattr(obj, obj_name)
|
||||
except AttributeError as exc:
|
||||
# asynciowrapper does funky getattr stuff
|
||||
if "AsyncIOWrapper" in str(exc) or name in (
|
||||
# Symbols not existing on all platforms, so we can't dynamically inspect them.
|
||||
# Manually confirmed to have docstrings but pyright doesn't see them due to
|
||||
# export shenanigans. TODO: actually manually confirm that.
|
||||
# In theory we could verify these at runtime, probably by running the script separately
|
||||
# on separate platforms. It might also be a decent idea to work the other way around,
|
||||
# a la test_static_tool_sees_class_members
|
||||
# darwin
|
||||
"trio.lowlevel.current_kqueue",
|
||||
"trio.lowlevel.monitor_kevent",
|
||||
"trio.lowlevel.wait_kevent",
|
||||
"trio._core._io_kqueue._KqueueStatistics",
|
||||
# windows
|
||||
"trio._socket.SocketType.share",
|
||||
"trio._core._io_windows._WindowsStatistics",
|
||||
"trio._core._windows_cffi.Handle",
|
||||
"trio.lowlevel.current_iocp",
|
||||
"trio.lowlevel.monitor_completion_key",
|
||||
"trio.lowlevel.readinto_overlapped",
|
||||
"trio.lowlevel.register_with_iocp",
|
||||
"trio.lowlevel.wait_overlapped",
|
||||
"trio.lowlevel.write_overlapped",
|
||||
"trio.lowlevel.WaitForSingleObject",
|
||||
"trio.socket.fromshare",
|
||||
# linux
|
||||
# this test will fail on linux, but I don't develop on linux. So the next
|
||||
# person to do so is very welcome to open a pull request and populate with
|
||||
# objects
|
||||
# TODO: these are erroring on all platforms, why?
|
||||
"trio._highlevel_generic.StapledStream.send_stream",
|
||||
"trio._highlevel_generic.StapledStream.receive_stream",
|
||||
"trio._ssl.SSLStream.transport_stream",
|
||||
"trio._file_io._HasFileNo",
|
||||
"trio._file_io._HasFileNo.fileno",
|
||||
):
|
||||
return True
|
||||
|
||||
else:
|
||||
print(
|
||||
f"Pyright sees {name} at runtime, but unable to getattr({obj.__name__}, {obj_name}).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
return bool(obj.__doc__)
|
||||
|
||||
|
||||
def check_type(
|
||||
platform: str,
|
||||
full_diagnostics_file: Path | None,
|
||||
expected_errors: list[object],
|
||||
) -> list[object]:
|
||||
# convince isort we use the trio import
|
||||
assert trio is not None
|
||||
|
||||
# run pyright, load output into json
|
||||
res = run_pyright(platform)
|
||||
current_result = json.loads(res.stdout)
|
||||
|
||||
if res.stderr:
|
||||
print(res.stderr, file=sys.stderr)
|
||||
|
||||
if full_diagnostics_file:
|
||||
with open(full_diagnostics_file, "a") as f:
|
||||
json.dump(current_result, f, sort_keys=True, indent=4)
|
||||
|
||||
errors = []
|
||||
|
||||
for symbol in current_result["typeCompleteness"]["symbols"]:
|
||||
diagnostics = symbol["diagnostics"]
|
||||
name = symbol["name"]
|
||||
for diagnostic in diagnostics:
|
||||
message = diagnostic["message"]
|
||||
if name in (
|
||||
"trio._path.PosixPath",
|
||||
"trio._path.WindowsPath",
|
||||
) and message.startswith("Type of base class "):
|
||||
continue
|
||||
|
||||
if name.startswith("trio._path.Path"):
|
||||
if message.startswith("No docstring found for"):
|
||||
continue
|
||||
if message.startswith(
|
||||
"Type is missing type annotation and could be inferred differently by type checkers",
|
||||
):
|
||||
continue
|
||||
|
||||
# ignore errors about missing docstrings if they're available at runtime
|
||||
if message.startswith("No docstring found for"):
|
||||
if has_docstring_at_runtime(symbol["name"]):
|
||||
continue
|
||||
else:
|
||||
# Missing docstring messages include the name of the object.
|
||||
# Other errors don't, so we add it.
|
||||
message = f"{name}: {message}"
|
||||
if message not in expected_errors and message not in printed_diagnostics:
|
||||
print(f"new error: {message}", file=sys.stderr)
|
||||
errors.append(message)
|
||||
printed_diagnostics.add(message)
|
||||
|
||||
continue
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> int:
|
||||
if args.full_diagnostics_file:
|
||||
full_diagnostics_file = Path(args.full_diagnostics_file)
|
||||
full_diagnostics_file.write_text("")
|
||||
else:
|
||||
full_diagnostics_file = None
|
||||
|
||||
errors_by_platform_file = Path(__file__).parent / "_check_type_completeness.json"
|
||||
if errors_by_platform_file.exists():
|
||||
with open(errors_by_platform_file) as f:
|
||||
errors_by_platform = json.load(f)
|
||||
else:
|
||||
errors_by_platform = {"Linux": [], "Windows": [], "Darwin": [], "all": []}
|
||||
|
||||
changed = False
|
||||
for platform in "Linux", "Windows", "Darwin":
|
||||
platform_errors = errors_by_platform[platform] + errors_by_platform["all"]
|
||||
print("*" * 20, f"\nChecking {platform}...")
|
||||
errors = check_type(platform, full_diagnostics_file, platform_errors)
|
||||
|
||||
new_errors = [e for e in errors if e not in platform_errors]
|
||||
missing_errors = [e for e in platform_errors if e not in errors]
|
||||
|
||||
if new_errors:
|
||||
print(
|
||||
f"New errors introduced in `pyright --verifytypes`. Fix them, or ignore them by modifying {errors_by_platform_file}, either manually or with '--overwrite-file'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
changed = True
|
||||
if missing_errors:
|
||||
print(
|
||||
f"Congratulations, you have resolved existing errors! Please remove them from {errors_by_platform_file}, either manually or with '--overwrite-file'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
changed = True
|
||||
print(missing_errors, file=sys.stderr)
|
||||
|
||||
errors_by_platform[platform] = errors
|
||||
print("*" * 20)
|
||||
|
||||
# cut down the size of the json file by a lot, and make it easier to parse for
|
||||
# humans, by moving errors that appear on all platforms to a separate category
|
||||
errors_by_platform["all"] = []
|
||||
for e in errors_by_platform["Linux"].copy():
|
||||
if e in errors_by_platform["Darwin"] and e in errors_by_platform["Windows"]:
|
||||
for platform in "Linux", "Windows", "Darwin":
|
||||
errors_by_platform[platform].remove(e)
|
||||
errors_by_platform["all"].append(e)
|
||||
|
||||
if changed and args.overwrite_file:
|
||||
with open(errors_by_platform_file, "w") as f:
|
||||
json.dump(errors_by_platform, f, indent=4, sort_keys=True)
|
||||
# newline at end of file
|
||||
f.write("\n")
|
||||
|
||||
# True -> 1 -> non-zero exit value -> error
|
||||
return changed
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--overwrite-file",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use this flag to overwrite the current stored results. Either in CI together with a diff check, or to avoid having to manually correct it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-diagnostics-file",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Use this for debugging, it will dump the output of all three pyright runs by platform into this file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert __name__ == "__main__", "This script should be run standalone"
|
||||
sys.exit(main(args))
|
||||
@@ -0,0 +1,22 @@
|
||||
regular = "hi"
|
||||
|
||||
import sys
|
||||
|
||||
from .. import _deprecate
|
||||
|
||||
_deprecate.deprecate_attributes(
|
||||
__name__,
|
||||
{
|
||||
"dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1),
|
||||
"dep2": _deprecate.DeprecatedAttribute(
|
||||
"value2",
|
||||
"1.2",
|
||||
issue=1,
|
||||
instead="instead-string",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
this_mod = sys.modules[__name__]
|
||||
assert this_mod.regular == "hi"
|
||||
assert "dep1" not in globals()
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import NoReturn
|
||||
|
||||
import pytest
|
||||
|
||||
from ..testing import MockClock, trio_test
|
||||
|
||||
RUN_SLOW = True
|
||||
SKIP_OPTIONAL_IMPORTS = False
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
parser.addoption("--run-slow", action="store_true", help="run slow tests")
|
||||
parser.addoption(
|
||||
"--skip-optional-imports",
|
||||
action="store_true",
|
||||
help="skip tests that rely on libraries not required by trio itself",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
global RUN_SLOW
|
||||
RUN_SLOW = config.getoption("--run-slow", default=True)
|
||||
global SKIP_OPTIONAL_IMPORTS
|
||||
SKIP_OPTIONAL_IMPORTS = config.getoption("--skip-optional-imports", default=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clock() -> MockClock:
|
||||
return MockClock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def autojump_clock() -> MockClock:
|
||||
return MockClock(autojump_threshold=0)
|
||||
|
||||
|
||||
# FIXME: split off into a package (or just make part of Trio's public
|
||||
# interface?), with config file to enable? and I guess a mark option too; I
|
||||
# guess it's useful with the class- and file-level marking machinery (where
|
||||
# the raw @trio_test decorator isn't enough).
|
||||
@pytest.hookimpl(tryfirst=True)
|
||||
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> None:
|
||||
if inspect.iscoroutinefunction(pyfuncitem.obj):
|
||||
pyfuncitem.obj = trio_test(pyfuncitem.obj)
|
||||
|
||||
|
||||
def skip_if_optional_else_raise(error: ImportError) -> NoReturn:
|
||||
if SKIP_OPTIONAL_IMPORTS:
|
||||
pytest.skip(error.msg, allow_module_level=True)
|
||||
else: # pragma: no cover
|
||||
raise error
|
||||
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from .. import abc as tabc
|
||||
from ..lowlevel import Task
|
||||
|
||||
|
||||
def test_instrument_implements_hook_methods() -> None:
|
||||
attrs = {
|
||||
"before_run": (),
|
||||
"after_run": (),
|
||||
"task_spawned": (Task,),
|
||||
"task_scheduled": (Task,),
|
||||
"before_task_step": (Task,),
|
||||
"after_task_step": (Task,),
|
||||
"task_exited": (Task,),
|
||||
"before_io_wait": (3.3,),
|
||||
"after_io_wait": (3.3,),
|
||||
}
|
||||
|
||||
mayonnaise = tabc.Instrument()
|
||||
|
||||
for method_name, args in attrs.items():
|
||||
assert hasattr(mayonnaise, method_name)
|
||||
method = getattr(mayonnaise, method_name)
|
||||
assert callable(method)
|
||||
method(*args)
|
||||
|
||||
|
||||
async def test_AsyncResource_defaults() -> None:
|
||||
@attrs.define(slots=False)
|
||||
class MyAR(tabc.AsyncResource):
|
||||
record: list[str] = attrs.Factory(list)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.record.append("ac")
|
||||
|
||||
async with MyAR() as myar:
|
||||
assert isinstance(myar, MyAR)
|
||||
assert myar.record == []
|
||||
|
||||
assert myar.record == ["ac"]
|
||||
|
||||
|
||||
def test_abc_generics() -> None:
|
||||
# Pythons below 3.5.2 had a typing.Generic that would throw
|
||||
# errors when instantiating or subclassing a parameterized
|
||||
# version of a class with any __slots__. This is why RunVar
|
||||
# (which has slots) is not generic. This tests that
|
||||
# the generic ABCs are fine, because while they are slotted
|
||||
# they don't actually define any slots.
|
||||
|
||||
class SlottedChannel(tabc.SendChannel[tabc.Stream]):
|
||||
__slots__ = ("x",)
|
||||
|
||||
def send_nowait(self, value: object) -> None:
|
||||
raise RuntimeError
|
||||
|
||||
async def send(self, value: object) -> None:
|
||||
raise RuntimeError # pragma: no cover
|
||||
|
||||
def clone(self) -> None:
|
||||
raise RuntimeError # pragma: no cover
|
||||
|
||||
async def aclose(self) -> None:
|
||||
pass # pragma: no cover
|
||||
|
||||
channel = SlottedChannel()
|
||||
with pytest.raises(RuntimeError):
|
||||
channel.send_nowait(None)
|
||||
@@ -0,0 +1,750 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio import EndOfChannel, as_safe_channel, open_memory_channel
|
||||
|
||||
from ..testing import assert_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import ExceptionGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
|
||||
async def test_channel() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
open_memory_channel(1.0)
|
||||
with pytest.raises(ValueError, match=r"^max_buffer_size must be >= 0$"):
|
||||
open_memory_channel(-1)
|
||||
|
||||
s, r = open_memory_channel[int | str | None](2)
|
||||
repr(s) # smoke test
|
||||
repr(r) # smoke test
|
||||
|
||||
s.send_nowait(1)
|
||||
with assert_checkpoints():
|
||||
await s.send(2)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(None)
|
||||
|
||||
with assert_checkpoints():
|
||||
assert await r.receive() == 1
|
||||
assert r.receive_nowait() == 2
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
|
||||
s.send_nowait("last")
|
||||
await s.aclose()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send("too late")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait("too late")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.clone()
|
||||
await s.aclose()
|
||||
|
||||
assert r.receive_nowait() == "last"
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
await r.aclose()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.receive_nowait()
|
||||
await r.aclose()
|
||||
|
||||
|
||||
async def test_553(autojump_clock: trio.abc.Clock) -> None:
|
||||
s, r = open_memory_channel[str](1)
|
||||
with trio.move_on_after(10) as timeout_scope:
|
||||
await r.receive()
|
||||
assert timeout_scope.cancelled_caught
|
||||
await s.send("Test for PR #553")
|
||||
|
||||
|
||||
async def test_channel_multiple_producers() -> None:
|
||||
async def producer(send_channel: trio.MemorySendChannel[int], i: int) -> None:
|
||||
# We close our handle when we're done with it
|
||||
async with send_channel:
|
||||
for j in range(3 * i, 3 * (i + 1)):
|
||||
await send_channel.send(j)
|
||||
|
||||
send_channel, receive_channel = open_memory_channel[int](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
# We hand out clones to all the new producers, and then close the
|
||||
# original.
|
||||
async with send_channel:
|
||||
for i in range(10):
|
||||
nursery.start_soon(producer, send_channel.clone(), i)
|
||||
|
||||
got = [value async for value in receive_channel]
|
||||
|
||||
got.sort()
|
||||
assert got == list(range(30))
|
||||
|
||||
|
||||
async def test_channel_multiple_consumers() -> None:
|
||||
successful_receivers = set()
|
||||
received = []
|
||||
|
||||
async def consumer(receive_channel: trio.MemoryReceiveChannel[int], i: int) -> None:
|
||||
async for value in receive_channel:
|
||||
successful_receivers.add(i)
|
||||
received.append(value)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
send_channel, receive_channel = trio.open_memory_channel[int](1)
|
||||
async with send_channel:
|
||||
for i in range(5):
|
||||
nursery.start_soon(consumer, receive_channel, i)
|
||||
await wait_all_tasks_blocked()
|
||||
for i in range(10):
|
||||
await send_channel.send(i)
|
||||
|
||||
assert successful_receivers == set(range(5))
|
||||
assert len(received) == 10
|
||||
assert set(received) == set(range(10))
|
||||
|
||||
|
||||
async def test_close_basics() -> None:
|
||||
async def send_block(
|
||||
s: trio.MemorySendChannel[None],
|
||||
expect: type[BaseException],
|
||||
) -> None:
|
||||
with pytest.raises(expect):
|
||||
await s.send(None)
|
||||
|
||||
# closing send -> other send gets ClosedResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.ClosedResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
await s.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# and receive gets EndOfChannel
|
||||
with pytest.raises(EndOfChannel):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
|
||||
# closing receive -> send gets BrokenResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.BrokenResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
await r.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# closing receive -> other receive gets ClosedResourceError
|
||||
async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
_s2, r2 = open_memory_channel[int](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_block, r2)
|
||||
await wait_all_tasks_blocked()
|
||||
await r2.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r2.receive_nowait()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r2.receive()
|
||||
|
||||
|
||||
async def test_close_sync() -> None:
|
||||
async def send_block(
|
||||
s: trio.MemorySendChannel[None],
|
||||
expect: type[BaseException],
|
||||
) -> None:
|
||||
with pytest.raises(expect):
|
||||
await s.send(None)
|
||||
|
||||
# closing send -> other send gets ClosedResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.ClosedResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
s.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# and receive gets EndOfChannel
|
||||
with pytest.raises(EndOfChannel):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
|
||||
# closing receive -> send gets BrokenResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.BrokenResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
r.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# closing receive -> other receive gets ClosedResourceError
|
||||
async def receive_block(r: trio.MemoryReceiveChannel[None]) -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_block, r)
|
||||
await wait_all_tasks_blocked()
|
||||
r.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
|
||||
async def test_receive_channel_clone_and_close() -> None:
|
||||
s, r = open_memory_channel[None](10)
|
||||
|
||||
r2 = r.clone()
|
||||
r3 = r.clone()
|
||||
|
||||
s.send_nowait(None)
|
||||
await r.aclose()
|
||||
with r2:
|
||||
pass
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.clone()
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r2.clone()
|
||||
|
||||
# Can still send, r3 is still open
|
||||
s.send_nowait(None)
|
||||
|
||||
await r3.aclose()
|
||||
|
||||
# But now the receiver is really closed
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
|
||||
|
||||
async def test_close_multiple_send_handles() -> None:
|
||||
# With multiple send handles, closing one handle only wakes senders on
|
||||
# that handle, but others can continue just fine
|
||||
s1, r = open_memory_channel[str](0)
|
||||
s2 = s1.clone()
|
||||
|
||||
async def send_will_close() -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s1.send("nope")
|
||||
|
||||
async def send_will_succeed() -> None:
|
||||
await s2.send("ok")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_will_close)
|
||||
nursery.start_soon(send_will_succeed)
|
||||
await wait_all_tasks_blocked()
|
||||
await s1.aclose()
|
||||
assert await r.receive() == "ok"
|
||||
|
||||
|
||||
async def test_close_multiple_receive_handles() -> None:
|
||||
# With multiple receive handles, closing one handle only wakes receivers on
|
||||
# that handle, but others can continue just fine
|
||||
s, r1 = open_memory_channel[str](0)
|
||||
r2 = r1.clone()
|
||||
|
||||
async def receive_will_close() -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r1.receive()
|
||||
|
||||
async def receive_will_succeed() -> None:
|
||||
assert await r2.receive() == "ok"
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_will_close)
|
||||
nursery.start_soon(receive_will_succeed)
|
||||
await wait_all_tasks_blocked()
|
||||
await r1.aclose()
|
||||
await s.send("ok")
|
||||
|
||||
|
||||
async def test_inf_capacity() -> None:
|
||||
send, receive = open_memory_channel[int](float("inf"))
|
||||
|
||||
# It's accepted, and we can send all day without blocking
|
||||
with send:
|
||||
for i in range(10):
|
||||
send.send_nowait(i)
|
||||
|
||||
got = [i async for i in receive]
|
||||
assert got == list(range(10))
|
||||
|
||||
|
||||
async def test_statistics() -> None:
|
||||
s, r = open_memory_channel[None](2)
|
||||
|
||||
assert s.statistics() == r.statistics()
|
||||
stats = s.statistics()
|
||||
assert stats.current_buffer_used == 0
|
||||
assert stats.max_buffer_size == 2
|
||||
assert stats.open_send_channels == 1
|
||||
assert stats.open_receive_channels == 1
|
||||
assert stats.tasks_waiting_send == 0
|
||||
assert stats.tasks_waiting_receive == 0
|
||||
|
||||
s.send_nowait(None)
|
||||
assert s.statistics().current_buffer_used == 1
|
||||
|
||||
s2 = s.clone()
|
||||
assert s.statistics().open_send_channels == 2
|
||||
await s.aclose()
|
||||
assert s2.statistics().open_send_channels == 1
|
||||
|
||||
r2 = r.clone()
|
||||
assert s2.statistics().open_receive_channels == 2
|
||||
await r2.aclose()
|
||||
assert s2.statistics().open_receive_channels == 1
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
s2.send_nowait(None) # fill up the buffer
|
||||
assert s.statistics().current_buffer_used == 2
|
||||
nursery.start_soon(s2.send, None)
|
||||
nursery.start_soon(s2.send, None)
|
||||
await wait_all_tasks_blocked()
|
||||
assert s.statistics().tasks_waiting_send == 2
|
||||
nursery.cancel_scope.cancel()
|
||||
assert s.statistics().tasks_waiting_send == 0
|
||||
|
||||
# empty out the buffer again
|
||||
try:
|
||||
while True:
|
||||
r.receive_nowait()
|
||||
except trio.WouldBlock:
|
||||
pass
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(r.receive)
|
||||
await wait_all_tasks_blocked()
|
||||
assert s.statistics().tasks_waiting_receive == 1
|
||||
nursery.cancel_scope.cancel()
|
||||
assert s.statistics().tasks_waiting_receive == 0
|
||||
|
||||
|
||||
async def test_channel_fairness() -> None:
|
||||
# We can remove an item we just sent, and send an item back in after, if
|
||||
# no-one else is waiting.
|
||||
s, r = open_memory_channel[int | None](1)
|
||||
s.send_nowait(1)
|
||||
assert r.receive_nowait() == 1
|
||||
s.send_nowait(2)
|
||||
assert r.receive_nowait() == 2
|
||||
|
||||
# But if someone else is waiting to receive, then they "own" the item we
|
||||
# send, so we can't receive it (even though we run first):
|
||||
|
||||
result: int | None = None
|
||||
|
||||
async def do_receive(r: trio.MemoryReceiveChannel[int | None]) -> None:
|
||||
nonlocal result
|
||||
result = await r.receive()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(do_receive, r)
|
||||
await wait_all_tasks_blocked()
|
||||
s.send_nowait(2)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
assert result == 2
|
||||
|
||||
# And the analogous situation for send: if we free up a space, we can't
|
||||
# immediately send something in it if someone is already waiting to do
|
||||
# that
|
||||
s, r = open_memory_channel[int | None](1)
|
||||
s.send_nowait(1)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(None)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(s.send, 2)
|
||||
await wait_all_tasks_blocked()
|
||||
assert r.receive_nowait() == 1
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(3)
|
||||
assert (await r.receive()) == 2
|
||||
|
||||
|
||||
async def test_unbuffered() -> None:
|
||||
s, r = open_memory_channel[int](0)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(1)
|
||||
|
||||
async def do_send(s: trio.MemorySendChannel[int], v: int) -> None:
|
||||
with assert_checkpoints():
|
||||
await s.send(v)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send, s, 1)
|
||||
with assert_checkpoints():
|
||||
assert await r.receive() == 1
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
|
||||
|
||||
async def test_as_safe_channel_exhaust() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[int]:
|
||||
yield 1
|
||||
|
||||
async with agen() as recv_chan:
|
||||
async for x in recv_chan:
|
||||
assert x == 1
|
||||
|
||||
|
||||
async def test_as_safe_channel_broken_resource() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[int]:
|
||||
yield 1
|
||||
yield 2 # pragma: no cover
|
||||
|
||||
async with agen() as recv_chan:
|
||||
assert await recv_chan.__anext__() == 1
|
||||
|
||||
# close the receiving channel
|
||||
await recv_chan.aclose()
|
||||
|
||||
# trying to get the next element errors
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await recv_chan.__anext__()
|
||||
|
||||
# but we don't get an error on exit of the cm
|
||||
|
||||
|
||||
async def test_as_safe_channel_cancelled() -> None:
|
||||
with trio.CancelScope() as cs:
|
||||
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[None]: # pragma: no cover
|
||||
raise AssertionError(
|
||||
"cancel before consumption means generator should not be iterated"
|
||||
)
|
||||
yield # indicate that we're an iterator
|
||||
|
||||
async with agen():
|
||||
cs.cancel()
|
||||
|
||||
|
||||
async def test_as_safe_channel_no_race() -> None:
|
||||
# this previously led to a race condition due to
|
||||
# https://github.com/python-trio/trio/issues/1559
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[int]:
|
||||
yield 1
|
||||
raise ValueError("oae")
|
||||
|
||||
with pytest.raises(ValueError, match=r"^oae$"):
|
||||
async with agen() as recv_chan:
|
||||
async for x in recv_chan:
|
||||
assert x == 1
|
||||
|
||||
|
||||
async def test_as_safe_channel_buffer_size_too_small(
|
||||
autojump_clock: trio.testing.MockClock,
|
||||
) -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[int]:
|
||||
yield 1
|
||||
raise AssertionError(
|
||||
"buffer size 0 means we shouldn't be asked for another value"
|
||||
) # pragma: no cover
|
||||
|
||||
with trio.move_on_after(5):
|
||||
async with agen() as recv_chan:
|
||||
async for x in recv_chan: # pragma: no branch
|
||||
assert x == 1
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
async def test_as_safe_channel_no_interleave() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[int]:
|
||||
yield 1
|
||||
raise AssertionError # pragma: no cover
|
||||
|
||||
async with agen() as recv_chan:
|
||||
assert await recv_chan.__anext__() == 1
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def test_as_safe_channel_genexit_finally() -> None:
|
||||
@as_safe_channel
|
||||
async def agen(events: list[str]) -> AsyncGenerator[int]:
|
||||
try:
|
||||
yield 1
|
||||
except BaseException as e:
|
||||
events.append(repr(e))
|
||||
raise
|
||||
finally:
|
||||
events.append("finally")
|
||||
raise ValueError("agen")
|
||||
|
||||
events: list[str] = []
|
||||
with pytest.RaisesGroup(
|
||||
pytest.RaisesExc(ValueError, match="^agen$"),
|
||||
pytest.RaisesExc(TypeError, match="^iterator$"),
|
||||
) as g:
|
||||
async with agen(events) as recv_chan:
|
||||
async for i in recv_chan: # pragma: no branch
|
||||
assert i == 1
|
||||
raise TypeError("iterator")
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
assert g.value.__notes__ == [
|
||||
"Encountered exception during cleanup of generator object, as "
|
||||
"well as exception in the contextmanager body - unable to unwrap."
|
||||
]
|
||||
|
||||
assert events == ["GeneratorExit()", "finally"]
|
||||
|
||||
|
||||
async def test_as_safe_channel_nested_loop() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[int]:
|
||||
for i in range(2):
|
||||
yield i
|
||||
|
||||
ii = 0
|
||||
async with agen() as recv_chan1:
|
||||
async for i in recv_chan1:
|
||||
async with agen() as recv_chan:
|
||||
jj = 0
|
||||
async for j in recv_chan:
|
||||
assert (i, j) == (ii, jj)
|
||||
jj += 1
|
||||
ii += 1
|
||||
|
||||
|
||||
async def test_as_safe_channel_doesnt_leak_cancellation() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[None]:
|
||||
yield
|
||||
with trio.CancelScope() as cscope:
|
||||
cscope.cancel()
|
||||
yield
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
async with agen() as recv_chan:
|
||||
async for _ in recv_chan:
|
||||
pass
|
||||
raise AssertionError("should be reachable")
|
||||
|
||||
|
||||
async def test_as_safe_channel_dont_unwrap_user_exceptiongroup() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[None]:
|
||||
raise NotImplementedError("not entered")
|
||||
yield # pragma: no cover
|
||||
|
||||
with pytest.RaisesGroup(pytest.RaisesExc(ValueError, match="bar"), match="foo"):
|
||||
async with agen() as _:
|
||||
raise ExceptionGroup("foo", [ValueError("bar")])
|
||||
|
||||
|
||||
async def test_as_safe_channel_multiple_receiver() -> None:
|
||||
event = trio.Event()
|
||||
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[int]:
|
||||
await event.wait()
|
||||
yield 0
|
||||
yield 1
|
||||
|
||||
async def handle_value(
|
||||
recv_chan: trio.abc.ReceiveChannel[int],
|
||||
value: int,
|
||||
task_status: trio.TaskStatus,
|
||||
) -> None:
|
||||
task_status.started()
|
||||
assert await recv_chan.receive() == value
|
||||
|
||||
async with agen() as recv_chan:
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(handle_value, recv_chan, 0)
|
||||
await nursery.start(handle_value, recv_chan, 1)
|
||||
event.set()
|
||||
|
||||
|
||||
async def test_as_safe_channel_multi_cancel() -> None:
|
||||
@as_safe_channel
|
||||
async def agen(events: list[str]) -> AsyncGenerator[None]:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# this will give a warning of ASYNC120, although it's not technically a
|
||||
# problem of swallowing existing exceptions
|
||||
try:
|
||||
await trio.lowlevel.checkpoint()
|
||||
except trio.Cancelled:
|
||||
events.append("agen cancel")
|
||||
raise
|
||||
|
||||
events: list[str] = []
|
||||
with trio.CancelScope() as cs:
|
||||
with pytest.raises(trio.Cancelled):
|
||||
async with agen(events) as recv_chan:
|
||||
async for _ in recv_chan: # pragma: no branch
|
||||
cs.cancel()
|
||||
try:
|
||||
await trio.lowlevel.checkpoint()
|
||||
except trio.Cancelled:
|
||||
events.append("body cancel")
|
||||
raise
|
||||
assert events == ["body cancel", "agen cancel"]
|
||||
|
||||
|
||||
async def test_as_safe_channel_genexit_exception_group() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[None]:
|
||||
try:
|
||||
async with trio.open_nursery():
|
||||
yield
|
||||
except BaseException as e:
|
||||
assert pytest.RaisesGroup(GeneratorExit).matches(e) # noqa: PT017
|
||||
raise
|
||||
|
||||
async with agen() as g:
|
||||
async for _ in g:
|
||||
break
|
||||
|
||||
|
||||
async def test_as_safe_channel_does_not_suppress_nested_genexit() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[None]:
|
||||
yield
|
||||
|
||||
with pytest.RaisesGroup(GeneratorExit):
|
||||
async with agen() as g, trio.open_nursery():
|
||||
await g.receive() # this is for coverage reasons
|
||||
raise GeneratorExit
|
||||
|
||||
|
||||
async def test_as_safe_channel_genexit_filter() -> None:
|
||||
async def wait_then_raise() -> None:
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
except trio.Cancelled:
|
||||
raise ValueError from None
|
||||
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[None]:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_then_raise)
|
||||
yield
|
||||
|
||||
with pytest.RaisesGroup(ValueError):
|
||||
async with agen() as g:
|
||||
async for _ in g:
|
||||
break
|
||||
|
||||
|
||||
async def test_as_safe_channel_swallowing_extra_exceptions() -> None:
|
||||
async def wait_then_raise(ex: type[BaseException]) -> None:
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
except trio.Cancelled:
|
||||
raise ex from None
|
||||
|
||||
@as_safe_channel
|
||||
async def agen(ex: type[BaseException]) -> AsyncGenerator[None]:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_then_raise, ex)
|
||||
nursery.start_soon(wait_then_raise, GeneratorExit)
|
||||
yield
|
||||
|
||||
with pytest.RaisesGroup(AssertionError):
|
||||
async with agen(GeneratorExit) as g:
|
||||
async for _ in g:
|
||||
break
|
||||
|
||||
with pytest.RaisesGroup(ValueError, AssertionError):
|
||||
async with agen(ValueError) as g:
|
||||
async for _ in g:
|
||||
break
|
||||
|
||||
|
||||
async def test_as_safe_channel_close_between_iteration() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[None]:
|
||||
while True:
|
||||
yield
|
||||
|
||||
async with agen() as chan, trio.open_nursery() as nursery:
|
||||
|
||||
async def close_channel() -> None:
|
||||
await trio.lowlevel.checkpoint()
|
||||
await chan.aclose()
|
||||
|
||||
nursery.start_soon(close_channel)
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
async for _ in chan:
|
||||
pass
|
||||
|
||||
|
||||
async def test_as_safe_channel_close_before_iteration() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[None]:
|
||||
raise AssertionError("should be unreachable") # pragma: no cover
|
||||
yield # pragma: no cover
|
||||
|
||||
async with agen() as chan:
|
||||
await chan.aclose()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await chan.receive()
|
||||
|
||||
|
||||
async def test_as_safe_channel_close_during_iteration() -> None:
|
||||
@as_safe_channel
|
||||
async def agen() -> AsyncGenerator[None]:
|
||||
yield
|
||||
await chan.aclose()
|
||||
while True:
|
||||
yield
|
||||
|
||||
async with agen() as chan:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
async for _ in chan:
|
||||
pass
|
||||
|
||||
# This is necessary to ensure that `chan` has been sent
|
||||
# to. Otherwise, this test sometimes passes on a broken
|
||||
# version of trio.
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
|
||||
from .. import _core
|
||||
|
||||
trio_testing_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar(
|
||||
"trio_testing_contextvar",
|
||||
)
|
||||
|
||||
|
||||
async def test_contextvars_default() -> None:
|
||||
trio_testing_contextvar.set("main")
|
||||
record: list[str] = []
|
||||
|
||||
async def child() -> None:
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
assert record == ["main"]
|
||||
|
||||
|
||||
async def test_contextvars_set() -> None:
|
||||
trio_testing_contextvar.set("main")
|
||||
record: list[str] = []
|
||||
|
||||
async def child() -> None:
|
||||
trio_testing_contextvar.set("child")
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
value = trio_testing_contextvar.get()
|
||||
assert record == ["child"]
|
||||
assert value == "main"
|
||||
|
||||
|
||||
async def test_contextvars_copy() -> None:
|
||||
trio_testing_contextvar.set("main")
|
||||
context = contextvars.copy_context()
|
||||
trio_testing_contextvar.set("second_main")
|
||||
record: list[str] = []
|
||||
|
||||
async def child() -> None:
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
context.run(nursery.start_soon, child)
|
||||
nursery.start_soon(child)
|
||||
value = trio_testing_contextvar.get()
|
||||
assert set(record) == {"main", "second_main"}
|
||||
assert value == "second_main"
|
||||
@@ -0,0 +1,277 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
|
||||
from .._deprecate import (
|
||||
TrioDeprecationWarning,
|
||||
deprecated,
|
||||
deprecated_alias,
|
||||
warn_deprecated,
|
||||
)
|
||||
from . import module_with_deprecations
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recwarn_always(recwarn: pytest.WarningsRecorder) -> pytest.WarningsRecorder:
|
||||
warnings.simplefilter("always")
|
||||
# ResourceWarnings about unclosed sockets can occur nondeterministically
|
||||
# (during GC) which throws off the tests in this file
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
return recwarn
|
||||
|
||||
|
||||
def _here() -> tuple[str, int]:
|
||||
frame = inspect.currentframe()
|
||||
assert frame is not None
|
||||
assert frame.f_back is not None
|
||||
info = inspect.getframeinfo(frame.f_back)
|
||||
return (info.filename, info.lineno)
|
||||
|
||||
|
||||
def test_warn_deprecated(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
def deprecated_thing() -> None:
|
||||
warn_deprecated("ice", "1.2", issue=1, instead="water")
|
||||
|
||||
deprecated_thing()
|
||||
filename, lineno = _here()
|
||||
assert len(recwarn_always) == 1
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "ice is deprecated" in got.message.args[0]
|
||||
assert "Trio 1.2" in got.message.args[0]
|
||||
assert "water instead" in got.message.args[0]
|
||||
assert "/issues/1" in got.message.args[0]
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno - 1
|
||||
|
||||
|
||||
def test_warn_deprecated_no_instead_or_issue(
|
||||
recwarn_always: pytest.WarningsRecorder,
|
||||
) -> None:
|
||||
# Explicitly no instead or issue
|
||||
warn_deprecated("water", "1.3", issue=None, instead=None)
|
||||
assert len(recwarn_always) == 1
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "water is deprecated" in got.message.args[0]
|
||||
assert "no replacement" in got.message.args[0]
|
||||
assert "Trio 1.3" in got.message.args[0]
|
||||
|
||||
|
||||
def test_warn_deprecated_stacklevel(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
def nested1() -> None:
|
||||
nested2()
|
||||
|
||||
def nested2() -> None:
|
||||
warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3)
|
||||
|
||||
filename, lineno = _here()
|
||||
nested1()
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno + 1
|
||||
|
||||
|
||||
def old() -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def new() -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def test_warn_deprecated_formatting(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
warn_deprecated(old, "1.0", issue=1, instead=new)
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.old is deprecated" in got.message.args[0]
|
||||
assert "test_deprecate.new instead" in got.message.args[0]
|
||||
|
||||
|
||||
@deprecated("1.5", issue=123, instead=new)
|
||||
def deprecated_old() -> int:
|
||||
return 3
|
||||
|
||||
|
||||
def test_deprecated_decorator(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
assert deprecated_old() == 3
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0]
|
||||
assert "1.5" in got.message.args[0]
|
||||
assert "test_deprecate.new" in got.message.args[0]
|
||||
assert "issues/123" in got.message.args[0]
|
||||
|
||||
|
||||
class Foo:
|
||||
@deprecated("1.0", issue=123, instead="crying")
|
||||
def method(self) -> int:
|
||||
return 7
|
||||
|
||||
|
||||
def test_deprecated_decorator_method(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
f = Foo()
|
||||
assert f.method() == 7
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.Foo.method is deprecated" in got.message.args[0]
|
||||
|
||||
|
||||
@deprecated("1.2", thing="the thing", issue=None, instead=None)
|
||||
def deprecated_with_thing() -> int:
|
||||
return 72
|
||||
|
||||
|
||||
def test_deprecated_decorator_with_explicit_thing(
|
||||
recwarn_always: pytest.WarningsRecorder,
|
||||
) -> None:
|
||||
assert deprecated_with_thing() == 72
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "the thing is deprecated" in got.message.args[0]
|
||||
|
||||
|
||||
def new_hotness() -> str:
|
||||
return "new hotness"
|
||||
|
||||
|
||||
old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1)
|
||||
|
||||
|
||||
def test_deprecated_alias(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
assert old_hotness() == "new hotness"
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.old_hotness is deprecated" in got.message.args[0]
|
||||
assert "1.23" in got.message.args[0]
|
||||
assert "test_deprecate.new_hotness instead" in got.message.args[0]
|
||||
assert "issues/1" in got.message.args[0]
|
||||
|
||||
assert isinstance(old_hotness.__doc__, str)
|
||||
assert ".. deprecated:: 1.23" in old_hotness.__doc__
|
||||
assert "test_deprecate.new_hotness instead" in old_hotness.__doc__
|
||||
assert "issues/1>`__" in old_hotness.__doc__
|
||||
|
||||
|
||||
class Alias:
|
||||
def new_hotness_method(self) -> str:
|
||||
return "new hotness method"
|
||||
|
||||
old_hotness_method = deprecated_alias(
|
||||
"Alias.old_hotness_method",
|
||||
new_hotness_method,
|
||||
"3.21",
|
||||
issue=1,
|
||||
)
|
||||
|
||||
|
||||
def test_deprecated_alias_method(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
obj = Alias()
|
||||
assert obj.old_hotness_method() == "new hotness method"
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
msg = got.message.args[0]
|
||||
assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg
|
||||
assert "test_deprecate.Alias.new_hotness_method instead" in msg
|
||||
|
||||
|
||||
@deprecated("2.1", issue=1, instead="hi")
|
||||
def docstring_test1() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=None, instead="hi")
|
||||
def docstring_test2() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=1, instead=None)
|
||||
def docstring_test3() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=None, instead=None)
|
||||
def docstring_test4() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
def test_deprecated_docstring_munging() -> None:
|
||||
assert docstring_test1.__doc__ == """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
Use hi instead.
|
||||
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
|
||||
|
||||
"""
|
||||
|
||||
assert docstring_test2.__doc__ == """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
Use hi instead.
|
||||
|
||||
"""
|
||||
|
||||
assert docstring_test3.__doc__ == """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
|
||||
|
||||
"""
|
||||
|
||||
assert docstring_test4.__doc__ == """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
assert module_with_deprecations.regular == "hi"
|
||||
assert len(recwarn_always) == 0
|
||||
|
||||
assert type(module_with_deprecations) is ModuleType
|
||||
|
||||
filename, lineno = _here()
|
||||
assert module_with_deprecations.dep1 == "value1" # type: ignore[attr-defined]
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno + 1
|
||||
|
||||
assert "module_with_deprecations.dep1" in got.message.args[0]
|
||||
assert "Trio 1.1" in got.message.args[0]
|
||||
assert "/issues/1" in got.message.args[0]
|
||||
assert "value1 instead" in got.message.args[0]
|
||||
|
||||
assert module_with_deprecations.dep2 == "value2" # type: ignore[attr-defined]
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "instead-string instead" in got.message.args[0]
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
module_with_deprecations.asdf # type: ignore[attr-defined] # noqa: B018 # "useless expression"
|
||||
|
||||
|
||||
def test_warning_class() -> None:
|
||||
with pytest.deprecated_call():
|
||||
warn_deprecated("foo", "bar", issue=None, instead=None)
|
||||
|
||||
# essentially the same as the above check
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match="^foo is deprecated since Trio bar with no replacement$",
|
||||
):
|
||||
warn_deprecated("foo", "bar", issue=None, instead=None)
|
||||
|
||||
with pytest.warns(TrioDeprecationWarning):
|
||||
warn_deprecated(
|
||||
"foo",
|
||||
"bar",
|
||||
issue=None,
|
||||
instead=None,
|
||||
use_triodeprecationwarning=True,
|
||||
)
|
||||
+64
@@ -0,0 +1,64 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
|
||||
|
||||
async def test_deprecation_warning_open_nursery() -> None:
|
||||
with pytest.warns(
|
||||
trio.TrioDeprecationWarning,
|
||||
match="strict_exception_groups=False",
|
||||
) as record:
|
||||
async with trio.open_nursery(strict_exception_groups=False):
|
||||
...
|
||||
assert len(record) == 1
|
||||
async with trio.open_nursery(strict_exception_groups=True):
|
||||
...
|
||||
async with trio.open_nursery():
|
||||
...
|
||||
|
||||
|
||||
def test_deprecation_warning_run() -> None:
|
||||
async def foo() -> None: ...
|
||||
|
||||
async def foo_nursery() -> None:
|
||||
# this should not raise a warning, even if it's implied loose
|
||||
async with trio.open_nursery():
|
||||
...
|
||||
|
||||
async def foo_loose_nursery() -> None:
|
||||
# this should raise a warning, even if specifying the parameter is redundant
|
||||
async with trio.open_nursery(strict_exception_groups=False):
|
||||
...
|
||||
|
||||
def helper(fun: Callable[[], Awaitable[None]], num: int) -> None:
|
||||
with pytest.warns(
|
||||
trio.TrioDeprecationWarning,
|
||||
match="strict_exception_groups=False",
|
||||
) as record:
|
||||
trio.run(fun, strict_exception_groups=False)
|
||||
assert len(record) == num
|
||||
|
||||
helper(foo, 1)
|
||||
helper(foo_nursery, 1)
|
||||
helper(foo_loose_nursery, 2)
|
||||
|
||||
|
||||
def test_deprecation_warning_start_guest_run() -> None:
|
||||
# "The simplest possible "host" loop."
|
||||
from .._core._tests.test_guest_mode import trivial_guest_run
|
||||
|
||||
async def trio_return(in_host: object) -> str:
|
||||
await trio.lowlevel.checkpoint()
|
||||
return "ok"
|
||||
|
||||
with pytest.warns(
|
||||
trio.TrioDeprecationWarning,
|
||||
match="strict_exception_groups=False",
|
||||
) as record:
|
||||
trivial_guest_run(
|
||||
trio_return,
|
||||
strict_exception_groups=False,
|
||||
)
|
||||
assert len(record) == 1
|
||||
@@ -0,0 +1,950 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from contextlib import asynccontextmanager
|
||||
from itertools import count
|
||||
from typing import TYPE_CHECKING, NoReturn
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from trio._tests.pytest_plugin import skip_if_optional_else_raise
|
||||
|
||||
try:
|
||||
import trustme
|
||||
from OpenSSL import SSL
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
from trio import DTLSChannel, DTLSEndpoint
|
||||
from trio.testing._fake_net import FakeNet, UDPPacket
|
||||
|
||||
from .._core._tests.tutil import binds_ipv6, gc_collect_harder, slow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
ca = trustme.CA()
|
||||
server_cert = ca.issue_cert("example.com")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_ctx() -> SSL.Context:
|
||||
ctx = SSL.Context(SSL.DTLS_METHOD)
|
||||
server_cert.configure_cert(ctx)
|
||||
return ctx
|
||||
|
||||
|
||||
def client_ctx_fn() -> SSL.Context:
|
||||
ctx = SSL.Context(SSL.DTLS_METHOD)
|
||||
ca.configure_trust(ctx)
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_ctx() -> SSL.Context:
|
||||
return client_ctx_fn()
|
||||
|
||||
|
||||
parametrize_ipv6 = pytest.mark.parametrize(
|
||||
"ipv6",
|
||||
[False, pytest.param(True, marks=binds_ipv6)],
|
||||
ids=["ipv4", "ipv6"],
|
||||
)
|
||||
|
||||
|
||||
def endpoint(**kwargs: int | bool) -> DTLSEndpoint:
|
||||
ipv6 = kwargs.pop("ipv6", False)
|
||||
family = trio.socket.AF_INET6 if ipv6 else trio.socket.AF_INET
|
||||
sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family)
|
||||
return DTLSEndpoint(sock, **kwargs)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def dtls_echo_server(
|
||||
*,
|
||||
server_ctx: SSL.Context,
|
||||
autocancel: bool = True,
|
||||
mtu: int | None = None,
|
||||
ipv6: bool = False,
|
||||
) -> AsyncGenerator[tuple[DTLSEndpoint, tuple[str, int]], None]:
|
||||
with endpoint(ipv6=ipv6) as server:
|
||||
localhost = "::1" if ipv6 else "127.0.0.1"
|
||||
await server.socket.bind((localhost, 0))
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def echo_handler(dtls_channel: DTLSChannel) -> None:
|
||||
print(
|
||||
"echo handler started: "
|
||||
f"server {dtls_channel.endpoint.socket.getsockname()!r} "
|
||||
f"client {dtls_channel.peer_address!r}",
|
||||
)
|
||||
if mtu is not None:
|
||||
dtls_channel.set_ciphertext_mtu(mtu)
|
||||
try:
|
||||
print("server starting do_handshake")
|
||||
await dtls_channel.do_handshake()
|
||||
print("server finished do_handshake")
|
||||
# no branch for leaving this for loop because we only leave
|
||||
# a channel by cancellation.
|
||||
async for packet in dtls_channel: # pragma: no branch
|
||||
print(f"echoing {packet!r} -> {dtls_channel.peer_address!r}")
|
||||
await dtls_channel.send(packet)
|
||||
except trio.BrokenResourceError: # pragma: no cover
|
||||
print("echo handler channel broken")
|
||||
|
||||
await nursery.start(server.serve, server_ctx, echo_handler)
|
||||
|
||||
yield server, server.socket.getsockname()
|
||||
|
||||
if autocancel:
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@parametrize_ipv6
|
||||
async def test_smoke(
|
||||
ipv6: bool, server_ctx: SSL.Context, client_ctx: SSL.Context
|
||||
) -> None:
|
||||
async with dtls_echo_server(ipv6=ipv6, server_ctx=server_ctx) as (
|
||||
_server_endpoint,
|
||||
address,
|
||||
):
|
||||
with endpoint(ipv6=ipv6) as client_endpoint:
|
||||
client_channel = client_endpoint.connect(address, client_ctx)
|
||||
with pytest.raises(trio.NeedHandshakeError):
|
||||
client_channel.get_cleartext_mtu()
|
||||
|
||||
await client_channel.do_handshake()
|
||||
await client_channel.send(b"hello")
|
||||
assert await client_channel.receive() == b"hello"
|
||||
await client_channel.send(b"goodbye")
|
||||
assert await client_channel.receive() == b"goodbye"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"^openssl doesn't support sending empty DTLS packets$",
|
||||
):
|
||||
await client_channel.send(b"")
|
||||
|
||||
client_channel.set_ciphertext_mtu(1234)
|
||||
cleartext_mtu_1234 = client_channel.get_cleartext_mtu()
|
||||
client_channel.set_ciphertext_mtu(4321)
|
||||
assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234
|
||||
client_channel.set_ciphertext_mtu(1234)
|
||||
assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234
|
||||
|
||||
|
||||
@slow
|
||||
async def test_handshake_over_terrible_network(
|
||||
autojump_clock: trio.testing.MockClock,
|
||||
server_ctx: SSL.Context,
|
||||
) -> None:
|
||||
HANDSHAKES = 100
|
||||
r = random.Random(0)
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
# avoid spurious timeouts on slow machines
|
||||
autojump_clock.autojump_threshold = 0.001
|
||||
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (_, address):
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def route_packet(packet: UDPPacket) -> None:
|
||||
while True:
|
||||
op = r.choices(
|
||||
["deliver", "drop", "dupe", "delay"],
|
||||
weights=[0.7, 0.1, 0.1, 0.1],
|
||||
)[0]
|
||||
print(f"{packet.source} -> {packet.destination}: {op}")
|
||||
if op == "drop":
|
||||
return
|
||||
elif op == "dupe":
|
||||
fn.send_packet(packet)
|
||||
elif op == "delay":
|
||||
await trio.sleep(r.random() * 3)
|
||||
# I wanted to test random packet corruption too, but it turns out
|
||||
# openssl has a bug in the following scenario:
|
||||
#
|
||||
# - client sends ClientHello
|
||||
# - server sends HelloVerifyRequest with cookie -- but cookie is
|
||||
# invalid b/c either the ClientHello or HelloVerifyRequest was
|
||||
# corrupted
|
||||
# - client re-sends ClientHello with invalid cookie
|
||||
# - server replies with new HelloVerifyRequest and correct cookie
|
||||
#
|
||||
# At this point, the client *should* switch to the new, valid
|
||||
# cookie. But OpenSSL doesn't; it stubbornly insists on re-sending
|
||||
# the original, invalid cookie over and over. In theory we could
|
||||
# work around this by detecting cookie changes and starting over
|
||||
# with a whole new SSL object, but (a) it doesn't seem worth it, (b)
|
||||
# when I tried then I ran into another issue where OpenSSL got stuck
|
||||
# in an infinite loop sending alerts over and over, which I didn't
|
||||
# dig into because see (a).
|
||||
#
|
||||
# elif op == "distort":
|
||||
# payload = bytearray(packet.payload)
|
||||
# payload[r.randrange(len(payload))] ^= 1 << r.randrange(8)
|
||||
# packet = attrs.evolve(packet, payload=payload)
|
||||
else:
|
||||
assert op == "deliver"
|
||||
print(
|
||||
f"{packet.source} -> {packet.destination}: delivered"
|
||||
f" {packet.payload.hex()}",
|
||||
)
|
||||
fn.deliver_packet(packet)
|
||||
break
|
||||
|
||||
def route_packet_wrapper(packet: UDPPacket) -> None:
|
||||
try: # noqa: SIM105 # suppressible-exception
|
||||
nursery.start_soon(route_packet, packet)
|
||||
except RuntimeError: # pragma: no cover
|
||||
# We're exiting the nursery, so any remaining packets can just get
|
||||
# dropped
|
||||
pass
|
||||
|
||||
fn.route_packet = route_packet_wrapper # type: ignore[assignment] # TODO: Fix FakeNet typing
|
||||
|
||||
for i in range(HANDSHAKES):
|
||||
print("#" * 80)
|
||||
print("#" * 80)
|
||||
print("#" * 80)
|
||||
with endpoint() as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx_fn())
|
||||
print("client starting do_handshake")
|
||||
await client.do_handshake()
|
||||
print("client finished do_handshake")
|
||||
msg = str(i).encode()
|
||||
# Make multiple attempts to send data, because the network might
|
||||
# drop it
|
||||
while True:
|
||||
with trio.move_on_after(10) as cscope:
|
||||
await client.send(msg)
|
||||
assert await client.receive() == msg
|
||||
if not cscope.cancelled_caught:
|
||||
break
|
||||
|
||||
|
||||
async def test_implicit_handshake(
|
||||
server_ctx: SSL.Context, client_ctx: SSL.Context
|
||||
) -> None:
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
|
||||
# Implicit handshake
|
||||
await client.send(b"xyz")
|
||||
assert await client.receive() == b"xyz"
|
||||
|
||||
|
||||
async def test_full_duplex(server_ctx: SSL.Context, client_ctx: SSL.Context) -> None:
|
||||
# Tests simultaneous send/receive, and also multiple methods implicitly invoking
|
||||
# do_handshake simultaneously.
|
||||
with endpoint() as server_endpoint, endpoint() as client_endpoint:
|
||||
await server_endpoint.socket.bind(("127.0.0.1", 0))
|
||||
async with trio.open_nursery() as server_nursery:
|
||||
|
||||
async def handler(channel: DTLSChannel) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(channel.send, b"from server")
|
||||
nursery.start_soon(channel.receive)
|
||||
|
||||
await server_nursery.start(server_endpoint.serve, server_ctx, handler)
|
||||
|
||||
client = client_endpoint.connect(
|
||||
server_endpoint.socket.getsockname(),
|
||||
client_ctx,
|
||||
)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(client.send, b"from client")
|
||||
nursery.start_soon(client.receive)
|
||||
|
||||
server_nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_channel_closing(
|
||||
server_ctx: SSL.Context, client_ctx: SSL.Context
|
||||
) -> None:
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
await client.do_handshake()
|
||||
client.close()
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client.send(b"abc")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client.receive()
|
||||
|
||||
# close is idempotent
|
||||
client.close()
|
||||
# can also aclose
|
||||
await client.aclose()
|
||||
|
||||
|
||||
async def test_serve_exits_cleanly_on_close(server_ctx: SSL.Context) -> None:
|
||||
async with dtls_echo_server(autocancel=False, server_ctx=server_ctx) as (
|
||||
server_endpoint,
|
||||
_address,
|
||||
):
|
||||
server_endpoint.close()
|
||||
# Testing that the nursery exits even without being cancelled
|
||||
# close is idempotent
|
||||
server_endpoint.close()
|
||||
|
||||
|
||||
async def test_client_multiplex(server_ctx: SSL.Context) -> None:
|
||||
async with (
|
||||
dtls_echo_server(server_ctx=server_ctx) as (_, address1),
|
||||
dtls_echo_server(server_ctx=server_ctx) as (_, address2),
|
||||
):
|
||||
with endpoint() as client_endpoint:
|
||||
client1 = client_endpoint.connect(address1, client_ctx_fn())
|
||||
client2 = client_endpoint.connect(address2, client_ctx_fn())
|
||||
|
||||
await client1.send(b"abc")
|
||||
await client2.send(b"xyz")
|
||||
assert await client2.receive() == b"xyz"
|
||||
assert await client1.receive() == b"abc"
|
||||
|
||||
client_endpoint.close()
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client1.send(b"xxx")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client2.receive()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
client_endpoint.connect(address1, client_ctx_fn())
|
||||
|
||||
async def null_handler(_: object) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await nursery.start(client_endpoint.serve, server_ctx, null_handler)
|
||||
|
||||
|
||||
async def test_dtls_over_dgram_only() -> None:
|
||||
with trio.socket.socket() as s:
|
||||
with pytest.raises(ValueError, match=r"^DTLS requires a SOCK_DGRAM socket$"):
|
||||
DTLSEndpoint(s)
|
||||
|
||||
|
||||
async def test_double_serve(server_ctx: SSL.Context) -> None:
|
||||
async def null_handler(_: object) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
with endpoint() as server_endpoint:
|
||||
await server_endpoint.socket.bind(("127.0.0.1", 0))
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||||
with pytest.raises(trio.BusyResourceError):
|
||||
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_connect_to_non_server(
|
||||
autojump_clock: trio.abc.Clock, client_ctx: SSL.Context
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
with endpoint() as client1, endpoint() as client2:
|
||||
await client1.socket.bind(("127.0.0.1", 0))
|
||||
# This should just time out
|
||||
with trio.move_on_after(100) as cscope:
|
||||
channel = client2.connect(client1.socket.getsockname(), client_ctx)
|
||||
await channel.do_handshake()
|
||||
assert cscope.cancelled_caught
|
||||
|
||||
|
||||
@pytest.mark.parametrize("buffer_size", [10, 20])
|
||||
async def test_incoming_buffer_overflow(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
server_ctx: SSL.Context,
|
||||
client_ctx: SSL.Context,
|
||||
buffer_size: int,
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (_, address):
|
||||
with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint:
|
||||
assert client_endpoint.incoming_packets_buffer == buffer_size
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
for i in range(buffer_size + 15):
|
||||
await client.send(str(i).encode())
|
||||
await trio.sleep(1)
|
||||
stats = client.statistics()
|
||||
assert stats.incoming_packets_dropped_in_trio == 15
|
||||
for i in range(buffer_size):
|
||||
assert await client.receive() == str(i).encode()
|
||||
await client.send(b"buffer clear now")
|
||||
assert await client.receive() == b"buffer clear now"
|
||||
|
||||
|
||||
async def test_server_socket_doesnt_crash_on_garbage(
|
||||
autojump_clock: trio.abc.Clock, server_ctx: SSL.Context
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
from trio._dtls import (
|
||||
ContentType,
|
||||
HandshakeFragment,
|
||||
HandshakeType,
|
||||
ProtocolVersion,
|
||||
Record,
|
||||
encode_handshake_fragment,
|
||||
encode_record,
|
||||
)
|
||||
|
||||
client_hello = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=encode_handshake_fragment(
|
||||
HandshakeFragment(
|
||||
msg_type=HandshakeType.client_hello,
|
||||
msg_len=10,
|
||||
msg_seq=0,
|
||||
frag_offset=0,
|
||||
frag_len=10,
|
||||
frag=bytes(10),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
client_hello_extended = client_hello + b"\x00"
|
||||
client_hello_short = client_hello[:-1]
|
||||
# cuts off in middle of handshake message header
|
||||
client_hello_really_short = client_hello[:14]
|
||||
client_hello_corrupt_record_len = bytearray(client_hello)
|
||||
client_hello_corrupt_record_len[11] = 0xFF
|
||||
|
||||
client_hello_fragmented = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=encode_handshake_fragment(
|
||||
HandshakeFragment(
|
||||
msg_type=HandshakeType.client_hello,
|
||||
msg_len=20,
|
||||
msg_seq=0,
|
||||
frag_offset=0,
|
||||
frag_len=10,
|
||||
frag=bytes(10),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
client_hello_trailing_data_in_record = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=encode_handshake_fragment(
|
||||
HandshakeFragment(
|
||||
msg_type=HandshakeType.client_hello,
|
||||
msg_len=20,
|
||||
msg_seq=0,
|
||||
frag_offset=0,
|
||||
frag_len=10,
|
||||
frag=bytes(10),
|
||||
),
|
||||
)
|
||||
+ b"\x00",
|
||||
),
|
||||
)
|
||||
|
||||
handshake_empty = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=b"",
|
||||
),
|
||||
)
|
||||
|
||||
client_hello_truncated_in_cookie = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=bytes(2 + 32 + 1) + b"\xff",
|
||||
),
|
||||
)
|
||||
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (_, address):
|
||||
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock:
|
||||
for bad_packet in [
|
||||
b"",
|
||||
b"xyz",
|
||||
client_hello_extended,
|
||||
client_hello_short,
|
||||
client_hello_really_short,
|
||||
client_hello_corrupt_record_len,
|
||||
client_hello_fragmented,
|
||||
client_hello_trailing_data_in_record,
|
||||
handshake_empty,
|
||||
client_hello_truncated_in_cookie,
|
||||
]:
|
||||
await sock.sendto(bad_packet, address)
|
||||
await trio.sleep(1)
|
||||
|
||||
|
||||
async def test_invalid_cookie_rejected(
|
||||
autojump_clock: trio.abc.Clock, server_ctx: SSL.Context, client_ctx: SSL.Context
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
from trio._dtls import BadPacket, decode_client_hello_untrusted
|
||||
|
||||
with trio.CancelScope() as cscope:
|
||||
# the first 11 bytes of ClientHello aren't protected by the cookie, so only test
|
||||
# corrupting bytes after that.
|
||||
offset_to_corrupt = count(11)
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
try:
|
||||
_, cookie, _ = decode_client_hello_untrusted(packet.payload)
|
||||
except BadPacket:
|
||||
pass
|
||||
else:
|
||||
if len(cookie) != 0:
|
||||
# this is a challenge response packet
|
||||
# let's corrupt the next offset so the handshake should fail
|
||||
payload = bytearray(packet.payload)
|
||||
offset = next(offset_to_corrupt)
|
||||
if offset >= len(payload):
|
||||
# We've tried all offsets. Clamp offset to the end of the
|
||||
# payload, and terminate the test.
|
||||
offset = len(payload) - 1
|
||||
cscope.cancel()
|
||||
payload[offset] ^= 0x01
|
||||
packet = attrs.evolve(packet, payload=payload)
|
||||
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO: Fix FakeNet typing
|
||||
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (_, address):
|
||||
while True:
|
||||
with endpoint() as client:
|
||||
channel = client.connect(address, client_ctx)
|
||||
await channel.do_handshake()
|
||||
assert cscope.cancelled_caught
|
||||
|
||||
|
||||
async def test_client_cancels_handshake_and_starts_new_one(
|
||||
autojump_clock: trio.abc.Clock, server_ctx: SSL.Context
|
||||
) -> None:
|
||||
# if a client disappears during the handshake, and then starts a new handshake from
|
||||
# scratch, then the first handler's channel should fail, and a new handler get
|
||||
# started
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
with endpoint() as server, endpoint() as client:
|
||||
await server.socket.bind(("127.0.0.1", 0))
|
||||
async with trio.open_nursery() as nursery:
|
||||
first_time = True
|
||||
|
||||
async def handler(channel: DTLSChannel) -> None:
|
||||
nonlocal first_time
|
||||
if first_time:
|
||||
first_time = False
|
||||
print("handler: first time, cancelling connect")
|
||||
connect_cscope.cancel()
|
||||
await trio.sleep(0.5)
|
||||
print("handler: handshake should fail now")
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await channel.do_handshake()
|
||||
else:
|
||||
print("handler: not first time, sending hello")
|
||||
await channel.send(b"hello")
|
||||
|
||||
await nursery.start(server.serve, server_ctx, handler)
|
||||
|
||||
print("client: starting first connect")
|
||||
with trio.CancelScope() as connect_cscope:
|
||||
channel = client.connect(server.socket.getsockname(), client_ctx_fn())
|
||||
await channel.do_handshake()
|
||||
assert connect_cscope.cancelled_caught
|
||||
|
||||
print("client: starting second connect")
|
||||
channel = client.connect(server.socket.getsockname(), client_ctx_fn())
|
||||
assert await channel.receive() == b"hello"
|
||||
|
||||
# Give handlers a chance to finish
|
||||
await trio.sleep(10)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_swap_client_server(server_ctx: SSL.Context) -> None:
|
||||
with endpoint() as a, endpoint() as b:
|
||||
await a.socket.bind(("127.0.0.1", 0))
|
||||
await b.socket.bind(("127.0.0.1", 0))
|
||||
|
||||
async def echo_handler(channel: DTLSChannel) -> None:
|
||||
async for packet in channel:
|
||||
await channel.send(packet)
|
||||
|
||||
async def crashing_echo_handler(channel: DTLSChannel) -> None:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await echo_handler(channel)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(a.serve, server_ctx, crashing_echo_handler)
|
||||
await nursery.start(b.serve, server_ctx, echo_handler)
|
||||
|
||||
b_to_a = b.connect(a.socket.getsockname(), client_ctx_fn())
|
||||
await b_to_a.send(b"b as client")
|
||||
assert await b_to_a.receive() == b"b as client"
|
||||
|
||||
a_to_b = a.connect(b.socket.getsockname(), client_ctx_fn())
|
||||
await a_to_b.do_handshake()
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await b_to_a.send(b"association broken")
|
||||
await a_to_b.send(b"a as client")
|
||||
assert await a_to_b.receive() == b"a as client"
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@slow
|
||||
async def test_openssl_retransmit_doesnt_break_stuff(
|
||||
server_ctx: SSL.Context, client_ctx: SSL.Context
|
||||
) -> None:
|
||||
# can't use autojump_clock here, because the point of the test is to wait for
|
||||
# openssl's built-in retransmit timer to expire, which is hard-coded to use
|
||||
# wall-clock time.
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
blackholed = True
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
if blackholed:
|
||||
print("dropped packet", packet)
|
||||
return
|
||||
print("delivered packet", packet)
|
||||
# packets.append(
|
||||
# scapy.all.IP(
|
||||
# src=packet.source.ip.compressed, dst=packet.destination.ip.compressed
|
||||
# )
|
||||
# / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port)
|
||||
# / packet.payload
|
||||
# )
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (server_endpoint, address):
|
||||
with endpoint() as client_endpoint:
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def connecter() -> None:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
await client.do_handshake(initial_retransmit_timeout=1.5)
|
||||
await client.send(b"hi")
|
||||
assert await client.receive() == b"hi"
|
||||
|
||||
nursery.start_soon(connecter)
|
||||
|
||||
# openssl's default timeout is 1 second, so this ensures that it thinks
|
||||
# the timeout has expired
|
||||
await trio.sleep(1.1)
|
||||
# disable blackholing and send a garbage packet to wake up openssl so it
|
||||
# notices the timeout has expired
|
||||
blackholed = False
|
||||
await server_endpoint.socket.sendto(
|
||||
b"xxx",
|
||||
client_endpoint.socket.getsockname(),
|
||||
)
|
||||
# now the client task should finish connecting and exit cleanly
|
||||
|
||||
# scapy.all.wrpcap("/tmp/trace.pcap", packets)
|
||||
|
||||
|
||||
async def test_initial_retransmit_timeout_configuration(
|
||||
autojump_clock: trio.abc.Clock, server_ctx: SSL.Context
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
blackholed = True
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
nonlocal blackholed
|
||||
if blackholed:
|
||||
blackholed = False
|
||||
else:
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO: add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (_, address):
|
||||
for t in [1, 2, 4]:
|
||||
with endpoint() as client:
|
||||
before = trio.current_time()
|
||||
blackholed = True
|
||||
channel = client.connect(address, client_ctx_fn())
|
||||
await channel.do_handshake(initial_retransmit_timeout=t)
|
||||
after = trio.current_time()
|
||||
assert after - before == t
|
||||
|
||||
|
||||
async def test_explicit_tiny_mtu_is_respected(
|
||||
server_ctx: SSL.Context, client_ctx: SSL.Context
|
||||
) -> None:
|
||||
# ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to
|
||||
# be larger than that. (300 is still smaller than any real network though.)
|
||||
MTU = 300
|
||||
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
print(f"delivering {packet}")
|
||||
print(f"payload size: {len(packet.payload)}")
|
||||
assert len(packet.payload) <= MTU
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server(mtu=MTU, server_ctx=server_ctx) as (_server, address):
|
||||
with endpoint() as client:
|
||||
channel = client.connect(address, client_ctx)
|
||||
channel.set_ciphertext_mtu(MTU)
|
||||
await channel.do_handshake()
|
||||
await channel.send(b"hi")
|
||||
assert await channel.receive() == b"hi"
|
||||
|
||||
|
||||
@parametrize_ipv6
|
||||
async def test_handshake_handles_minimum_network_mtu(
|
||||
ipv6: bool,
|
||||
autojump_clock: trio.abc.Clock,
|
||||
server_ctx: SSL.Context,
|
||||
client_ctx: SSL.Context,
|
||||
) -> None:
|
||||
# Fake network that has the minimum allowable MTU for whatever protocol we're using.
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
mtu = 1280 - 48 if ipv6 else 576 - 28
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
if len(packet.payload) > mtu:
|
||||
print(f"dropping {packet}")
|
||||
else:
|
||||
print(f"delivering {packet}")
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO: add type annotations for FakeNet
|
||||
|
||||
# See if we can successfully do a handshake -- some of the volleys will get dropped,
|
||||
# and the retransmit logic should detect this and back off the MTU to something
|
||||
# smaller until it succeeds.
|
||||
async with dtls_echo_server(ipv6=ipv6, server_ctx=server_ctx) as (_, address):
|
||||
with endpoint(ipv6=ipv6) as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
# the handshake mtu backoff shouldn't affect the return value from
|
||||
# get_cleartext_mtu, b/c that's under the user's control via
|
||||
# set_ciphertext_mtu
|
||||
client.set_ciphertext_mtu(9999)
|
||||
await client.send(b"xyz")
|
||||
assert await client.receive() == b"xyz"
|
||||
assert client.get_cleartext_mtu() > 9000 # as vegeta said
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
async def test_system_task_cleaned_up_on_gc(client_ctx: SSL.Context) -> None:
|
||||
before_tasks = trio.lowlevel.current_statistics().tasks_living
|
||||
|
||||
# We put this into a sub-function so that everything automatically becomes garbage
|
||||
# when the frame exits. For some reason just doing 'del e' wasn't enough on pypy
|
||||
# with coverage enabled -- I think we were hitting this bug:
|
||||
# https://foss.heptapod.net/pypy/pypy/-/issues/3656
|
||||
async def start_and_forget_endpoint() -> int:
|
||||
e = endpoint()
|
||||
|
||||
# This connection/handshake attempt can't succeed. The only purpose is to force
|
||||
# the endpoint to set up a receive loop.
|
||||
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
c = e.connect(s.getsockname(), client_ctx)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(c.do_handshake)
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
during_tasks = trio.lowlevel.current_statistics().tasks_living
|
||||
return during_tasks
|
||||
|
||||
with pytest.warns(ResourceWarning): # noqa: PT031
|
||||
during_tasks = await start_and_forget_endpoint()
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
gc_collect_harder()
|
||||
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
after_tasks = trio.lowlevel.current_statistics().tasks_living
|
||||
assert before_tasks < during_tasks
|
||||
assert before_tasks == after_tasks
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
async def test_gc_before_system_task_starts() -> None:
|
||||
e = endpoint()
|
||||
|
||||
with pytest.warns(ResourceWarning): # noqa: PT031
|
||||
del e
|
||||
gc_collect_harder()
|
||||
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
async def test_gc_as_packet_received() -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
e = endpoint()
|
||||
await e.socket.bind(("127.0.0.1", 0))
|
||||
e._ensure_receive_loop()
|
||||
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
|
||||
await s.sendto(b"xxx", e.socket.getsockname())
|
||||
# At this point, the endpoint's receive loop has been marked runnable because it
|
||||
# just received a packet; closing the endpoint socket won't interrupt that. But by
|
||||
# the time it wakes up to process the packet, the endpoint will be gone.
|
||||
with pytest.warns(ResourceWarning): # noqa: PT031
|
||||
del e
|
||||
gc_collect_harder()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
def test_gc_after_trio_exits() -> None:
|
||||
async def main() -> DTLSEndpoint:
|
||||
# We use fakenet just to make sure no real sockets can leak out of the test
|
||||
# case - on pypy somehow the socket was outliving the gc_collect_harder call
|
||||
# below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode
|
||||
# when called after trio exits, it doesn't need a real socket.
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
return endpoint()
|
||||
|
||||
e = trio.run(main)
|
||||
with pytest.warns(ResourceWarning): # noqa: PT031
|
||||
del e
|
||||
gc_collect_harder()
|
||||
|
||||
|
||||
async def test_already_closed_socket_doesnt_crash() -> None:
|
||||
with endpoint() as e:
|
||||
# We close the socket before checkpointing, so the socket will already be closed
|
||||
# when the system task starts up
|
||||
e.socket.close()
|
||||
# Now give it a chance to start up, and hopefully not crash
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
|
||||
async def test_socket_closed_while_processing_clienthello(
|
||||
autojump_clock: trio.abc.Clock, server_ctx: SSL.Context, client_ctx: SSL.Context
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
# Check what happens if the socket is discovered to be closed when sending a
|
||||
# HelloVerifyRequest, since that has its own sending logic
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (server, address):
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
fn.deliver_packet(packet)
|
||||
server.socket.close()
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
with endpoint() as client_endpoint:
|
||||
with trio.move_on_after(10):
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
await client.do_handshake()
|
||||
|
||||
|
||||
async def test_association_replaced_while_handshake_running(
|
||||
autojump_clock: trio.abc.Clock, server_ctx: SSL.Context
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
pass
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO: add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
# TODO: should this have the same exact client_ctx?
|
||||
c1 = client_endpoint.connect(address, client_ctx_fn())
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def doomed_handshake() -> None:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await c1.do_handshake()
|
||||
|
||||
nursery.start_soon(doomed_handshake)
|
||||
|
||||
await trio.sleep(10)
|
||||
|
||||
client_endpoint.connect(address, client_ctx_fn())
|
||||
|
||||
|
||||
async def test_association_replaced_before_handshake_starts(
|
||||
server_ctx: SSL.Context,
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
# This test shouldn't send any packets
|
||||
def route_packet(packet: UDPPacket) -> NoReturn: # pragma: no cover
|
||||
raise AssertionError()
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
# TODO: should this use the same client_ctx?
|
||||
c1 = client_endpoint.connect(address, client_ctx_fn())
|
||||
client_endpoint.connect(address, client_ctx_fn())
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await c1.do_handshake()
|
||||
|
||||
|
||||
async def test_send_to_closed_local_port(server_ctx: SSL.Context) -> None:
|
||||
# On Windows, sending a UDP packet to a closed local port can cause a weird
|
||||
# ECONNRESET error later, inside the receive task. Make sure we're handling it
|
||||
# properly.
|
||||
async with dtls_echo_server(server_ctx=server_ctx) as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(1, 10):
|
||||
channel = client_endpoint.connect(("127.0.0.1", i), client_ctx_fn())
|
||||
nursery.start_soon(channel.do_handshake)
|
||||
channel = client_endpoint.connect(address, client_ctx_fn())
|
||||
await channel.send(b"xxx")
|
||||
assert await channel.receive() == b"xxx"
|
||||
nursery.cancel_scope.cancel()
|
||||
@@ -0,0 +1,626 @@
|
||||
from __future__ import annotations # isort: split
|
||||
|
||||
import __future__ # Regular import, not special!
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import socket as stdlib_socket
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path, PurePath
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
from trio._tests.pytest_plugin import RUN_SLOW, skip_if_optional_else_raise
|
||||
|
||||
from .. import _core, _util
|
||||
from .._core._tests.tutil import slow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
mypy_cache_updated = False
|
||||
|
||||
|
||||
try: # If installed, check both versions of this class.
|
||||
from typing_extensions import Protocol as Protocol_ext
|
||||
except ImportError: # pragma: no cover
|
||||
Protocol_ext = Protocol
|
||||
|
||||
|
||||
def _ensure_mypy_cache_updated() -> None:
|
||||
# This pollutes the `empty` dir. Should this be changed?
|
||||
try:
|
||||
from mypy.api import run
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
global mypy_cache_updated
|
||||
if not mypy_cache_updated:
|
||||
# mypy cache was *probably* already updated by the other tests,
|
||||
# but `pytest -k ...` might run just this test on its own
|
||||
result = run(
|
||||
[
|
||||
"--config-file=",
|
||||
"--cache-dir=./.mypy_cache",
|
||||
"--no-error-summary",
|
||||
"-c",
|
||||
"import trio",
|
||||
],
|
||||
)
|
||||
assert not result[1] # stderr
|
||||
assert not result[0] # stdout
|
||||
mypy_cache_updated = True
|
||||
|
||||
|
||||
def test_core_is_properly_reexported() -> None:
|
||||
# Each export from _core should be re-exported by exactly one of these
|
||||
# three modules:
|
||||
sources = [trio, trio.lowlevel, trio.testing]
|
||||
for symbol in dir(_core):
|
||||
if symbol.startswith("_"):
|
||||
continue
|
||||
found = 0
|
||||
for source in sources:
|
||||
if symbol in dir(source) and getattr(source, symbol) is getattr(
|
||||
_core,
|
||||
symbol,
|
||||
):
|
||||
found += 1
|
||||
print(symbol, found)
|
||||
assert found == 1
|
||||
|
||||
|
||||
def class_is_final(cls: type) -> bool:
|
||||
"""Check if a class cannot be subclassed."""
|
||||
try:
|
||||
# new_class() handles metaclasses properly, type(...) does not.
|
||||
types.new_class("SubclassTester", (cls,))
|
||||
except TypeError:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def iter_modules(
|
||||
module: types.ModuleType,
|
||||
only_public: bool,
|
||||
) -> Iterator[types.ModuleType]:
|
||||
yield module
|
||||
for name, class_ in module.__dict__.items():
|
||||
if name.startswith("_") and only_public:
|
||||
continue
|
||||
if not isinstance(class_, ModuleType):
|
||||
continue
|
||||
if not class_.__name__.startswith(module.__name__): # pragma: no cover
|
||||
continue
|
||||
if class_ is module: # pragma: no cover
|
||||
continue
|
||||
yield from iter_modules(class_, only_public)
|
||||
|
||||
|
||||
PUBLIC_MODULES = list(iter_modules(trio, only_public=True))
|
||||
ALL_MODULES = list(iter_modules(trio, only_public=False))
|
||||
PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES]
|
||||
|
||||
|
||||
# It doesn't make sense for downstream redistributors to run this test, since
|
||||
# they might be using a newer version of Python with additional symbols which
|
||||
# won't be reflected in trio.socket, and this shouldn't cause downstream test
|
||||
# runs to start failing.
|
||||
@pytest.mark.redistributors_should_skip
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info[:4] == (3, 14, 0, "beta"),
|
||||
# 12 pass, 16 fail
|
||||
reason="several tools don't support 3.14",
|
||||
)
|
||||
# Static analysis tools often have trouble with alpha releases, where Python's
|
||||
# internals are in flux, grammar may not have settled down, etc.
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info.releaselevel == "alpha",
|
||||
reason="skip static introspection tools on Python dev/alpha releases",
|
||||
)
|
||||
@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES)
|
||||
@pytest.mark.parametrize("tool", ["pylint", "jedi", "mypy", "pyright_verifytypes"])
|
||||
@pytest.mark.filterwarnings(
|
||||
# https://github.com/pypa/setuptools/issues/3274
|
||||
"ignore:module 'sre_constants' is deprecated:DeprecationWarning",
|
||||
)
|
||||
def test_static_tool_sees_all_symbols(tool: str, modname: str, tmp_path: Path) -> None:
|
||||
module = importlib.import_module(modname)
|
||||
|
||||
def no_underscores(symbols: Iterable[str]) -> set[str]:
|
||||
return {symbol for symbol in symbols if not symbol.startswith("_")}
|
||||
|
||||
runtime_names = no_underscores(dir(module))
|
||||
|
||||
# ignore deprecated module `tests` being invisible
|
||||
if modname == "trio":
|
||||
runtime_names.discard("tests")
|
||||
|
||||
# Ignore any __future__ feature objects, if imported under that name.
|
||||
for name in __future__.all_feature_names:
|
||||
if getattr(module, name, None) is getattr(__future__, name):
|
||||
runtime_names.remove(name)
|
||||
|
||||
if tool == "pylint":
|
||||
try:
|
||||
from pylint.lint import PyLinter
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
linter = PyLinter()
|
||||
assert module.__file__ is not None
|
||||
ast = linter.get_ast(module.__file__, modname)
|
||||
static_names = no_underscores(ast) # type: ignore[arg-type]
|
||||
elif tool == "jedi":
|
||||
if sys.implementation.name != "cpython":
|
||||
pytest.skip("jedi does not support pypy")
|
||||
|
||||
try:
|
||||
import jedi
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
# Simulate typing "import trio; trio.<TAB>"
|
||||
script = jedi.Script(f"import {modname}; {modname}.")
|
||||
completions = script.complete()
|
||||
static_names = no_underscores(c.name for c in completions)
|
||||
elif tool == "mypy":
|
||||
if sys.implementation.name != "cpython":
|
||||
# https://github.com/python/mypy/issues/20329
|
||||
pytest.skip("mypy does not support pypy")
|
||||
|
||||
if not RUN_SLOW: # pragma: no cover
|
||||
pytest.skip("use --run-slow to check against mypy")
|
||||
|
||||
cache = Path.cwd() / ".mypy_cache"
|
||||
|
||||
_ensure_mypy_cache_updated()
|
||||
|
||||
trio_cache = next(cache.glob("*/trio"))
|
||||
_, modname = (modname + ".").split(".", 1)
|
||||
modname = modname[:-1]
|
||||
mod_cache = trio_cache / modname if modname else trio_cache
|
||||
if mod_cache.is_dir(): # pragma: no coverage
|
||||
mod_cache = mod_cache / "__init__.data.json"
|
||||
else:
|
||||
mod_cache = trio_cache / (modname + ".data.json")
|
||||
|
||||
assert mod_cache.exists()
|
||||
assert mod_cache.is_file()
|
||||
with mod_cache.open() as cache_file:
|
||||
cache_json = json.loads(cache_file.read())
|
||||
static_names = no_underscores(
|
||||
key
|
||||
for key, value in cache_json["names"].items()
|
||||
if not key.startswith(".") and value["kind"] == "Gdef"
|
||||
)
|
||||
elif tool == "pyright_verifytypes":
|
||||
if not RUN_SLOW: # pragma: no cover
|
||||
pytest.skip("use --run-slow to check against pyright")
|
||||
|
||||
try:
|
||||
import pyright # noqa: F401
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
import subprocess
|
||||
|
||||
res = subprocess.run(
|
||||
["pyright", f"--verifytypes={modname}", "--outputjson"],
|
||||
capture_output=True,
|
||||
)
|
||||
current_result = json.loads(res.stdout)
|
||||
|
||||
static_names = {
|
||||
x["name"][len(modname) + 1 :]
|
||||
for x in current_result["typeCompleteness"]["symbols"]
|
||||
if x["name"].startswith(modname)
|
||||
}
|
||||
else: # pragma: no cover
|
||||
raise AssertionError()
|
||||
|
||||
# It's expected that the static set will contain more names than the
|
||||
# runtime set:
|
||||
# - static tools are sometimes sloppy and include deleted names
|
||||
# - some symbols are platform-specific at runtime, but always show up in
|
||||
# static analysis (e.g. in trio.socket or trio.lowlevel)
|
||||
# So we check that the runtime names are a subset of the static names.
|
||||
missing_names = runtime_names - static_names
|
||||
|
||||
# ignore warnings about deprecated module tests
|
||||
missing_names -= {"tests"}
|
||||
|
||||
if missing_names: # pragma: no cover
|
||||
print(f"{tool} can't see the following names in {modname}:")
|
||||
print()
|
||||
for name in sorted(missing_names):
|
||||
print(f" {name}")
|
||||
raise AssertionError()
|
||||
|
||||
|
||||
@slow
|
||||
# see comment on test_static_tool_sees_all_symbols
|
||||
@pytest.mark.redistributors_should_skip
|
||||
# Static analysis tools often have trouble with alpha releases, where Python's
|
||||
# internals are in flux, grammar may not have settled down, etc.
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info.releaselevel == "alpha",
|
||||
reason="skip static introspection tools on Python dev/alpha releases",
|
||||
)
|
||||
@pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES)
|
||||
@pytest.mark.parametrize("tool", ["jedi", "mypy"])
|
||||
def test_static_tool_sees_class_members(
|
||||
tool: str,
|
||||
module_name: str,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)]
|
||||
|
||||
# ignore hidden, but not dunder, symbols
|
||||
def no_hidden(symbols: Iterable[str]) -> set[str]:
|
||||
return {
|
||||
symbol
|
||||
for symbol in symbols
|
||||
if (not symbol.startswith("_")) or symbol.startswith("__")
|
||||
}
|
||||
|
||||
if tool == "jedi" and sys.implementation.name != "cpython":
|
||||
pytest.skip("jedi does not support pypy")
|
||||
|
||||
if tool == "mypy" and sys.implementation.name != "cpython":
|
||||
# https://github.com/python/mypy/issues/20329
|
||||
pytest.skip("mypy does not support pypy")
|
||||
|
||||
if tool == "mypy":
|
||||
cache = Path.cwd() / ".mypy_cache"
|
||||
|
||||
_ensure_mypy_cache_updated()
|
||||
|
||||
trio_cache = next(cache.glob("*/trio"))
|
||||
modname = module_name
|
||||
_, modname = (modname + ".").split(".", 1)
|
||||
modname = modname[:-1]
|
||||
mod_cache = trio_cache / modname if modname else trio_cache
|
||||
if mod_cache.is_dir():
|
||||
mod_cache = mod_cache / "__init__.data.json"
|
||||
else:
|
||||
mod_cache = trio_cache / (modname + ".data.json")
|
||||
|
||||
assert mod_cache.exists()
|
||||
assert mod_cache.is_file()
|
||||
with mod_cache.open() as cache_file:
|
||||
cache_json = json.loads(cache_file.read())
|
||||
|
||||
# skip a bunch of file-system activity (probably can un-memoize?)
|
||||
@functools.lru_cache
|
||||
def lookup_symbol(symbol: str) -> dict[str, Any]: # type: ignore[misc, explicit-any]
|
||||
topname, *modname, name = symbol.split(".")
|
||||
version = next(cache.glob("3.*/"))
|
||||
mod_cache = version / topname
|
||||
if not mod_cache.is_dir():
|
||||
mod_cache = version / (topname + ".data.json")
|
||||
|
||||
if modname:
|
||||
for piece in modname[:-1]:
|
||||
mod_cache /= piece
|
||||
next_cache = mod_cache / modname[-1]
|
||||
if next_cache.is_dir(): # pragma: no coverage
|
||||
mod_cache = next_cache / "__init__.data.json"
|
||||
else:
|
||||
mod_cache = mod_cache / (modname[-1] + ".data.json")
|
||||
elif mod_cache.is_dir():
|
||||
mod_cache /= "__init__.data.json"
|
||||
with mod_cache.open() as f:
|
||||
return json.loads(f.read())["names"][name] # type: ignore[no-any-return]
|
||||
|
||||
errors: dict[str, object] = {}
|
||||
for class_name, class_ in module.__dict__.items():
|
||||
if not isinstance(class_, type):
|
||||
continue
|
||||
if module_name == "trio.socket" and class_name in dir(stdlib_socket):
|
||||
continue
|
||||
|
||||
# Ignore classes that don't use attrs, they only define their members once
|
||||
# __init__ is called (and reason they don't use attrs is because they're going
|
||||
# to be reimplemented in pytest).
|
||||
# Not 100% that's the case, and it works locally, so whatever /shrug
|
||||
if module_name == "trio.testing" and class_name in ("_RaisesGroup", "_Matcher"):
|
||||
continue
|
||||
|
||||
# dir() and inspect.getmembers doesn't display properties from the metaclass
|
||||
# also ignore some dunder methods that tend to differ but are of no consequence
|
||||
ignore_names = set(dir(type(class_))) | {
|
||||
"__annotations__",
|
||||
"__attrs_attrs__",
|
||||
"__attrs_own_setattr__",
|
||||
"__callable_proto_members_only__",
|
||||
"__class_getitem__",
|
||||
"__final__",
|
||||
"__getstate__",
|
||||
"__match_args__",
|
||||
"__order__",
|
||||
"__orig_bases__",
|
||||
"__parameters__",
|
||||
"__protocol_attrs__",
|
||||
"__setstate__",
|
||||
"__slots__",
|
||||
"__weakref__",
|
||||
# ignore errors about dunders inherited from stdlib that tools might
|
||||
# not see
|
||||
"__copy__",
|
||||
"__deepcopy__",
|
||||
}
|
||||
|
||||
if type(class_) is type:
|
||||
# C extension classes don't have these dunders, but Python classes do
|
||||
ignore_names.add("__firstlineno__")
|
||||
ignore_names.add("__static_attributes__")
|
||||
|
||||
# inspect.getmembers sees `name` and `value` in Enums, otherwise
|
||||
# it behaves the same way as `dir`
|
||||
# runtime_names = no_underscores(dir(class_))
|
||||
runtime_names = (
|
||||
no_hidden(x[0] for x in inspect.getmembers(class_)) - ignore_names
|
||||
)
|
||||
|
||||
if tool == "jedi":
|
||||
try:
|
||||
import jedi
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
script = jedi.Script(
|
||||
f"from {module_name} import {class_name}; {class_name}.",
|
||||
)
|
||||
completions = script.complete()
|
||||
static_names = no_hidden(c.name for c in completions) - ignore_names
|
||||
|
||||
elif tool == "mypy":
|
||||
# load the cached type information
|
||||
cached_type_info = cache_json["names"][class_name]
|
||||
assert (
|
||||
"node" not in cached_type_info
|
||||
), "previously this was an 'if' but it seems it's no longer possible for this cache to contain 'node', if this assert raises for you please let us know!"
|
||||
cached_type_info = lookup_symbol(cached_type_info["cross_ref"])
|
||||
|
||||
assert "node" in cached_type_info
|
||||
node = cached_type_info["node"]
|
||||
static_names = no_hidden(
|
||||
k for k in node.get("names", ()) if not k.startswith(".")
|
||||
)
|
||||
for symbol in node["mro"][1:]:
|
||||
node = lookup_symbol(symbol)["node"]
|
||||
static_names |= no_hidden(
|
||||
k for k in node.get("names", ()) if not k.startswith(".")
|
||||
)
|
||||
static_names -= ignore_names
|
||||
|
||||
else: # pragma: no cover
|
||||
raise AssertionError("unknown tool")
|
||||
|
||||
missing = runtime_names - static_names
|
||||
extra = static_names - runtime_names
|
||||
|
||||
# using .remove() instead of .delete() to get an error in case they start not
|
||||
# being missing
|
||||
|
||||
if (
|
||||
tool == "jedi"
|
||||
and BaseException in class_.__mro__
|
||||
and sys.version_info >= (3, 11)
|
||||
):
|
||||
missing.remove("add_note")
|
||||
|
||||
if (
|
||||
tool == "mypy"
|
||||
and BaseException in class_.__mro__
|
||||
and sys.version_info >= (3, 11)
|
||||
):
|
||||
extra.remove("__notes__")
|
||||
|
||||
if tool == "mypy" and attrs.has(class_):
|
||||
# e.g. __trio__core__run_CancelScope_AttrsAttributes__
|
||||
before = len(extra)
|
||||
extra = {e for e in extra if not e.endswith("AttrsAttributes__")}
|
||||
assert len(extra) == before - 1
|
||||
|
||||
if attrs.has(class_):
|
||||
# dynamically created attribute by attrs?
|
||||
missing.remove("__attrs_props__")
|
||||
|
||||
# dir does not see `__signature__` on enums until 3.14
|
||||
if (
|
||||
tool == "mypy"
|
||||
and enum.Enum in class_.__mro__
|
||||
and sys.version_info >= (3, 12)
|
||||
and sys.version_info < (3, 14)
|
||||
):
|
||||
extra.remove("__signature__")
|
||||
|
||||
# TODO: this *should* be visible via `dir`!!
|
||||
if tool == "mypy" and class_ == trio.Nursery:
|
||||
extra.remove("cancel_scope")
|
||||
|
||||
# These are (mostly? solely?) *runtime* attributes, often set in
|
||||
# __init__, which doesn't show up with dir() or inspect.getmembers,
|
||||
# but we get them in the way we query mypy & jedi
|
||||
EXTRAS = {
|
||||
trio.DTLSChannel: {"peer_address", "endpoint"},
|
||||
trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"},
|
||||
trio.Process: {"args", "pid", "stderr", "stdin", "stdio", "stdout"},
|
||||
trio.SSLListener: {"transport_listener"},
|
||||
trio.SSLStream: {"transport_stream"},
|
||||
trio.SocketListener: {"socket"},
|
||||
trio.SocketStream: {"socket"},
|
||||
trio.testing.MemoryReceiveStream: {"close_hook", "receive_some_hook"},
|
||||
trio.testing.MemorySendStream: {
|
||||
"close_hook",
|
||||
"send_all_hook",
|
||||
"wait_send_all_might_not_block_hook",
|
||||
},
|
||||
}
|
||||
if tool == "mypy" and class_ in EXTRAS:
|
||||
before = len(extra)
|
||||
extra -= EXTRAS[class_]
|
||||
assert len(extra) == before - len(EXTRAS[class_])
|
||||
|
||||
# TODO: why is this? Is it a problem?
|
||||
# see https://github.com/python-trio/trio/pull/2631#discussion_r1185615916
|
||||
if class_ == trio.StapledStream:
|
||||
extra.remove("receive_stream")
|
||||
extra.remove("send_stream")
|
||||
|
||||
# I have not researched why these are missing, should maybe create an issue
|
||||
# upstream with jedi
|
||||
if tool == "jedi" and sys.version_info >= (3, 11):
|
||||
if class_ in (
|
||||
trio.DTLSChannel,
|
||||
trio.MemoryReceiveChannel,
|
||||
trio.MemorySendChannel,
|
||||
trio.SSLListener,
|
||||
trio.SocketListener,
|
||||
):
|
||||
missing.remove("__aenter__")
|
||||
missing.remove("__aexit__")
|
||||
if class_ in (trio.DTLSChannel, trio.MemoryReceiveChannel):
|
||||
missing.remove("__aiter__")
|
||||
missing.remove("__anext__")
|
||||
|
||||
if class_ in (trio.Path, trio.WindowsPath, trio.PosixPath):
|
||||
# These are from inherited subclasses.
|
||||
missing -= PurePath.__dict__.keys()
|
||||
# These are unix-only.
|
||||
if tool == "mypy" and sys.platform == "win32":
|
||||
missing -= {"owner", "is_mount", "group"}
|
||||
if tool == "jedi" and sys.platform == "win32":
|
||||
extra -= {"owner", "is_mount", "group"}
|
||||
|
||||
# not sure why jedi in particular ignores this (static?) method in 3.13
|
||||
if (
|
||||
tool == "jedi"
|
||||
and sys.version_info[:2] == (3, 13)
|
||||
and class_ in (trio.Path, trio.WindowsPath, trio.PosixPath)
|
||||
):
|
||||
missing.remove("with_segments")
|
||||
|
||||
# tuple subclasses are weird
|
||||
if issubclass(class_, tuple):
|
||||
extra.remove("__reversed__")
|
||||
missing.remove("__getnewargs__")
|
||||
|
||||
if sys.version_info >= (3, 13) and attrs.has(class_):
|
||||
missing.remove("__replace__")
|
||||
|
||||
if sys.version_info >= (3, 14):
|
||||
# these depend on whether a class has processed deferred annotations.
|
||||
# (which might or might not happen and we don't know)
|
||||
missing.discard("__annotate_func__")
|
||||
missing.discard("__annotations_cache__")
|
||||
|
||||
if missing or extra: # pragma: no cover
|
||||
errors[f"{module_name}.{class_name}"] = {
|
||||
"missing": missing,
|
||||
"extra": extra,
|
||||
}
|
||||
|
||||
# `assert not errors` will not print the full content of errors, even with
|
||||
# `--verbose`, so we manually print it
|
||||
if errors: # pragma: no cover
|
||||
from pprint import pprint
|
||||
|
||||
print(f"\n{tool} can't see the following symbols in {module_name}:")
|
||||
pprint(errors)
|
||||
assert not errors
|
||||
|
||||
|
||||
def test_nopublic_is_final() -> None:
|
||||
"""Check all NoPublicConstructor classes are also @final."""
|
||||
assert class_is_final(_util.NoPublicConstructor) # This is itself final.
|
||||
|
||||
for module in ALL_MODULES:
|
||||
for class_ in module.__dict__.values():
|
||||
if isinstance(class_, _util.NoPublicConstructor):
|
||||
assert class_is_final(class_)
|
||||
|
||||
|
||||
def test_classes_are_final() -> None:
|
||||
# Sanity checks.
|
||||
assert not class_is_final(object)
|
||||
assert class_is_final(bool)
|
||||
|
||||
for module in PUBLIC_MODULES:
|
||||
for name, class_ in module.__dict__.items():
|
||||
if not isinstance(class_, type):
|
||||
continue
|
||||
# Deprecated classes are exported with a leading underscore
|
||||
if name.startswith("_"): # pragma: no cover
|
||||
continue
|
||||
|
||||
# Abstract classes can be subclassed, because that's the whole
|
||||
# point of ABCs
|
||||
if inspect.isabstract(class_):
|
||||
continue
|
||||
# Same with protocols, but only direct children.
|
||||
if Protocol in class_.__bases__ or Protocol_ext in class_.__bases__:
|
||||
continue
|
||||
# Exceptions are allowed to be subclassed, because exception
|
||||
# subclassing isn't used to inherit behavior.
|
||||
if issubclass(class_, BaseException):
|
||||
continue
|
||||
# These are classes that are conceptually abstract, but
|
||||
# inspect.isabstract returns False for boring reasons.
|
||||
if class_ is trio.abc.Instrument or class_ is trio.socket.SocketType:
|
||||
continue
|
||||
# ... insert other special cases here ...
|
||||
|
||||
# The `Path` class needs to support inheritance to allow `WindowsPath` and `PosixPath`.
|
||||
if class_ is trio.Path:
|
||||
continue
|
||||
# don't care about the *Statistics classes
|
||||
if name.endswith("Statistics"):
|
||||
continue
|
||||
|
||||
assert class_is_final(class_)
|
||||
|
||||
|
||||
# Plugin might not be running, especially if running from an installed version.
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(attrs.field, "trio_modded"),
|
||||
reason="Pytest plugin not installed.",
|
||||
)
|
||||
def test_pyright_recognizes_init_attributes() -> None:
|
||||
"""Check whether we provide `alias` for all underscore prefixed attributes.
|
||||
|
||||
Attrs always sets the `alias` attribute on fields, so a pytest plugin is used
|
||||
to monkeypatch `field()` to record whether an alias was defined in the metadata.
|
||||
See `_trio_check_attrs_aliases`.
|
||||
"""
|
||||
for module in PUBLIC_MODULES:
|
||||
for class_ in module.__dict__.values():
|
||||
if not attrs.has(class_):
|
||||
continue
|
||||
if isinstance(class_, _util.NoPublicConstructor):
|
||||
continue
|
||||
|
||||
attributes = [
|
||||
attr
|
||||
for attr in attrs.fields(class_)
|
||||
if attr.init
|
||||
if attr.alias
|
||||
not in (
|
||||
attr.name,
|
||||
# trio_original_args may not be present in autoattribs
|
||||
attr.metadata.get("trio_original_args", {}).get("alias"),
|
||||
)
|
||||
]
|
||||
|
||||
assert attributes == [], class_
|
||||
@@ -0,0 +1,317 @@
|
||||
import errno
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing._fake_net import FakeNet
|
||||
|
||||
# ENOTCONN gives different messages on different platforms
|
||||
if sys.platform == "linux":
|
||||
ENOTCONN_MSG = r"^\[Errno 107\] (Transport endpoint is|Socket) not connected$"
|
||||
elif sys.platform == "darwin":
|
||||
ENOTCONN_MSG = r"^\[Errno 57\] Socket is not connected$"
|
||||
else:
|
||||
ENOTCONN_MSG = r"^\[Errno 10057\] Unknown error$"
|
||||
|
||||
|
||||
def fn() -> FakeNet:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
return fn
|
||||
|
||||
|
||||
async def test_basic_udp() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
ip, port = s1.getsockname()
|
||||
assert ip == "127.0.0.1"
|
||||
assert port != 0
|
||||
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^\[\w+ \d+\] Invalid argument$",
|
||||
) as exc: # Cannot rebind.
|
||||
await s1.bind(("192.0.2.1", 0))
|
||||
assert exc.value.errno == errno.EINVAL
|
||||
|
||||
# Cannot bind multiple sockets to the same address
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^\[\w+ \d+\] (Address (already )?in use|Unknown error)$",
|
||||
) as exc:
|
||||
await s2.bind(("127.0.0.1", port))
|
||||
assert exc.value.errno == errno.EADDRINUSE
|
||||
|
||||
await s2.sendto(b"xyz", s1.getsockname())
|
||||
data, addr = await s1.recvfrom(10)
|
||||
assert data == b"xyz"
|
||||
assert addr == s2.getsockname()
|
||||
await s1.sendto(b"abc", s2.getsockname())
|
||||
data, addr = await s2.recvfrom(10)
|
||||
assert data == b"abc"
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
|
||||
async def test_msg_trunc() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
await s2.sendto(b"xyz", s1.getsockname())
|
||||
await s1.recvfrom(10)
|
||||
|
||||
|
||||
async def test_recv_methods() -> None:
|
||||
"""Test all recv methods for codecov"""
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
|
||||
# receiving on an unbound socket is a bad idea (I think?)
|
||||
with pytest.raises(NotImplementedError, match="code will most likely hang"):
|
||||
await s2.recv(10)
|
||||
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
ip, port = s1.getsockname()
|
||||
assert ip == "127.0.0.1"
|
||||
assert port != 0
|
||||
|
||||
# recvfrom
|
||||
await s2.sendto(b"abc", s1.getsockname())
|
||||
data, addr = await s1.recvfrom(10)
|
||||
assert data == b"abc"
|
||||
assert addr == s2.getsockname()
|
||||
|
||||
# recv
|
||||
await s1.sendto(b"def", s2.getsockname())
|
||||
data = await s2.recv(10)
|
||||
assert data == b"def"
|
||||
|
||||
# recvfrom_into
|
||||
assert await s1.sendto(b"ghi", s2.getsockname()) == 3
|
||||
buf = bytearray(10)
|
||||
|
||||
with pytest.raises(NotImplementedError, match=r"^partial recvfrom_into$"):
|
||||
nbytes, addr = await s2.recvfrom_into(buf, nbytes=2)
|
||||
|
||||
nbytes, addr = await s2.recvfrom_into(buf)
|
||||
assert nbytes == 3
|
||||
assert buf == b"ghi" + b"\x00" * 7
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
# recv_into
|
||||
assert await s1.sendto(b"jkl", s2.getsockname()) == 3
|
||||
buf2 = bytearray(10)
|
||||
nbytes = await s2.recv_into(buf2)
|
||||
assert nbytes == 3
|
||||
assert buf2 == b"jkl" + b"\x00" * 7
|
||||
|
||||
if sys.platform == "linux" and sys.implementation.name == "cpython":
|
||||
flags: int = socket.MSG_MORE
|
||||
else:
|
||||
flags = 1
|
||||
|
||||
# Send seems explicitly non-functional
|
||||
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
|
||||
await s2.send(b"mno")
|
||||
assert exc.value.errno == errno.ENOTCONN
|
||||
with pytest.raises(
|
||||
NotImplementedError, match=r"^FakeNet send flags must be 0, not"
|
||||
):
|
||||
await s2.send(b"mno", flags)
|
||||
|
||||
# sendto errors
|
||||
# it's successfully used earlier
|
||||
with pytest.raises(
|
||||
NotImplementedError, match=r"^FakeNet send flags must be 0, not"
|
||||
):
|
||||
await s2.sendto(b"mno", flags, s1.getsockname())
|
||||
with pytest.raises(TypeError, match=r"wrong number of arguments$"):
|
||||
await s2.sendto(b"mno", flags, s1.getsockname(), "extra arg") # type: ignore[call-overload]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="functions not in socket on windows",
|
||||
)
|
||||
async def test_nonwindows_functionality() -> None:
|
||||
# mypy doesn't support a good way of aborting typechecking on different platforms
|
||||
if sys.platform != "win32": # pragma: no branch
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
await s2.bind(("127.0.0.1", 0))
|
||||
|
||||
# sendmsg
|
||||
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
|
||||
await s2.sendmsg([b"mno"])
|
||||
assert exc.value.errno == errno.ENOTCONN
|
||||
|
||||
assert await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) == 3
|
||||
data, ancdata, msg_flags, addr = await s2.recvmsg(10)
|
||||
assert data == b"jkl"
|
||||
assert ancdata == []
|
||||
assert msg_flags == 0
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
# TODO: recvmsg
|
||||
|
||||
# recvmsg_into
|
||||
assert await s1.sendto(b"xyzw", s2.getsockname()) == 4
|
||||
buf1 = bytearray(2)
|
||||
buf2 = bytearray(3)
|
||||
ret = await s2.recvmsg_into([buf1, buf2])
|
||||
nbytes, ancdata, msg_flags, addr = ret
|
||||
assert nbytes == 4
|
||||
assert buf1 == b"xy"
|
||||
assert buf2 == b"zw" + b"\x00"
|
||||
assert ancdata == []
|
||||
assert msg_flags == 0
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
# recvmsg_into with MSG_TRUNC set
|
||||
assert await s1.sendto(b"xyzwv", s2.getsockname()) == 5
|
||||
buf1 = bytearray(2)
|
||||
ret = await s2.recvmsg_into([buf1])
|
||||
nbytes, ancdata, msg_flags, addr = ret
|
||||
assert nbytes == 2
|
||||
assert buf1 == b"xy"
|
||||
assert ancdata == []
|
||||
assert msg_flags == socket.MSG_TRUNC
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match=r"^'FakeSocket' object has no attribute 'share'$",
|
||||
):
|
||||
await s1.share(0) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "win32",
|
||||
reason="windows-specific fakesocket testing",
|
||||
)
|
||||
async def test_windows_functionality() -> None:
|
||||
# mypy doesn't support a good way of aborting typechecking on different platforms
|
||||
if sys.platform == "win32": # pragma: no branch
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match=r"^'FakeSocket' object has no attribute 'sendmsg'$",
|
||||
):
|
||||
await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) # type: ignore[attr-defined]
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match=r"^'FakeSocket' object has no attribute 'recvmsg'$",
|
||||
):
|
||||
s2.recvmsg(0) # type: ignore[attr-defined]
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match=r"^'FakeSocket' object has no attribute 'recvmsg_into'$",
|
||||
):
|
||||
s2.recvmsg_into([]) # type: ignore[attr-defined]
|
||||
with pytest.raises(NotImplementedError):
|
||||
s1.share(0)
|
||||
|
||||
|
||||
async def test_basic_tcp() -> None:
|
||||
fn()
|
||||
with pytest.raises(NotImplementedError):
|
||||
trio.socket.socket()
|
||||
|
||||
|
||||
async def test_not_implemented_functions() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
|
||||
# getsockopt
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^FakeNet doesn't implement getsockopt\(\d, \d\)$",
|
||||
):
|
||||
s1.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
|
||||
|
||||
# setsockopt
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match=r"^FakeNet always has IPV6_V6ONLY=True$",
|
||||
):
|
||||
s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$",
|
||||
):
|
||||
s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True)
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$",
|
||||
):
|
||||
s1.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
# set_inheritable
|
||||
s1.set_inheritable(False)
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match=r"^FakeNet can't make inheritable sockets$",
|
||||
):
|
||||
s1.set_inheritable(True)
|
||||
|
||||
# get_inheritable
|
||||
assert not s1.get_inheritable()
|
||||
|
||||
|
||||
async def test_getpeername() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
|
||||
s1.getpeername()
|
||||
assert exc.value.errno == errno.ENOTCONN
|
||||
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match=r"^This method seems to assume that self._binding has a remote UDPEndpoint$",
|
||||
):
|
||||
s1.getpeername()
|
||||
|
||||
|
||||
async def test_init() -> None:
|
||||
fn()
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match=re.escape(
|
||||
f"FakeNet doesn't (yet) support type={trio.socket.SOCK_STREAM}",
|
||||
),
|
||||
):
|
||||
s1 = trio.socket.socket()
|
||||
|
||||
# getsockname on unbound ipv4 socket
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
assert s1.getsockname() == ("0.0.0.0", 0)
|
||||
|
||||
# getsockname on bound ipv4 socket
|
||||
await s1.bind(("0.0.0.0", 0))
|
||||
ip, port = s1.getsockname()
|
||||
assert ip == "127.0.0.1"
|
||||
assert port != 0
|
||||
|
||||
# getsockname on unbound ipv6 socket
|
||||
s2 = trio.socket.socket(family=socket.AF_INET6, type=socket.SOCK_DGRAM)
|
||||
assert s2.getsockname() == ("::", 0)
|
||||
|
||||
# getsockname on bound ipv6 socket
|
||||
await s2.bind(("::", 0))
|
||||
ip, port, *_ = s2.getsockname()
|
||||
assert ip == "::1"
|
||||
assert port != 0
|
||||
assert _ == [0, 0]
|
||||
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
from unittest.mock import sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio import _core, _file_io
|
||||
from trio._file_io import _FILE_ASYNC_METHODS, _FILE_SYNC_ATTRS, AsyncIOWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pathlib
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def path(tmp_path: pathlib.Path) -> str:
|
||||
return os.fspath(tmp_path / "test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wrapped() -> mock.Mock:
|
||||
return mock.Mock(spec_set=io.StringIO)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_file(wrapped: mock.Mock) -> AsyncIOWrapper[mock.Mock]:
|
||||
return trio.wrap_file(wrapped)
|
||||
|
||||
|
||||
def test_wrap_invalid() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
trio.wrap_file("")
|
||||
|
||||
|
||||
def test_wrap_non_iobase() -> None:
|
||||
class FakeFile:
|
||||
def close(self) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
def write(self) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
wrapped = FakeFile()
|
||||
assert not isinstance(wrapped, io.IOBase)
|
||||
|
||||
async_file = trio.wrap_file(wrapped)
|
||||
assert isinstance(async_file, AsyncIOWrapper)
|
||||
|
||||
del FakeFile.write
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
trio.wrap_file(FakeFile())
|
||||
|
||||
|
||||
def test_wrapped_property(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
assert async_file.wrapped is wrapped
|
||||
|
||||
|
||||
def test_dir_matches_wrapped(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS)
|
||||
|
||||
# all supported attrs in wrapped should be available in async_file
|
||||
assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped))
|
||||
# all supported attrs not in wrapped should not be available in async_file
|
||||
assert not any(
|
||||
attr in dir(async_file) for attr in attrs if attr not in dir(wrapped)
|
||||
)
|
||||
|
||||
|
||||
def test_unsupported_not_forwarded() -> None:
|
||||
class FakeFile(io.RawIOBase):
|
||||
def unsupported_attr(self) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
async_file = trio.wrap_file(FakeFile())
|
||||
|
||||
assert hasattr(async_file.wrapped, "unsupported_attr")
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
# B018 "useless expression"
|
||||
async_file.unsupported_attr # type: ignore[attr-defined] # noqa: B018
|
||||
|
||||
|
||||
def test_type_stubs_match_lists() -> None:
|
||||
"""Check the manual stubs match the list of wrapped methods."""
|
||||
# Fetch the module's source code.
|
||||
assert _file_io.__spec__ is not None
|
||||
loader = _file_io.__spec__.loader
|
||||
assert isinstance(loader, importlib.abc.SourceLoader)
|
||||
source = io.StringIO(loader.get_source("trio._file_io"))
|
||||
|
||||
# Find the class, then find the TYPE_CHECKING block.
|
||||
for line in source:
|
||||
if "class AsyncIOWrapper" in line:
|
||||
break
|
||||
else: # pragma: no cover - should always find this
|
||||
pytest.fail("No class definition line?")
|
||||
|
||||
for line in source:
|
||||
if "if TYPE_CHECKING" in line:
|
||||
break
|
||||
else: # pragma: no cover - should always find this
|
||||
pytest.fail("No TYPE CHECKING line?")
|
||||
|
||||
# Now we should be at the type checking block.
|
||||
found: list[tuple[str, str]] = []
|
||||
for line in source: # pragma: no branch - expected to break early
|
||||
if line.strip() and not line.startswith(" " * 8):
|
||||
break # Dedented out of the if TYPE_CHECKING block.
|
||||
match = re.match(r"\s*(async )?def ([a-zA-Z0-9_]+)\(", line)
|
||||
if match is not None:
|
||||
kind = "async" if match.group(1) is not None else "sync"
|
||||
found.append((match.group(2), kind))
|
||||
|
||||
# Compare two lists so that we can easily see duplicates, and see what is different overall.
|
||||
expected = [(fname, "async") for fname in _FILE_ASYNC_METHODS]
|
||||
expected += [(fname, "sync") for fname in _FILE_SYNC_ATTRS]
|
||||
# Ignore order, error if duplicates are present.
|
||||
found.sort()
|
||||
expected.sort()
|
||||
assert found == expected
|
||||
|
||||
|
||||
def test_sync_attrs_forwarded(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for attr_name in _FILE_SYNC_ATTRS:
|
||||
if attr_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
assert getattr(async_file, attr_name) is getattr(wrapped, attr_name)
|
||||
|
||||
|
||||
def test_sync_attrs_match_wrapper(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for attr_name in _FILE_SYNC_ATTRS:
|
||||
if attr_name in dir(async_file):
|
||||
continue
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(async_file, attr_name)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(wrapped, attr_name)
|
||||
|
||||
|
||||
def test_async_methods_generated_once(async_file: AsyncIOWrapper[mock.Mock]) -> None:
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
assert getattr(async_file, meth_name) is getattr(async_file, meth_name)
|
||||
|
||||
|
||||
# I gave up on typing this one
|
||||
def test_async_methods_signature(async_file: AsyncIOWrapper[mock.Mock]) -> None:
|
||||
# use read as a representative of all async methods
|
||||
assert async_file.read.__name__ == "read"
|
||||
assert async_file.read.__qualname__ == "AsyncIOWrapper.read"
|
||||
|
||||
assert async_file.read.__doc__ is not None
|
||||
assert "io.StringIO.read" in async_file.read.__doc__
|
||||
|
||||
|
||||
async def test_async_methods_wrap(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
meth = getattr(async_file, meth_name)
|
||||
wrapped_meth = getattr(wrapped, meth_name)
|
||||
|
||||
value = await meth(sentinel.argument, keyword=sentinel.keyword)
|
||||
|
||||
wrapped_meth.assert_called_once_with(
|
||||
sentinel.argument,
|
||||
keyword=sentinel.keyword,
|
||||
)
|
||||
assert value == wrapped_meth()
|
||||
|
||||
wrapped.reset_mock()
|
||||
|
||||
|
||||
def test_async_methods_match_wrapper(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name in dir(async_file):
|
||||
continue
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(async_file, meth_name)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(wrapped, meth_name)
|
||||
|
||||
|
||||
async def test_open(path: pathlib.Path) -> None:
|
||||
f = await trio.open_file(path, "w")
|
||||
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
|
||||
await f.aclose()
|
||||
|
||||
|
||||
async def test_open_context_manager(path: pathlib.Path) -> None:
|
||||
async with await trio.open_file(path, "w") as f:
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
assert not f.closed
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_async_iter() -> None:
|
||||
async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar"))
|
||||
expected = list(async_file.wrapped)
|
||||
async_file.wrapped.seek(0)
|
||||
|
||||
result = [line async for line in async_file]
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
async def test_aclose_cancelled(path: pathlib.Path) -> None:
|
||||
with _core.CancelScope() as cscope:
|
||||
f = await trio.open_file(path, "w")
|
||||
cscope.cancel()
|
||||
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await f.write("a")
|
||||
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await f.aclose()
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_detach_rewraps_asynciobase(tmp_path: pathlib.Path) -> None:
|
||||
tmp_file = tmp_path / "filename"
|
||||
tmp_file.touch()
|
||||
# flake8-async does not like opening files in async mode
|
||||
with open(tmp_file, mode="rb", buffering=0) as raw: # noqa: ASYNC230
|
||||
buffered = io.BufferedReader(raw)
|
||||
|
||||
async_file = trio.wrap_file(buffered)
|
||||
|
||||
detached = await async_file.detach()
|
||||
|
||||
assert isinstance(detached, AsyncIOWrapper)
|
||||
assert detached.wrapped is raw
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NoReturn
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from .._highlevel_generic import StapledStream
|
||||
from ..abc import ReceiveStream, SendStream
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class RecordSendStream(SendStream):
|
||||
record: list[str | tuple[str, object]] = attrs.Factory(list)
|
||||
|
||||
async def send_all(self, data: object) -> None:
|
||||
self.record.append(("send_all", data))
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
self.record.append("wait_send_all_might_not_block")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.record.append("aclose")
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class RecordReceiveStream(ReceiveStream):
|
||||
record: list[str | tuple[str, int | None]] = attrs.Factory(list)
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes:
|
||||
self.record.append(("receive_some", max_bytes))
|
||||
return b""
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.record.append("aclose")
|
||||
|
||||
|
||||
async def test_StapledStream() -> None:
|
||||
send_stream = RecordSendStream()
|
||||
receive_stream = RecordReceiveStream()
|
||||
stapled = StapledStream(send_stream, receive_stream)
|
||||
|
||||
assert stapled.send_stream is send_stream
|
||||
assert stapled.receive_stream is receive_stream
|
||||
|
||||
await stapled.send_all(b"foo")
|
||||
await stapled.wait_send_all_might_not_block()
|
||||
assert send_stream.record == [
|
||||
("send_all", b"foo"),
|
||||
"wait_send_all_might_not_block",
|
||||
]
|
||||
send_stream.record.clear()
|
||||
|
||||
await stapled.send_eof()
|
||||
assert send_stream.record == ["aclose"]
|
||||
send_stream.record.clear()
|
||||
|
||||
async def fake_send_eof() -> None:
|
||||
send_stream.record.append("send_eof")
|
||||
|
||||
send_stream.send_eof = fake_send_eof # type: ignore[attr-defined]
|
||||
await stapled.send_eof()
|
||||
assert send_stream.record == ["send_eof"]
|
||||
|
||||
send_stream.record.clear()
|
||||
assert receive_stream.record == []
|
||||
|
||||
await stapled.receive_some(1234)
|
||||
assert receive_stream.record == [("receive_some", 1234)]
|
||||
assert send_stream.record == []
|
||||
receive_stream.record.clear()
|
||||
|
||||
await stapled.aclose()
|
||||
assert receive_stream.record == ["aclose"]
|
||||
assert send_stream.record == ["aclose"]
|
||||
|
||||
|
||||
async def test_StapledStream_with_erroring_close() -> None:
|
||||
# Make sure that if one of the aclose methods errors out, then the other
|
||||
# one still gets called.
|
||||
class BrokenSendStream(RecordSendStream):
|
||||
async def aclose(self) -> NoReturn:
|
||||
await super().aclose()
|
||||
raise ValueError("send error")
|
||||
|
||||
class BrokenReceiveStream(RecordReceiveStream):
|
||||
async def aclose(self) -> NoReturn:
|
||||
await super().aclose()
|
||||
raise ValueError("recv error")
|
||||
|
||||
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
|
||||
|
||||
with pytest.raises(ValueError, match=r"^(send|recv) error$") as excinfo:
|
||||
await stapled.aclose()
|
||||
assert isinstance(excinfo.value.__context__, ValueError)
|
||||
|
||||
assert stapled.send_stream.record == ["aclose"]
|
||||
assert stapled.receive_stream.record == ["aclose"]
|
||||
@@ -0,0 +1,419 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import socket as stdlib_socket
|
||||
import sys
|
||||
from socket import AddressFamily, SocketKind
|
||||
from typing import TYPE_CHECKING, cast, overload
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio import (
|
||||
SocketListener,
|
||||
open_tcp_listeners,
|
||||
open_tcp_stream,
|
||||
serve_tcp,
|
||||
)
|
||||
from trio.abc import HostnameResolver, SendStream, SocketFactory
|
||||
from trio.testing import open_stream_to_socket_listener
|
||||
|
||||
from .. import socket as tsocket
|
||||
from .._core._tests.tutil import binds_ipv6, slow
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Buffer
|
||||
|
||||
from trio._socket import AddressFormat
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_basic() -> None:
|
||||
listeners = await open_tcp_listeners(0)
|
||||
assert isinstance(listeners, list)
|
||||
for obj in listeners:
|
||||
assert isinstance(obj, SocketListener)
|
||||
# Binds to wildcard address by default
|
||||
assert obj.socket.family in [tsocket.AF_INET, tsocket.AF_INET6]
|
||||
assert obj.socket.getsockname()[0] in ["0.0.0.0", "::"]
|
||||
|
||||
listener = listeners[0]
|
||||
# Make sure the backlog is at least 2
|
||||
c1 = await open_stream_to_socket_listener(listener)
|
||||
c2 = await open_stream_to_socket_listener(listener)
|
||||
|
||||
s1 = await listener.accept()
|
||||
s2 = await listener.accept()
|
||||
|
||||
# Note that we don't know which client stream is connected to which server
|
||||
# stream
|
||||
await s1.send_all(b"x")
|
||||
await s2.send_all(b"x")
|
||||
assert await c1.receive_some(1) == b"x"
|
||||
assert await c2.receive_some(1) == b"x"
|
||||
|
||||
for resource in [c1, c2, s1, s2, *listeners]:
|
||||
await resource.aclose()
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_specific_port_specific_host() -> None:
|
||||
# Pick a port
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("127.0.0.1", 0))
|
||||
host, port = sock.getsockname()
|
||||
sock.close()
|
||||
|
||||
(listener,) = await open_tcp_listeners(port, host=host)
|
||||
async with listener:
|
||||
assert listener.socket.getsockname() == (host, port)
|
||||
|
||||
|
||||
@binds_ipv6
|
||||
@slow
|
||||
async def test_open_tcp_listeners_ipv6_v6only() -> None:
|
||||
# Check IPV6_V6ONLY is working properly
|
||||
(ipv6_listener,) = await open_tcp_listeners(0, host="::1")
|
||||
async with ipv6_listener:
|
||||
_, port, *_ = ipv6_listener.socket.getsockname()
|
||||
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"(Error|all attempts to) connect(ing)* to (\(')*127\.0\.0\.1(', |:)\d+(\): Connection refused| failed)$",
|
||||
):
|
||||
# Windows retries failed connections so this takes seconds
|
||||
# (and that's why this is marked @slow)
|
||||
await open_tcp_stream("127.0.0.1", port)
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_rebind() -> None:
|
||||
(l1,) = await open_tcp_listeners(0, host="127.0.0.1")
|
||||
sockaddr1 = l1.socket.getsockname()
|
||||
|
||||
# Plain old rebinding while it's still there should fail, even if we have
|
||||
# SO_REUSEADDR set
|
||||
with stdlib_socket.socket() as probe:
|
||||
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"(Address (already )?in use|An attempt was made to access a socket in a way forbidden by its access permissions)$",
|
||||
):
|
||||
probe.bind(sockaddr1)
|
||||
|
||||
# Now use the first listener to set up some connections in various states,
|
||||
# and make sure that they don't create any obstacle to rebinding a second
|
||||
# listener after the first one is closed.
|
||||
c_established = await open_stream_to_socket_listener(l1)
|
||||
s_established = await l1.accept()
|
||||
|
||||
c_time_wait = await open_stream_to_socket_listener(l1)
|
||||
s_time_wait = await l1.accept()
|
||||
# Server-initiated close leaves socket in TIME_WAIT
|
||||
await s_time_wait.aclose()
|
||||
|
||||
await l1.aclose()
|
||||
(l2,) = await open_tcp_listeners(sockaddr1[1], host="127.0.0.1")
|
||||
sockaddr2 = l2.socket.getsockname()
|
||||
|
||||
assert sockaddr1 == sockaddr2
|
||||
assert s_established.socket.getsockname() == sockaddr2
|
||||
assert c_time_wait.socket.getpeername() == sockaddr2
|
||||
|
||||
for resource in [
|
||||
l1,
|
||||
l2,
|
||||
c_established,
|
||||
s_established,
|
||||
c_time_wait,
|
||||
s_time_wait,
|
||||
]:
|
||||
await resource.aclose()
|
||||
|
||||
|
||||
class FakeOSError(OSError):
|
||||
pass
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class FakeSocket(tsocket.SocketType):
|
||||
_family: AddressFamily = attrs.field(converter=AddressFamily)
|
||||
_type: SocketKind = attrs.field(converter=SocketKind)
|
||||
_proto: int
|
||||
|
||||
closed: bool = False
|
||||
poison_listen: bool = False
|
||||
backlog: int | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> SocketKind:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def family(self) -> AddressFamily:
|
||||
return self._family
|
||||
|
||||
@property
|
||||
def proto(self) -> int: # pragma: no cover
|
||||
return self._proto
|
||||
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ...
|
||||
|
||||
def getsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int | None = None,
|
||||
) -> int | bytes:
|
||||
if (level, optname) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
|
||||
return True
|
||||
raise AssertionError() # pragma: no cover
|
||||
|
||||
@overload
|
||||
def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ...
|
||||
|
||||
@overload
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: None,
|
||||
optlen: int,
|
||||
) -> None: ...
|
||||
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer | None,
|
||||
optlen: int | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def bind(self, address: AddressFormat) -> None:
|
||||
pass
|
||||
|
||||
def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None:
|
||||
assert self.backlog is None
|
||||
assert backlog is not None
|
||||
self.backlog = backlog
|
||||
if self.poison_listen:
|
||||
raise FakeOSError("whoops")
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class FakeSocketFactory(SocketFactory):
|
||||
poison_after: int
|
||||
sockets: list[tsocket.SocketType] = attrs.Factory(list)
|
||||
raise_on_family: dict[AddressFamily, int] = attrs.Factory(dict) # family => errno
|
||||
|
||||
def socket(
|
||||
self,
|
||||
family: AddressFamily | int | None = None,
|
||||
type_: SocketKind | int | None = None,
|
||||
proto: int = 0,
|
||||
) -> tsocket.SocketType:
|
||||
assert family is not None
|
||||
assert type_ is not None
|
||||
if isinstance(family, int) and not isinstance(family, AddressFamily):
|
||||
family = AddressFamily(family) # pragma: no cover
|
||||
if family in self.raise_on_family:
|
||||
raise OSError(self.raise_on_family[family], "nope")
|
||||
sock = FakeSocket(family, type_, proto)
|
||||
self.poison_after -= 1
|
||||
if self.poison_after == 0:
|
||||
sock.poison_listen = True
|
||||
self.sockets.append(sock)
|
||||
return sock
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class FakeHostnameResolver(HostnameResolver):
|
||||
family_addr_pairs: Sequence[tuple[AddressFamily, str]]
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = 0,
|
||||
type: int = 0,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
|
||||
]
|
||||
]:
|
||||
assert isinstance(port, int)
|
||||
return [
|
||||
(family, tsocket.SOCK_STREAM, 0, "", (addr, port))
|
||||
for family, addr in self.family_addr_pairs
|
||||
]
|
||||
|
||||
async def getnameinfo(
|
||||
self,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
flags: int,
|
||||
) -> tuple[str, str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None:
|
||||
# If we were trying to bind to multiple hosts and one of them failed, they
|
||||
# call get cleaned up before returning
|
||||
fsf = FakeSocketFactory(3)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver(
|
||||
[
|
||||
(tsocket.AF_INET, "1.1.1.1"),
|
||||
(tsocket.AF_INET, "2.2.2.2"),
|
||||
(tsocket.AF_INET, "3.3.3.3"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(FakeOSError):
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
|
||||
assert len(fsf.sockets) == 3
|
||||
for sock in fsf.sockets:
|
||||
# property only exists on FakeSocket
|
||||
assert sock.closed # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_port_checking() -> None:
|
||||
for host in ["127.0.0.1", None]:
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners(None, host=host) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners(b"80", host=host) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners("http", host=host) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_serve_tcp() -> None:
|
||||
async def handler(stream: SendStream) -> None:
|
||||
await stream.send_all(b"x")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# nursery.start is incorrectly typed, awaiting #2773
|
||||
value = await nursery.start(serve_tcp, handler, 0)
|
||||
assert isinstance(value, list)
|
||||
listeners = cast("list[SocketListener]", value)
|
||||
stream = await open_stream_to_socket_listener(listeners[0])
|
||||
async with stream:
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"try_families",
|
||||
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"fail_families",
|
||||
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
|
||||
)
|
||||
async def test_open_tcp_listeners_some_address_families_unavailable(
|
||||
try_families: set[AddressFamily],
|
||||
fail_families: set[AddressFamily],
|
||||
) -> None:
|
||||
fsf = FakeSocketFactory(
|
||||
10,
|
||||
raise_on_family=dict.fromkeys(fail_families, errno.EAFNOSUPPORT),
|
||||
)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver([(family, "foo") for family in try_families]),
|
||||
)
|
||||
|
||||
should_succeed = try_families - fail_families
|
||||
|
||||
if not should_succeed:
|
||||
with pytest.raises(OSError, match="This system doesn't support") as exc_info:
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
|
||||
# open_listeners always creates an exceptiongroup with the
|
||||
# unsupported address families, regardless of the value of
|
||||
# strict_exception_groups or number of unsupported families.
|
||||
assert isinstance(exc_info.value.__cause__, BaseExceptionGroup)
|
||||
for subexc in exc_info.value.__cause__.exceptions:
|
||||
assert "nope" in str(subexc)
|
||||
else:
|
||||
listeners = await open_tcp_listeners(80)
|
||||
for listener in listeners:
|
||||
should_succeed.remove(listener.socket.family)
|
||||
assert not should_succeed
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None:
|
||||
fsf = FakeSocketFactory(
|
||||
10,
|
||||
raise_on_family={
|
||||
tsocket.AF_INET: errno.EAFNOSUPPORT,
|
||||
tsocket.AF_INET6: errno.EINVAL,
|
||||
},
|
||||
)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]),
|
||||
)
|
||||
|
||||
with pytest.raises(OSError, match="nope") as exc_info:
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
assert exc_info.value.errno == errno.EINVAL
|
||||
assert exc_info.value.__cause__ is None
|
||||
assert "nope" in str(exc_info.value)
|
||||
|
||||
|
||||
# We used to have an elaborate test that opened a real TCP listening socket
|
||||
# and then tried to measure its backlog by making connections to it. And most
|
||||
# of the time, it worked. But no matter what we tried, it was always fragile,
|
||||
# because it had to do things like use timeouts to guess when the listening
|
||||
# queue was full, sometimes the CI hosts go into SYN-cookie mode (where there
|
||||
# effectively is no backlog), sometimes the host might not be enough resources
|
||||
# to give us the full requested backlog... it was a mess. So now we just check
|
||||
# that the backlog argument is passed through correctly.
|
||||
async def test_open_tcp_listeners_backlog() -> None:
|
||||
fsf = FakeSocketFactory(99)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
for given, expected in [
|
||||
(None, 0xFFFF),
|
||||
(99999999, 0xFFFF),
|
||||
(10, 10),
|
||||
(1, 1),
|
||||
]:
|
||||
listeners = await open_tcp_listeners(0, backlog=given)
|
||||
assert listeners
|
||||
for listener in listeners:
|
||||
# `backlog` only exists on FakeSocket
|
||||
assert listener.socket.backlog == expected # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_backlog_float_error() -> None:
|
||||
fsf = FakeSocketFactory(99)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
for should_fail in (0.0, 2.18, 3.15, 9.75):
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=f"backlog must be an int or None, not {should_fail!r}",
|
||||
):
|
||||
await open_tcp_listeners(0, backlog=should_fail) # type: ignore[arg-type]
|
||||
@@ -0,0 +1,693 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import sys
|
||||
from socket import AddressFamily, SocketKind
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio._highlevel_open_tcp_stream import (
|
||||
close_all,
|
||||
format_host_port,
|
||||
open_tcp_stream,
|
||||
reorder_for_rfc_6555_section_5_4,
|
||||
)
|
||||
from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from trio.testing import MockClock
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
|
||||
def test_close_all() -> None:
|
||||
class CloseMe(SocketType):
|
||||
closed = False
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
class CloseKiller(SocketType):
|
||||
def close(self) -> None:
|
||||
raise OSError("os error text")
|
||||
|
||||
c: CloseMe = CloseMe()
|
||||
with close_all() as to_close:
|
||||
to_close.add(c)
|
||||
assert c.closed
|
||||
|
||||
c = CloseMe()
|
||||
with pytest.raises(RuntimeError):
|
||||
with close_all() as to_close:
|
||||
to_close.add(c)
|
||||
raise RuntimeError
|
||||
assert c.closed
|
||||
|
||||
c = CloseMe()
|
||||
with pytest.raises(OSError, match="os error text"):
|
||||
with close_all() as to_close:
|
||||
to_close.add(CloseKiller())
|
||||
to_close.add(c)
|
||||
assert c.closed
|
||||
|
||||
|
||||
def test_reorder_for_rfc_6555_section_5_4() -> None:
|
||||
def fake4(
|
||||
i: int,
|
||||
) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
|
||||
return (
|
||||
AF_INET,
|
||||
SOCK_STREAM,
|
||||
IPPROTO_TCP,
|
||||
"",
|
||||
(f"10.0.0.{i}", 80),
|
||||
)
|
||||
|
||||
def fake6(
|
||||
i: int,
|
||||
) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
|
||||
return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", (f"::{i}", 80))
|
||||
|
||||
for fake in fake4, fake6:
|
||||
# No effect on homogeneous lists
|
||||
targets = [fake(0), fake(1), fake(2)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake(0), fake(1), fake(2)]
|
||||
|
||||
# Single item lists also OK
|
||||
targets = [fake(0)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake(0)]
|
||||
|
||||
# If the list starts out with different families in positions 0 and 1,
|
||||
# then it's left alone
|
||||
orig = [fake4(0), fake6(0), fake4(1), fake6(1)]
|
||||
targets = list(orig)
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == orig
|
||||
|
||||
# If not, it's reordered
|
||||
targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
|
||||
|
||||
|
||||
def test_format_host_port() -> None:
|
||||
assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
|
||||
assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
|
||||
assert format_host_port("example.com", 443) == "example.com:443"
|
||||
assert format_host_port(b"example.com", 443) == "example.com:443"
|
||||
assert format_host_port("::1", "http") == "[::1]:http"
|
||||
assert format_host_port(b"::1", "http") == "[::1]:http"
|
||||
|
||||
|
||||
# Make sure we can connect to localhost using real kernel sockets
|
||||
async def test_open_tcp_stream_real_socket_smoketest() -> None:
|
||||
listen_sock = trio.socket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
_, listen_port = listen_sock.getsockname()
|
||||
listen_sock.listen(1)
|
||||
client_stream = await open_tcp_stream("127.0.0.1", listen_port)
|
||||
server_sock, _ = await listen_sock.accept()
|
||||
await client_stream.send_all(b"x")
|
||||
assert await server_sock.recv(1) == b"x"
|
||||
await client_stream.aclose()
|
||||
server_sock.close()
|
||||
|
||||
listen_sock.close()
|
||||
|
||||
|
||||
async def test_open_tcp_stream_input_validation() -> None:
|
||||
with pytest.raises(ValueError, match=r"^host must be str or bytes, not None$"):
|
||||
await open_tcp_stream(None, 80) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_stream("127.0.0.1", b"80") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def can_bind_127_0_0_2() -> bool:
|
||||
with socket.socket() as s:
|
||||
try:
|
||||
s.bind(("127.0.0.2", 0))
|
||||
except OSError:
|
||||
return False
|
||||
# s.getsockname() is typed as returning Any
|
||||
return s.getsockname()[0] == "127.0.0.2" # type: ignore[no-any-return]
|
||||
|
||||
|
||||
async def test_local_address_real() -> None:
|
||||
with trio.socket.socket() as listener:
|
||||
await listener.bind(("127.0.0.1", 0))
|
||||
listener.listen()
|
||||
|
||||
# It's hard to test local_address properly, because you need multiple
|
||||
# local addresses that you can bind to. Fortunately, on most Linux
|
||||
# systems, you can bind to any 127.*.*.* address, and they all go
|
||||
# through the loopback interface. So we can use a non-standard
|
||||
# loopback address. On other systems, the only address we know for
|
||||
# certain we have is 127.0.0.1, so we can't really test local_address=
|
||||
# properly -- passing local_address=127.0.0.1 is indistinguishable
|
||||
# from not passing local_address= at all. But, we can still do a smoke
|
||||
# test to make sure the local_address= code doesn't crash.
|
||||
local_address = "127.0.0.2" if can_bind_127_0_0_2() else "127.0.0.1"
|
||||
|
||||
async with await open_tcp_stream(
|
||||
*listener.getsockname(),
|
||||
local_address=local_address,
|
||||
) as client_stream:
|
||||
assert client_stream.socket.getsockname()[0] == local_address
|
||||
if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
|
||||
assert client_stream.socket.getsockopt(
|
||||
trio.socket.IPPROTO_IP,
|
||||
trio.socket.IP_BIND_ADDRESS_NO_PORT,
|
||||
)
|
||||
server_sock, remote_addr = await listener.accept()
|
||||
await client_stream.aclose()
|
||||
server_sock.close()
|
||||
# accept returns tuple[SocketType, object], due to typeshed returning `Any`
|
||||
assert remote_addr[0] == local_address
|
||||
|
||||
# Trying to connect to an ipv4 address with the ipv6 wildcard
|
||||
# local_address should fail
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^all attempts to connect* to *127\.0\.0\.\d:\d+ failed$",
|
||||
):
|
||||
await open_tcp_stream(*listener.getsockname(), local_address="::")
|
||||
|
||||
# But the ipv4 wildcard address should work
|
||||
async with await open_tcp_stream(
|
||||
*listener.getsockname(),
|
||||
local_address="0.0.0.0",
|
||||
) as client_stream:
|
||||
server_sock, remote_addr = await listener.accept()
|
||||
server_sock.close()
|
||||
assert remote_addr == client_stream.socket.getsockname()
|
||||
|
||||
|
||||
# Now, thorough tests using fake sockets
|
||||
|
||||
|
||||
@attrs.define(eq=False, slots=False)
|
||||
class FakeSocket(trio.socket.SocketType):
|
||||
scenario: Scenario
|
||||
_family: AddressFamily
|
||||
_type: SocketKind
|
||||
_proto: int
|
||||
|
||||
ip: str | int | None = None
|
||||
port: str | int | None = None
|
||||
succeeded: bool = False
|
||||
closed: bool = False
|
||||
failing: bool = False
|
||||
|
||||
@property
|
||||
def type(self) -> SocketKind:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def family(self) -> AddressFamily: # pragma: no cover
|
||||
return self._family
|
||||
|
||||
@property
|
||||
def proto(self) -> int: # pragma: no cover
|
||||
return self._proto
|
||||
|
||||
async def connect(self, sockaddr: tuple[str | int, str | int | None]) -> None:
|
||||
self.ip = sockaddr[0]
|
||||
self.port = sockaddr[1]
|
||||
assert self.ip not in self.scenario.sockets
|
||||
self.scenario.sockets[self.ip] = self
|
||||
self.scenario.connect_times[self.ip] = trio.current_time()
|
||||
delay, result = self.scenario.ip_dict[self.ip]
|
||||
await trio.sleep(delay)
|
||||
if result == "error":
|
||||
raise OSError("sorry")
|
||||
if result == "postconnect_fail":
|
||||
self.failing = True
|
||||
self.succeeded = True
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
# called when SocketStream is constructed
|
||||
def setsockopt(self, *args: object, **kwargs: object) -> None:
|
||||
if self.failing:
|
||||
# raise something that isn't OSError as SocketStream
|
||||
# ignores those
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
|
||||
def __init__(
|
||||
self,
|
||||
port: int,
|
||||
ip_list: Sequence[tuple[str, float, str]],
|
||||
supported_families: set[AddressFamily],
|
||||
) -> None:
|
||||
# ip_list have to be unique
|
||||
ip_order = [ip for (ip, _, _) in ip_list]
|
||||
assert len(set(ip_order)) == len(ip_list)
|
||||
ip_dict: dict[str | int, tuple[float, str]] = {}
|
||||
for ip, delay, result in ip_list:
|
||||
assert delay >= 0
|
||||
assert result in ["error", "success", "postconnect_fail"]
|
||||
ip_dict[ip] = (delay, result)
|
||||
|
||||
self.port = port
|
||||
self.ip_order = ip_order
|
||||
self.ip_dict = ip_dict
|
||||
self.supported_families = supported_families
|
||||
self.socket_count = 0
|
||||
self.sockets: dict[str | int, FakeSocket] = {}
|
||||
self.connect_times: dict[str | int, float] = {}
|
||||
|
||||
def socket(
|
||||
self,
|
||||
family: AddressFamily | int | None = None,
|
||||
type_: SocketKind | int | None = None,
|
||||
proto: int | None = None,
|
||||
) -> SocketType:
|
||||
assert isinstance(family, AddressFamily)
|
||||
assert isinstance(type_, SocketKind)
|
||||
assert proto is not None
|
||||
if family not in self.supported_families:
|
||||
raise OSError("pretending not to support this family")
|
||||
self.socket_count += 1
|
||||
return FakeSocket(self, family, type_, proto)
|
||||
|
||||
def _ip_to_gai_entry(self, ip: str) -> tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int, int, int] | tuple[str, int] | tuple[int, bytes],
|
||||
]:
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes]
|
||||
if ":" in ip:
|
||||
family = trio.socket.AF_INET6
|
||||
sockaddr = (ip, self.port, 0, 0)
|
||||
else:
|
||||
family = trio.socket.AF_INET
|
||||
sockaddr = (ip, self.port)
|
||||
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = -1,
|
||||
type: int = -1,
|
||||
proto: int = -1,
|
||||
flags: int = -1,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int, int, int] | tuple[str, int] | tuple[int, bytes],
|
||||
]
|
||||
]:
|
||||
assert host == b"test.example.com"
|
||||
assert port == self.port
|
||||
assert family == trio.socket.AF_UNSPEC
|
||||
assert type == trio.socket.SOCK_STREAM
|
||||
assert proto == 0
|
||||
assert flags == 0
|
||||
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
|
||||
|
||||
async def getnameinfo(
|
||||
self,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
flags: int,
|
||||
) -> tuple[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def check(self, succeeded: SocketType | None) -> None:
|
||||
# sockets only go into self.sockets when connect is called; make sure
|
||||
# all the sockets that were created did in fact go in there.
|
||||
assert self.socket_count == len(self.sockets)
|
||||
|
||||
for ip, socket_ in self.sockets.items():
|
||||
assert ip in self.ip_dict
|
||||
if socket_ is not succeeded:
|
||||
assert socket_.closed
|
||||
assert socket_.port == self.port
|
||||
|
||||
|
||||
async def run_scenario(
|
||||
# The port to connect to
|
||||
port: int,
|
||||
# A list of
|
||||
# (ip, delay, result)
|
||||
# tuples, where delay is in seconds and result is "success" or "error"
|
||||
# The ip's will be returned from getaddrinfo in this order, and then
|
||||
# connect() calls to them will have the given result.
|
||||
ip_list: Sequence[tuple[str, float, str]],
|
||||
*,
|
||||
# If False, AF_INET4/6 sockets error out on creation, before connect is
|
||||
# even called.
|
||||
ipv4_supported: bool = True,
|
||||
ipv6_supported: bool = True,
|
||||
# Normally, we return (winning_sock, scenario object)
|
||||
# If this is True, we require there to be an exception, and return
|
||||
# (exception, scenario object)
|
||||
expect_error: tuple[type[BaseException], ...] | type[BaseException] = (),
|
||||
happy_eyeballs_delay: float | None = 0.25,
|
||||
local_address: str | None = None,
|
||||
) -> tuple[SocketType, Scenario] | tuple[BaseException, Scenario]:
|
||||
supported_families = set()
|
||||
if ipv4_supported:
|
||||
supported_families.add(trio.socket.AF_INET)
|
||||
if ipv6_supported:
|
||||
supported_families.add(trio.socket.AF_INET6)
|
||||
scenario = Scenario(port, ip_list, supported_families)
|
||||
trio.socket.set_custom_hostname_resolver(scenario)
|
||||
trio.socket.set_custom_socket_factory(scenario)
|
||||
|
||||
try:
|
||||
stream = await open_tcp_stream(
|
||||
"test.example.com",
|
||||
port,
|
||||
happy_eyeballs_delay=happy_eyeballs_delay,
|
||||
local_address=local_address,
|
||||
)
|
||||
assert expect_error == ()
|
||||
scenario.check(stream.socket)
|
||||
return (stream.socket, scenario)
|
||||
except AssertionError: # pragma: no cover
|
||||
raise
|
||||
except expect_error as exc:
|
||||
scenario.check(None)
|
||||
return (exc, scenario)
|
||||
|
||||
|
||||
async def test_one_host_quick_success(autojump_clock: MockClock) -> None:
|
||||
sock, _scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "1.2.3.4"
|
||||
assert trio.current_time() == 0.123
|
||||
|
||||
|
||||
async def test_one_host_slow_success(autojump_clock: MockClock) -> None:
|
||||
sock, _scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "1.2.3.4"
|
||||
assert trio.current_time() == 100
|
||||
|
||||
|
||||
async def test_one_host_quick_fail(autojump_clock: MockClock) -> None:
|
||||
exc, _scenario = await run_scenario(
|
||||
82,
|
||||
[("1.2.3.4", 0.123, "error")],
|
||||
expect_error=OSError,
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
assert trio.current_time() == 0.123
|
||||
|
||||
|
||||
async def test_one_host_slow_fail(autojump_clock: MockClock) -> None:
|
||||
exc, _scenario = await run_scenario(
|
||||
83,
|
||||
[("1.2.3.4", 100, "error")],
|
||||
expect_error=OSError,
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
assert trio.current_time() == 100
|
||||
|
||||
|
||||
async def test_one_host_failed_after_connect(autojump_clock: MockClock) -> None:
|
||||
exc, _scenario = await run_scenario(
|
||||
83,
|
||||
[("1.2.3.4", 1, "postconnect_fail")],
|
||||
expect_error=KeyboardInterrupt,
|
||||
)
|
||||
assert isinstance(exc, KeyboardInterrupt)
|
||||
|
||||
|
||||
# With the default 0.250 second delay, the third attempt will win
|
||||
async def test_basic_fallthrough(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "3.3.3.3"
|
||||
# current time is default time + default time + connection time
|
||||
assert trio.current_time() == (0.250 + 0.250 + 0.2)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
"3.3.3.3": 0.500,
|
||||
}
|
||||
|
||||
|
||||
async def test_early_success(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 0.1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "2.2.2.2"
|
||||
assert trio.current_time() == (0.250 + 0.1)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
# 3.3.3.3 was never even started
|
||||
}
|
||||
|
||||
|
||||
# With a 0.450 second delay, the first attempt will win
|
||||
async def test_custom_delay(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=0.450,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "1.1.1.1"
|
||||
assert trio.current_time() == 1
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.450,
|
||||
"3.3.3.3": 0.900,
|
||||
}
|
||||
|
||||
|
||||
async def test_none_default(autojump_clock: MockClock) -> None:
|
||||
"""Copy of test_basic_fallthrough, but specifying the delay =None"""
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=None,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "3.3.3.3"
|
||||
# current time is default time + default time + connection time
|
||||
assert trio.current_time() == (0.250 + 0.250 + 0.2)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
"3.3.3.3": 0.500,
|
||||
}
|
||||
|
||||
|
||||
async def test_custom_errors_expedite(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.1, "error"),
|
||||
("2.2.2.2", 0.2, "error"),
|
||||
("3.3.3.3", 10, "success"),
|
||||
# .25 is the default timeout
|
||||
("4.4.4.4", 0.25, "success"),
|
||||
],
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "4.4.4.4"
|
||||
assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.1,
|
||||
"3.3.3.3": 0.1 + 0.2,
|
||||
"4.4.4.4": 0.1 + 0.2 + 0.25,
|
||||
}
|
||||
|
||||
|
||||
async def test_all_fail(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.1, "error"),
|
||||
("2.2.2.2", 0.2, "error"),
|
||||
("3.3.3.3", 10, "error"),
|
||||
("4.4.4.4", 0.250, "error"),
|
||||
],
|
||||
expect_error=OSError,
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
|
||||
subexceptions = (pytest.RaisesExc(OSError, match="^sorry$"),) * 4
|
||||
assert pytest.RaisesGroup(
|
||||
*subexceptions,
|
||||
match="all attempts to connect to test.example.com:80 failed",
|
||||
).matches(exc.__cause__)
|
||||
|
||||
assert trio.current_time() == (0.1 + 0.2 + 10)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.1,
|
||||
"3.3.3.3": 0.1 + 0.2,
|
||||
"4.4.4.4": 0.1 + 0.2 + 0.25,
|
||||
}
|
||||
|
||||
|
||||
async def test_multi_success(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.5, "error"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("3.3.3.3", 10 - 1, "success"),
|
||||
("4.4.4.4", 10 - 2, "success"),
|
||||
("5.5.5.5", 0.5, "error"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
assert not scenario.sockets["1.1.1.1"].succeeded
|
||||
assert (
|
||||
scenario.sockets["2.2.2.2"].succeeded
|
||||
or scenario.sockets["3.3.3.3"].succeeded
|
||||
or scenario.sockets["4.4.4.4"].succeeded
|
||||
)
|
||||
assert not scenario.sockets["5.5.5.5"].succeeded
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
|
||||
assert trio.current_time() == (0.5 + 10)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.5,
|
||||
"3.3.3.3": 1.5,
|
||||
"4.4.4.4": 2.5,
|
||||
"5.5.5.5": 3.5,
|
||||
}
|
||||
|
||||
|
||||
async def test_does_reorder(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 10, "error"),
|
||||
# This would win if we tried it first...
|
||||
("2.2.2.2", 1, "success"),
|
||||
# But in fact we try this first, because of section 5.4
|
||||
("::3", 0.5, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "::3"
|
||||
assert trio.current_time() == 1 + 0.5
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"::3": 1,
|
||||
}
|
||||
|
||||
|
||||
async def test_handles_no_ipv4(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
# Here the ipv6 addresses fail at socket creation time, so the connect
|
||||
# configuration doesn't matter
|
||||
[
|
||||
("::1", 10, "success"),
|
||||
("2.2.2.2", 0, "success"),
|
||||
("::3", 0.1, "success"),
|
||||
("4.4.4.4", 0, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
ipv4_supported=False,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "::3"
|
||||
assert trio.current_time() == 1 + 0.1
|
||||
assert scenario.connect_times == {
|
||||
"::1": 0,
|
||||
"::3": 1.0,
|
||||
}
|
||||
|
||||
|
||||
async def test_handles_no_ipv6(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
# Here the ipv6 addresses fail at socket creation time, so the connect
|
||||
# configuration doesn't matter
|
||||
[
|
||||
("::1", 0, "success"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("::3", 0, "success"),
|
||||
("4.4.4.4", 0.1, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
ipv6_supported=False,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "4.4.4.4"
|
||||
assert trio.current_time() == 1 + 0.1
|
||||
assert scenario.connect_times == {
|
||||
"2.2.2.2": 0,
|
||||
"4.4.4.4": 1.0,
|
||||
}
|
||||
|
||||
|
||||
async def test_no_hosts(autojump_clock: MockClock) -> None:
|
||||
exc, _scenario = await run_scenario(80, [], expect_error=OSError)
|
||||
assert "no results found" in str(exc)
|
||||
|
||||
|
||||
async def test_cancel(autojump_clock: MockClock) -> None:
|
||||
with trio.move_on_after(5) as cancel_scope:
|
||||
exc, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 10, "success"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("3.3.3.3", 10, "success"),
|
||||
("4.4.4.4", 10, "success"),
|
||||
],
|
||||
expect_error=BaseExceptionGroup,
|
||||
)
|
||||
assert isinstance(exc, BaseException)
|
||||
# What comes out should be 1 or more Cancelled errors that all belong
|
||||
# to this cancel_scope; this is the easiest way to check that
|
||||
raise exc
|
||||
assert cancel_scope.cancelled_caught
|
||||
|
||||
assert trio.current_time() == 5
|
||||
|
||||
# This should have been called already, but just to make sure, since the
|
||||
# exception-handling logic in run_scenario is a bit complicated and the
|
||||
# main thing we care about here is that all the sockets were cleaned up.
|
||||
scenario.check(succeeded=None)
|
||||
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from trio import Path, open_unix_socket
|
||||
from trio._highlevel_open_unix_stream import close_on_error
|
||||
|
||||
assert not TYPE_CHECKING or sys.platform != "win32"
|
||||
|
||||
skip_if_not_unix = pytest.mark.skipif(
|
||||
not hasattr(socket, "AF_UNIX"),
|
||||
reason="Needs unix socket support",
|
||||
)
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
def test_close_on_error() -> None:
|
||||
class CloseMe:
|
||||
closed = False
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
with close_on_error(CloseMe()) as c:
|
||||
pass
|
||||
assert not c.closed
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with close_on_error(CloseMe()) as c:
|
||||
raise RuntimeError
|
||||
assert c.closed
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
@pytest.mark.parametrize("filename", [4, 4.5])
|
||||
async def test_open_with_bad_filename_type(filename: float) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
await open_unix_socket(filename) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
async def test_open_bad_socket() -> None:
|
||||
# mktemp is marked as insecure, but that's okay, we don't want the file to
|
||||
# exist
|
||||
name = tempfile.mktemp()
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await open_unix_socket(name)
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
async def test_open_unix_socket() -> None:
|
||||
for name_type in [Path, str]:
|
||||
name = tempfile.mktemp()
|
||||
serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
with serv_sock:
|
||||
serv_sock.bind(name)
|
||||
try:
|
||||
serv_sock.listen(1)
|
||||
|
||||
# The actual function we're testing
|
||||
unix_socket = await open_unix_socket(name_type(name))
|
||||
|
||||
async with unix_socket:
|
||||
client, _ = serv_sock.accept()
|
||||
with client:
|
||||
await unix_socket.send_all(b"test")
|
||||
assert client.recv(2048) == b"test"
|
||||
|
||||
client.sendall(b"response")
|
||||
received = await unix_socket.receive_some(2048)
|
||||
assert received == b"response"
|
||||
finally:
|
||||
os.unlink(name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(hasattr(socket, "AF_UNIX"), reason="Test for non-unix platforms")
|
||||
async def test_error_on_no_unix() -> None:
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match=r"^Unix sockets are not supported on this platform$",
|
||||
):
|
||||
await open_unix_socket("")
|
||||
@@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, NoReturn, cast
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio import Nursery, StapledStream, TaskStatus
|
||||
from trio.testing import (
|
||||
MemoryReceiveStream,
|
||||
MemorySendStream,
|
||||
MockClock,
|
||||
memory_stream_pair,
|
||||
wait_all_tasks_blocked,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from trio._channel import MemoryReceiveChannel, MemorySendChannel
|
||||
from trio.abc import Stream
|
||||
|
||||
# types are somewhat tentative - I just bruteforced them until I got something that didn't
|
||||
# give errors
|
||||
StapledMemoryStream = StapledStream[MemorySendStream, MemoryReceiveStream]
|
||||
|
||||
|
||||
@attrs.define(eq=False, slots=False)
|
||||
class MemoryListener(trio.abc.Listener[StapledMemoryStream]):
|
||||
closed: bool = False
|
||||
accepted_streams: list[trio.abc.Stream] = attrs.Factory(list)
|
||||
queued_streams: tuple[
|
||||
MemorySendChannel[StapledMemoryStream],
|
||||
MemoryReceiveChannel[StapledMemoryStream],
|
||||
] = attrs.Factory(lambda: trio.open_memory_channel[StapledMemoryStream](1))
|
||||
accept_hook: Callable[[], Awaitable[object]] | None = None
|
||||
|
||||
async def connect(self) -> StapledMemoryStream:
|
||||
assert not self.closed
|
||||
client, server = memory_stream_pair()
|
||||
await self.queued_streams[0].send(server)
|
||||
return client
|
||||
|
||||
async def accept(self) -> StapledMemoryStream:
|
||||
await trio.lowlevel.checkpoint()
|
||||
assert not self.closed
|
||||
if self.accept_hook is not None:
|
||||
await self.accept_hook()
|
||||
stream = await self.queued_streams[1].receive()
|
||||
self.accepted_streams.append(stream)
|
||||
return stream
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.closed = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def test_serve_listeners_basic() -> None:
|
||||
listeners = [MemoryListener(), MemoryListener()]
|
||||
|
||||
record = []
|
||||
|
||||
def close_hook() -> None:
|
||||
# Make sure this is a forceful close
|
||||
assert trio.current_effective_deadline() == float("-inf")
|
||||
record.append("closed")
|
||||
|
||||
async def handler(stream: StapledMemoryStream) -> None:
|
||||
await stream.send_all(b"123")
|
||||
assert await stream.receive_some(10) == b"456"
|
||||
stream.send_stream.close_hook = close_hook
|
||||
stream.receive_stream.close_hook = close_hook
|
||||
|
||||
async def client(listener: MemoryListener) -> None:
|
||||
s = await listener.connect()
|
||||
assert await s.receive_some(10) == b"123"
|
||||
await s.send_all(b"456")
|
||||
|
||||
async def do_tests(parent_nursery: Nursery) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for listener in listeners:
|
||||
for _ in range(3):
|
||||
nursery.start_soon(client, listener)
|
||||
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# verifies that all 6 streams x 2 directions each were closed ok
|
||||
assert len(record) == 12
|
||||
|
||||
parent_nursery.cancel_scope.cancel()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
value = await nursery.start(
|
||||
trio.serve_listeners,
|
||||
handler,
|
||||
listeners,
|
||||
)
|
||||
assert isinstance(value, list)
|
||||
l2 = cast("list[MemoryListener]", value)
|
||||
assert l2 == listeners
|
||||
# This is just split into another function because gh-136 isn't
|
||||
# implemented yet
|
||||
nursery.start_soon(do_tests, nursery)
|
||||
|
||||
for listener in listeners:
|
||||
assert listener.closed
|
||||
|
||||
|
||||
async def test_serve_listeners_accept_unrecognized_error() -> None:
|
||||
for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
|
||||
listener = MemoryListener()
|
||||
|
||||
async def raise_error() -> NoReturn:
|
||||
raise error # noqa: B023 # Set from loop
|
||||
|
||||
def check_error(e: BaseException) -> bool:
|
||||
return e is error # noqa: B023
|
||||
|
||||
listener.accept_hook = raise_error
|
||||
|
||||
with pytest.RaisesGroup(pytest.RaisesExc(check=check_error)):
|
||||
await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_serve_listeners_accept_capacity_error(
|
||||
autojump_clock: MockClock,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
listener = MemoryListener()
|
||||
|
||||
async def raise_EMFILE() -> NoReturn:
|
||||
raise OSError(errno.EMFILE, "out of file descriptors")
|
||||
|
||||
listener.accept_hook = raise_EMFILE
|
||||
|
||||
# It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900
|
||||
# = 10 times total
|
||||
with trio.move_on_after(0.950):
|
||||
await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
|
||||
|
||||
assert len(caplog.records) == 10
|
||||
for record in caplog.records:
|
||||
assert "retrying" in record.msg
|
||||
assert record.exc_info is not None
|
||||
assert isinstance(record.exc_info[1], OSError)
|
||||
assert record.exc_info[1].errno == errno.EMFILE
|
||||
|
||||
|
||||
async def test_serve_listeners_connection_nursery(autojump_clock: MockClock) -> None:
|
||||
listener = MemoryListener()
|
||||
|
||||
async def handler(stream: Stream) -> None:
|
||||
await trio.sleep(1)
|
||||
|
||||
class Done(Exception):
|
||||
pass
|
||||
|
||||
async def connection_watcher(
|
||||
*,
|
||||
task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED,
|
||||
) -> NoReturn:
|
||||
async with trio.open_nursery() as nursery:
|
||||
task_status.started(nursery)
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(nursery.child_tasks) == 10
|
||||
raise Done
|
||||
|
||||
# the exception is wrapped twice because we open two nested nurseries
|
||||
with pytest.RaisesGroup(pytest.RaisesGroup(Done)):
|
||||
async with trio.open_nursery() as nursery:
|
||||
value = await nursery.start(connection_watcher)
|
||||
assert isinstance(value, trio.Nursery)
|
||||
handler_nursery: trio.Nursery = value
|
||||
await nursery.start(
|
||||
partial(
|
||||
trio.serve_listeners,
|
||||
handler,
|
||||
[listener],
|
||||
handler_nursery=handler_nursery,
|
||||
),
|
||||
)
|
||||
for _ in range(10):
|
||||
nursery.start_soon(listener.connect)
|
||||
@@ -0,0 +1,336 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import socket as stdlib_socket
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import _core, socket as tsocket
|
||||
from .._highlevel_socket import *
|
||||
from ..testing import (
|
||||
assert_checkpoints,
|
||||
check_half_closeable_stream,
|
||||
wait_all_tasks_blocked,
|
||||
)
|
||||
from .test_socket import setsockopt_tests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
async def test_SocketStream_basics() -> None:
|
||||
# stdlib socket bad (even if connected)
|
||||
stdlib_a, stdlib_b = stdlib_socket.socketpair()
|
||||
with stdlib_a, stdlib_b:
|
||||
with pytest.raises(TypeError):
|
||||
SocketStream(stdlib_a) # type: ignore[arg-type]
|
||||
|
||||
# DGRAM socket bad
|
||||
with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"^SocketStream requires a SOCK_STREAM socket$",
|
||||
):
|
||||
# TODO: does not raise an error?
|
||||
SocketStream(sock)
|
||||
|
||||
a, b = tsocket.socketpair()
|
||||
with a, b:
|
||||
s = SocketStream(a)
|
||||
assert s.socket is a
|
||||
|
||||
# Use a real, connected socket to test socket options, because
|
||||
# socketpair() might give us a unix socket that doesn't support any of
|
||||
# these options
|
||||
with tsocket.socket() as listen_sock:
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(1)
|
||||
with tsocket.socket() as client_sock:
|
||||
await client_sock.connect(listen_sock.getsockname())
|
||||
|
||||
s = SocketStream(client_sock)
|
||||
|
||||
# TCP_NODELAY enabled by default
|
||||
assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
||||
# We can disable it though
|
||||
s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
|
||||
assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
||||
|
||||
res = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
|
||||
assert isinstance(res, bytes)
|
||||
|
||||
setsockopt_tests(s)
|
||||
|
||||
|
||||
async def test_SocketStream_send_all() -> None:
|
||||
BIG = 10000000
|
||||
|
||||
a_sock, b_sock = tsocket.socketpair()
|
||||
with a_sock, b_sock:
|
||||
a = SocketStream(a_sock)
|
||||
b = SocketStream(b_sock)
|
||||
|
||||
# Check a send_all that has to be split into multiple parts (on most
|
||||
# platforms... on Windows every send() either succeeds or fails as a
|
||||
# whole)
|
||||
async def sender() -> None:
|
||||
data = bytearray(BIG)
|
||||
await a.send_all(data)
|
||||
# send_all uses memoryviews internally, which temporarily "lock"
|
||||
# the object they view. If it doesn't clean them up properly, then
|
||||
# some bytearray operations might raise an error afterwards, which
|
||||
# would be a pretty weird and annoying side-effect to spring on
|
||||
# users. So test that this doesn't happen, by forcing the
|
||||
# bytearray's underlying buffer to be realloc'ed:
|
||||
data += bytes(BIG)
|
||||
# (Note: the above line of code doesn't do a very good job at
|
||||
# testing anything, because:
|
||||
# - on CPython, the refcount GC generally cleans up memoryviews
|
||||
# for us even if we're sloppy.
|
||||
# - on PyPy3, at least as of 5.7.0, the memoryview code and the
|
||||
# bytearray code conspire so that resizing never fails – if
|
||||
# resizing forces the bytearray's internal buffer to move, then
|
||||
# all memoryview references are automagically updated (!!).
|
||||
# See:
|
||||
# https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
|
||||
# But I'm leaving the test here in hopes that if this ever changes
|
||||
# and we break our implementation of send_all, then we'll get some
|
||||
# early warning...)
|
||||
|
||||
async def receiver() -> None:
|
||||
# Make sure the sender fills up the kernel buffers and blocks
|
||||
await wait_all_tasks_blocked()
|
||||
nbytes = 0
|
||||
while nbytes < BIG:
|
||||
nbytes += len(await b.receive_some(BIG))
|
||||
assert nbytes == BIG
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(receiver)
|
||||
|
||||
# We know that we received BIG bytes of NULs so far. Make sure that
|
||||
# was all the data in there.
|
||||
await a.send_all(b"e")
|
||||
assert await b.receive_some(10) == b"e"
|
||||
await a.send_eof()
|
||||
assert await b.receive_some(10) == b""
|
||||
|
||||
|
||||
async def fill_stream(s: SocketStream) -> None:
|
||||
async def sender() -> None:
|
||||
while True:
|
||||
await s.send_all(b"x" * 10000)
|
||||
|
||||
async def waiter(nursery: _core.Nursery) -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(waiter, nursery)
|
||||
|
||||
|
||||
async def test_SocketStream_generic() -> None:
|
||||
async def stream_maker() -> tuple[
|
||||
SocketStream,
|
||||
SocketStream,
|
||||
]:
|
||||
left, right = tsocket.socketpair()
|
||||
return SocketStream(left), SocketStream(right)
|
||||
|
||||
async def clogged_stream_maker() -> tuple[SocketStream, SocketStream]:
|
||||
left, right = await stream_maker()
|
||||
await fill_stream(left)
|
||||
await fill_stream(right)
|
||||
return left, right
|
||||
|
||||
await check_half_closeable_stream(stream_maker, clogged_stream_maker)
|
||||
|
||||
|
||||
async def test_SocketListener() -> None:
|
||||
# Not a Trio socket
|
||||
with stdlib_socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
s.listen(10)
|
||||
with pytest.raises(TypeError):
|
||||
SocketListener(s) # type: ignore[arg-type]
|
||||
|
||||
# Not a SOCK_STREAM
|
||||
with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"^SocketListener requires a SOCK_STREAM socket$",
|
||||
) as excinfo:
|
||||
SocketListener(s)
|
||||
excinfo.match(r".*SOCK_STREAM")
|
||||
|
||||
# Didn't call .listen()
|
||||
# macOS has no way to check for this, so skip testing it there.
|
||||
if sys.platform != "darwin":
|
||||
with tsocket.socket() as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"^SocketListener requires a listening socket$",
|
||||
) as excinfo:
|
||||
SocketListener(s)
|
||||
excinfo.match(r".*listen")
|
||||
|
||||
listen_sock = tsocket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(10)
|
||||
listener = SocketListener(listen_sock)
|
||||
|
||||
assert listener.socket is listen_sock
|
||||
|
||||
client_sock = tsocket.socket()
|
||||
await client_sock.connect(listen_sock.getsockname())
|
||||
with assert_checkpoints():
|
||||
server_stream = await listener.accept()
|
||||
assert isinstance(server_stream, SocketStream)
|
||||
assert server_stream.socket.getsockname() == listen_sock.getsockname()
|
||||
assert server_stream.socket.getpeername() == client_sock.getsockname()
|
||||
|
||||
with assert_checkpoints():
|
||||
await listener.aclose()
|
||||
|
||||
with assert_checkpoints():
|
||||
await listener.aclose()
|
||||
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await listener.accept()
|
||||
|
||||
client_sock.close()
|
||||
await server_stream.aclose()
|
||||
|
||||
|
||||
async def test_SocketListener_socket_closed_underfoot() -> None:
|
||||
listen_sock = tsocket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(10)
|
||||
listener = SocketListener(listen_sock)
|
||||
|
||||
# Close the socket, not the listener
|
||||
listen_sock.close()
|
||||
|
||||
# SocketListener gives correct error
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await listener.accept()
|
||||
|
||||
|
||||
async def test_SocketListener_accept_errors() -> None:
|
||||
class FakeSocket(tsocket.SocketType):
|
||||
def __init__(self, events: Sequence[SocketType | BaseException]) -> None:
|
||||
self._events = iter(events)
|
||||
|
||||
type = tsocket.SOCK_STREAM
|
||||
|
||||
# Fool the check for SO_ACCEPTCONN in SocketListener.__init__
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def getsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int,
|
||||
) -> bytes: ...
|
||||
|
||||
def getsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int | None = None,
|
||||
) -> int | bytes:
|
||||
return True
|
||||
|
||||
@overload
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def setsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: None,
|
||||
optlen: int,
|
||||
) -> None: ...
|
||||
|
||||
def setsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer | None,
|
||||
optlen: int | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def accept(self) -> tuple[SocketType, object]:
|
||||
await _core.checkpoint()
|
||||
event = next(self._events)
|
||||
if isinstance(event, BaseException):
|
||||
raise event
|
||||
else:
|
||||
return event, None
|
||||
|
||||
fake_server_sock = FakeSocket([])
|
||||
|
||||
fake_listen_sock = FakeSocket(
|
||||
[
|
||||
OSError(errno.ECONNABORTED, "Connection aborted"),
|
||||
OSError(errno.EPERM, "Permission denied"),
|
||||
OSError(errno.EPROTO, "Bad protocol"),
|
||||
fake_server_sock,
|
||||
OSError(errno.EMFILE, "Out of file descriptors"),
|
||||
OSError(errno.EFAULT, "attempt to write to read-only memory"),
|
||||
OSError(errno.ENOBUFS, "out of buffers"),
|
||||
fake_server_sock,
|
||||
],
|
||||
)
|
||||
|
||||
listener = SocketListener(fake_listen_sock)
|
||||
|
||||
with assert_checkpoints():
|
||||
stream = await listener.accept()
|
||||
assert stream.socket is fake_server_sock
|
||||
|
||||
for code, match in {
|
||||
errno.EMFILE: r"\[\w+ \d+\] Out of file descriptors$",
|
||||
errno.EFAULT: r"\[\w+ \d+\] attempt to write to read-only memory$",
|
||||
errno.ENOBUFS: r"\[\w+ \d+\] out of buffers$",
|
||||
}.items():
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(OSError, match=match) as excinfo:
|
||||
await listener.accept()
|
||||
assert excinfo.value.errno == code
|
||||
|
||||
with assert_checkpoints():
|
||||
stream = await listener.accept()
|
||||
assert stream.socket is fake_server_sock
|
||||
|
||||
|
||||
async def test_socket_stream_works_when_peer_has_already_closed() -> None:
|
||||
sock_a, sock_b = tsocket.socketpair()
|
||||
with sock_a, sock_b:
|
||||
await sock_b.send(b"x")
|
||||
sock_b.close()
|
||||
stream = SocketStream(sock_a)
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
assert await stream.receive_some(1) == b""
|
||||
@@ -0,0 +1,169 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, NoReturn, cast
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM
|
||||
|
||||
from .._highlevel_ssl_helpers import (
|
||||
open_ssl_over_tcp_listeners,
|
||||
open_ssl_over_tcp_stream,
|
||||
serve_ssl_over_tcp,
|
||||
)
|
||||
|
||||
# using noqa because linters don't understand how pytest fixtures work.
|
||||
from .test_ssl import SERVER_CTX, client_ctx # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from socket import AddressFamily, SocketKind
|
||||
from ssl import SSLContext
|
||||
|
||||
from trio.abc import Stream
|
||||
|
||||
from .._highlevel_socket import SocketListener
|
||||
from .._ssl import SSLListener
|
||||
|
||||
|
||||
async def echo_handler(stream: Stream) -> None:
|
||||
async with stream:
|
||||
try:
|
||||
while True:
|
||||
data = await stream.receive_some(10000)
|
||||
if not data:
|
||||
break
|
||||
await stream.send_all(data)
|
||||
except trio.BrokenResourceError:
|
||||
pass
|
||||
|
||||
|
||||
# Resolver that always returns the given sockaddr, no matter what host/port
|
||||
# you ask for.
|
||||
@attrs.define(slots=False)
|
||||
class FakeHostnameResolver(trio.abc.HostnameResolver):
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes]
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = 0,
|
||||
type: int = 0,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
|
||||
]
|
||||
]:
|
||||
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]
|
||||
|
||||
async def getnameinfo(
|
||||
self,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
flags: int,
|
||||
) -> NoReturn: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
|
||||
# using noqa because linters don't understand how pytest fixtures work.
|
||||
async def test_open_ssl_over_tcp_stream_and_everything_else(
|
||||
client_ctx: SSLContext, # noqa: F811 # linters doesn't understand fixture
|
||||
) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# TODO: this function wraps an SSLListener around a SocketListener, this is illegal
|
||||
# according to current type hints, and probably for good reason. But there should
|
||||
# maybe be a different wrapper class/function that could be used instead?
|
||||
value = await nursery.start(
|
||||
partial(
|
||||
serve_ssl_over_tcp,
|
||||
echo_handler,
|
||||
0,
|
||||
SERVER_CTX,
|
||||
host="127.0.0.1",
|
||||
),
|
||||
)
|
||||
assert isinstance(value, list)
|
||||
res = cast("list[SSLListener[SocketListener]]", value) # type: ignore[type-var]
|
||||
(listener,) = res
|
||||
async with listener:
|
||||
# listener.transport_listener is of type Listener[Stream]
|
||||
tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment]
|
||||
|
||||
sockaddr = tp_listener.socket.getsockname()
|
||||
hostname_resolver = FakeHostnameResolver(sockaddr)
|
||||
trio.socket.set_custom_hostname_resolver(hostname_resolver)
|
||||
|
||||
# We don't have the right trust set up
|
||||
# (checks that ssl_context=None is doing some validation)
|
||||
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
|
||||
async with stream:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await stream.do_handshake()
|
||||
|
||||
# We have the trust but not the hostname
|
||||
# (checks custom ssl_context + hostname checking)
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"xyzzy.example.org",
|
||||
80,
|
||||
ssl_context=client_ctx,
|
||||
)
|
||||
async with stream:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await stream.do_handshake()
|
||||
|
||||
# This one should work!
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"trio-test-1.example.org",
|
||||
80,
|
||||
ssl_context=client_ctx,
|
||||
)
|
||||
async with stream:
|
||||
assert isinstance(stream, trio.SSLStream)
|
||||
assert stream.server_hostname == "trio-test-1.example.org"
|
||||
await stream.send_all(b"x")
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
|
||||
# Check https_compatible settings are being passed through
|
||||
assert not stream._https_compatible
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"trio-test-1.example.org",
|
||||
80,
|
||||
ssl_context=client_ctx,
|
||||
https_compatible=True,
|
||||
# also, smoke test happy_eyeballs_delay
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
async with stream:
|
||||
assert stream._https_compatible
|
||||
|
||||
# Stop the echo server
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_open_ssl_over_tcp_listeners() -> None:
|
||||
(listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
|
||||
async with listener:
|
||||
assert isinstance(listener, trio.SSLListener)
|
||||
tl = listener.transport_listener
|
||||
assert isinstance(tl, trio.SocketListener)
|
||||
assert tl.socket.getsockname()[0] == "127.0.0.1"
|
||||
|
||||
assert not listener._https_compatible
|
||||
|
||||
(listener,) = await open_ssl_over_tcp_listeners(
|
||||
0,
|
||||
SERVER_CTX,
|
||||
host="127.0.0.1",
|
||||
https_compatible=True,
|
||||
)
|
||||
async with listener:
|
||||
assert listener._https_compatible
|
||||
@@ -0,0 +1,279 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio._file_io import AsyncIOWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def path(tmp_path: pathlib.Path) -> trio.Path:
|
||||
return trio.Path(tmp_path / "test")
|
||||
|
||||
|
||||
def method_pair(
|
||||
path: str,
|
||||
method_name: str,
|
||||
) -> tuple[Callable[[], object], Callable[[], Awaitable[object]]]:
|
||||
sync_path = pathlib.Path(path)
|
||||
async_path = trio.Path(path)
|
||||
return getattr(sync_path, method_name), getattr(async_path, method_name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name == "nt", reason="OS is not posix")
|
||||
def test_instantiate_posix() -> None:
|
||||
assert isinstance(trio.Path(), trio.PosixPath)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name != "nt", reason="OS is not Windows")
|
||||
def test_instantiate_windows() -> None:
|
||||
assert isinstance(trio.Path(), trio.WindowsPath)
|
||||
|
||||
|
||||
async def test_open_is_async_context_manager(path: trio.Path) -> None:
|
||||
async with await path.open("w") as f:
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
def test_magic() -> None:
|
||||
path = trio.Path("test")
|
||||
|
||||
assert str(path) == "test"
|
||||
assert bytes(path) == b"test"
|
||||
|
||||
|
||||
EitherPathType = type[trio.Path] | type[pathlib.Path]
|
||||
PathOrStrType = EitherPathType | type[str]
|
||||
cls_pairs: list[tuple[EitherPathType, EitherPathType]] = [
|
||||
(trio.Path, pathlib.Path),
|
||||
(pathlib.Path, trio.Path),
|
||||
(trio.Path, trio.Path),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs)
|
||||
def test_cmp_magic(cls_a: EitherPathType, cls_b: EitherPathType) -> None:
|
||||
a, b = cls_a(""), cls_b("")
|
||||
assert a == b
|
||||
assert not a != b # noqa: SIM202 # negate-not-equal-op
|
||||
|
||||
a, b = cls_a("a"), cls_b("b")
|
||||
assert a < b
|
||||
assert b > a
|
||||
|
||||
# this is intentionally testing equivalence with none, due to the
|
||||
# other=sentinel logic in _forward_magic
|
||||
assert not a == None # noqa
|
||||
assert not b == None # noqa
|
||||
|
||||
|
||||
# upstream python3.8 bug: we should also test (pathlib.Path, trio.Path), but
|
||||
# __*div__ does not properly raise NotImplementedError like the other comparison
|
||||
# magic, so trio.Path's implementation does not get dispatched
|
||||
cls_pairs_str: list[tuple[PathOrStrType, PathOrStrType]] = [
|
||||
(trio.Path, pathlib.Path),
|
||||
(trio.Path, trio.Path),
|
||||
(trio.Path, str),
|
||||
(str, trio.Path),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs_str)
|
||||
def test_div_magic(cls_a: PathOrStrType, cls_b: PathOrStrType) -> None:
|
||||
a, b = cls_a("a"), cls_b("b")
|
||||
|
||||
result = a / b # type: ignore[operator]
|
||||
# Type checkers think str / str could happen. Check each combo manually in type_tests/.
|
||||
assert isinstance(result, trio.Path)
|
||||
assert str(result) == os.path.join("a", "b")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cls_a", "cls_b"),
|
||||
[(trio.Path, pathlib.Path), (trio.Path, trio.Path)],
|
||||
)
|
||||
@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"])
|
||||
def test_hash_magic(
|
||||
cls_a: EitherPathType,
|
||||
cls_b: EitherPathType,
|
||||
path: str,
|
||||
) -> None:
|
||||
a, b = cls_a(path), cls_b(path)
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
|
||||
def test_forwarded_properties(path: trio.Path) -> None:
|
||||
# use `name` as a representative of forwarded properties
|
||||
|
||||
assert "name" in dir(path)
|
||||
assert path.name == "test"
|
||||
|
||||
|
||||
def test_async_method_signature(path: trio.Path) -> None:
|
||||
# use `resolve` as a representative of wrapped methods
|
||||
|
||||
assert path.resolve.__name__ == "resolve"
|
||||
assert path.resolve.__qualname__ == "Path.resolve"
|
||||
|
||||
assert path.resolve.__doc__ is not None
|
||||
assert path.resolve.__qualname__ in path.resolve.__doc__
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
|
||||
async def test_compare_async_stat_methods(method_name: str) -> None:
|
||||
method, async_method = method_pair(".", method_name)
|
||||
|
||||
result = method()
|
||||
async_result = await async_method()
|
||||
|
||||
assert result == async_result
|
||||
|
||||
|
||||
def test_invalid_name_not_wrapped(path: trio.Path) -> None:
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(path, "invalid_fake_attr") # noqa: B009 # "get-attr-with-constant"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
|
||||
async def test_async_methods_rewrap(method_name: str) -> None:
|
||||
method, async_method = method_pair(".", method_name)
|
||||
|
||||
result = method()
|
||||
async_result = await async_method()
|
||||
|
||||
assert isinstance(async_result, trio.Path)
|
||||
assert str(result) == str(async_result)
|
||||
|
||||
|
||||
def test_forward_methods_rewrap(path: trio.Path, tmp_path: pathlib.Path) -> None:
|
||||
with_name = path.with_name("foo")
|
||||
with_suffix = path.with_suffix(".py")
|
||||
|
||||
assert isinstance(with_name, trio.Path)
|
||||
assert with_name == tmp_path / "foo"
|
||||
assert isinstance(with_suffix, trio.Path)
|
||||
assert with_suffix == tmp_path / "test.py"
|
||||
|
||||
|
||||
def test_forward_properties_rewrap(path: trio.Path) -> None:
|
||||
assert isinstance(path.parent, trio.Path)
|
||||
|
||||
|
||||
def test_forward_methods_without_rewrap(path: trio.Path) -> None:
|
||||
assert "totally-unique-path" in str(path.joinpath("totally-unique-path"))
|
||||
|
||||
|
||||
def test_repr() -> None:
|
||||
path = trio.Path(".")
|
||||
|
||||
assert repr(path) == "trio.Path('.')"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath])
|
||||
async def test_path_wraps_path(
|
||||
path: trio.Path,
|
||||
meth: Callable[[trio.Path, trio.Path], object],
|
||||
) -> None:
|
||||
wrapped = await path.absolute()
|
||||
result = meth(path, wrapped)
|
||||
if result is None:
|
||||
result = path
|
||||
|
||||
assert wrapped == result
|
||||
|
||||
|
||||
def test_path_nonpath() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
trio.Path(1) # type: ignore
|
||||
|
||||
|
||||
async def test_open_file_can_open_path(path: trio.Path) -> None:
|
||||
async with await trio.open_file(path, "w") as f:
|
||||
assert f.name == os.fspath(path)
|
||||
|
||||
|
||||
async def test_globmethods(path: trio.Path) -> None:
|
||||
# Populate a directory tree
|
||||
await path.mkdir()
|
||||
await (path / "foo").mkdir()
|
||||
await (path / "foo" / "_bar.txt").write_bytes(b"")
|
||||
await (path / "bar.txt").write_bytes(b"")
|
||||
await (path / "bar.dat").write_bytes(b"")
|
||||
|
||||
# Path.glob
|
||||
for pattern, results in {
|
||||
"*.txt": {"bar.txt"},
|
||||
"**/*.txt": {"_bar.txt", "bar.txt"},
|
||||
}.items():
|
||||
entries = set()
|
||||
for entry in await path.glob(pattern):
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == results
|
||||
|
||||
# Path.rglob
|
||||
entries = set()
|
||||
for entry in await path.rglob("*.txt"):
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == {"_bar.txt", "bar.txt"}
|
||||
|
||||
|
||||
async def test_as_uri(path: trio.Path) -> None:
|
||||
path = await path.parent.resolve()
|
||||
|
||||
assert path.as_uri().startswith("file:///")
|
||||
|
||||
|
||||
async def test_iterdir(path: trio.Path) -> None:
|
||||
# Populate a directory
|
||||
await path.mkdir()
|
||||
await (path / "foo").mkdir()
|
||||
await (path / "bar.txt").write_bytes(b"")
|
||||
|
||||
entries = set()
|
||||
for entry in await path.iterdir():
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == {"bar.txt", "foo"}
|
||||
|
||||
|
||||
async def test_classmethods() -> None:
|
||||
assert isinstance(await trio.Path.home(), trio.Path)
|
||||
|
||||
# pathlib.Path has only two classmethods
|
||||
assert str(await trio.Path.home()) == os.path.expanduser("~") # noqa: ASYNC240
|
||||
assert str(await trio.Path.cwd()) == os.getcwd()
|
||||
|
||||
# Wrapped method has docstring
|
||||
assert trio.Path.home.__doc__
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wrapper",
|
||||
[
|
||||
trio._path._wraps_async,
|
||||
trio._path._wrap_method,
|
||||
trio._path._wrap_method_path,
|
||||
trio._path._wrap_method_path_iterable,
|
||||
],
|
||||
)
|
||||
def test_wrapping_without_docstrings(
|
||||
wrapper: Callable[[Callable[[], None]], Callable[[], None]],
|
||||
) -> None:
|
||||
@wrapper
|
||||
def func_without_docstring() -> None: ... # pragma: no cover
|
||||
|
||||
assert func_without_docstring.__doc__ is None
|
||||
@@ -0,0 +1,428 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import Protocol
|
||||
|
||||
import pytest
|
||||
|
||||
import trio._repl
|
||||
|
||||
|
||||
class RawInput(Protocol):
|
||||
def __call__(self, prompt: str = "") -> str: ...
|
||||
|
||||
|
||||
def build_raw_input(cmds: list[str]) -> RawInput:
|
||||
"""
|
||||
Pass in a list of strings.
|
||||
Returns a callable that returns each string, each time its called
|
||||
When there are not more strings to return, raise EOFError
|
||||
"""
|
||||
cmds_iter = iter(cmds)
|
||||
prompts = []
|
||||
|
||||
def _raw_helper(prompt: str = "") -> str:
|
||||
prompts.append(prompt)
|
||||
try:
|
||||
return next(cmds_iter)
|
||||
except StopIteration:
|
||||
raise EOFError from None
|
||||
|
||||
return _raw_helper
|
||||
|
||||
|
||||
def test_build_raw_input() -> None:
|
||||
"""Quick test of our helper function."""
|
||||
raw_input = build_raw_input(["cmd1"])
|
||||
assert raw_input() == "cmd1"
|
||||
with pytest.raises(EOFError):
|
||||
raw_input()
|
||||
|
||||
|
||||
async def test_basic_interaction(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Run some basic commands through the interpreter while capturing stdout.
|
||||
Ensure that the interpreted prints the expected results.
|
||||
"""
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# evaluate simple expression and recall the value
|
||||
"x = 1",
|
||||
"print(f'{x=}')",
|
||||
# Literal gets printed
|
||||
"'hello'",
|
||||
# define and call sync function
|
||||
"def func():",
|
||||
" print(x + 1)",
|
||||
"",
|
||||
"func()",
|
||||
# define and call async function
|
||||
"async def afunc():",
|
||||
" return 4",
|
||||
"",
|
||||
"await afunc()",
|
||||
# import works
|
||||
"import sys",
|
||||
"sys.stdout.write('hello stdout\\n')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, _err = capsys.readouterr()
|
||||
assert out.splitlines() == ["x=1", "'hello'", "2", "4", "hello stdout", "13"]
|
||||
|
||||
|
||||
async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"raise SystemExit",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
with pytest.raises(SystemExit):
|
||||
await trio._repl.run_repl(console)
|
||||
|
||||
|
||||
async def test_KI_interrupts(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"import signal, trio, trio.lowlevel",
|
||||
"async def f():",
|
||||
" trio.lowlevel.spawn_system_task("
|
||||
" trio.to_thread.run_sync,"
|
||||
" signal.raise_signal, signal.SIGINT,"
|
||||
" )", # just awaiting this kills the test runner?!
|
||||
" await trio.sleep_forever()",
|
||||
" print('should not see this')",
|
||||
"",
|
||||
"await f()",
|
||||
"print('AFTER KeyboardInterrupt')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "KeyboardInterrupt" in err
|
||||
assert "should" not in out
|
||||
assert "AFTER KeyboardInterrupt" in out
|
||||
|
||||
|
||||
async def test_system_exits_in_exc_group(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"import sys",
|
||||
"if sys.version_info < (3, 11):",
|
||||
" from exceptiongroup import BaseExceptionGroup",
|
||||
"",
|
||||
"raise BaseExceptionGroup('', [RuntimeError(), SystemExit()])",
|
||||
"print('AFTER BaseExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, _err = capsys.readouterr()
|
||||
# assert that raise SystemExit in an exception group
|
||||
# doesn't quit
|
||||
assert "AFTER BaseExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_system_exits_in_nested_exc_group(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"import sys",
|
||||
"if sys.version_info < (3, 11):",
|
||||
" from exceptiongroup import BaseExceptionGroup",
|
||||
"",
|
||||
"raise BaseExceptionGroup(",
|
||||
" '', [BaseExceptionGroup('', [RuntimeError(), SystemExit()])])",
|
||||
"print('AFTER BaseExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, _err = capsys.readouterr()
|
||||
# assert that raise SystemExit in an exception group
|
||||
# doesn't quit
|
||||
assert "AFTER BaseExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_base_exception_captured(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# The statement after raise should still get executed
|
||||
"raise BaseException",
|
||||
"print('AFTER BaseException')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "_threads.py" not in err
|
||||
assert "_repl.py" not in err
|
||||
assert "AFTER BaseException" in out
|
||||
|
||||
|
||||
async def test_exc_group_captured(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# The statement after raise should still get executed
|
||||
"raise ExceptionGroup('', [KeyError()])",
|
||||
"print('AFTER ExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, _err = capsys.readouterr()
|
||||
assert "AFTER ExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_base_exception_capture_from_coroutine(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole()
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"async def async_func_raises_base_exception():",
|
||||
" raise BaseException",
|
||||
"",
|
||||
# This will raise, but the statement after should still
|
||||
# be executed
|
||||
"await async_func_raises_base_exception()",
|
||||
"print('AFTER BaseException')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "_threads.py" not in err
|
||||
assert "_repl.py" not in err
|
||||
assert "AFTER BaseException" in out
|
||||
|
||||
|
||||
def test_main_entrypoint() -> None:
|
||||
"""
|
||||
Basic smoke test when running via the package __main__ entrypoint.
|
||||
"""
|
||||
repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()")
|
||||
assert repl.returncode == 0
|
||||
|
||||
|
||||
def should_try_newline_injection() -> bool:
|
||||
if sys.platform != "linux":
|
||||
return False
|
||||
|
||||
sysctl = pathlib.Path("/proc/sys/dev/tty/legacy_tiocsti")
|
||||
if not sysctl.exists(): # pragma: no cover
|
||||
return True
|
||||
|
||||
else:
|
||||
return sysctl.read_text() == "1"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not should_try_newline_injection(),
|
||||
reason="the ioctl we use is disabled in CI",
|
||||
)
|
||||
def test_ki_newline_injection() -> None: # TODO: test this line
|
||||
# TODO: we want to remove this functionality, eg by using vendored
|
||||
# pyrepls.
|
||||
assert sys.platform != "win32"
|
||||
|
||||
import pty
|
||||
|
||||
# NOTE: this cannot be subprocess.Popen because pty.fork
|
||||
# does some magic to set the controlling terminal.
|
||||
# (which I don't know how to replicate... so I copied this
|
||||
# structure from pty.spawn...)
|
||||
pid, pty_fd = pty.fork() # type: ignore[attr-defined,unused-ignore]
|
||||
if pid == 0:
|
||||
os.execlp(sys.executable, *[sys.executable, "-u", "-m", "trio"])
|
||||
|
||||
# setup:
|
||||
buffer = b""
|
||||
while not buffer.endswith(b"import trio\r\n>>> "):
|
||||
buffer += os.read(pty_fd, 4096)
|
||||
|
||||
# sanity check:
|
||||
print(buffer.decode())
|
||||
buffer = b""
|
||||
os.write(pty_fd, b'print("hello!")\n')
|
||||
while not buffer.endswith(b">>> "):
|
||||
buffer += os.read(pty_fd, 4096)
|
||||
|
||||
assert buffer.count(b"hello!") == 2
|
||||
|
||||
# press ctrl+c
|
||||
print(buffer.decode())
|
||||
buffer = b""
|
||||
os.kill(pid, signal.SIGINT)
|
||||
while not buffer.endswith(b">>> "):
|
||||
buffer += os.read(pty_fd, 4096)
|
||||
|
||||
assert b"KeyboardInterrupt" in buffer
|
||||
|
||||
# press ctrl+c later
|
||||
print(buffer.decode())
|
||||
buffer = b""
|
||||
os.write(pty_fd, b'print("hello!")')
|
||||
os.kill(pid, signal.SIGINT)
|
||||
while not buffer.endswith(b">>> "):
|
||||
buffer += os.read(pty_fd, 4096)
|
||||
|
||||
assert b"KeyboardInterrupt" in buffer
|
||||
print(buffer.decode())
|
||||
os.close(pty_fd)
|
||||
os.waitpid(pid, 0)[1]
|
||||
|
||||
|
||||
async def test_ki_in_repl() -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
proc = await nursery.start(
|
||||
partial(
|
||||
trio.run_process,
|
||||
[sys.executable, "-u", "-m", "trio"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE,
|
||||
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if sys.platform == "win32" else 0, # type: ignore[attr-defined,unused-ignore]
|
||||
)
|
||||
)
|
||||
|
||||
async with proc.stdout:
|
||||
# setup
|
||||
buffer = b""
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
# TODO: consider making run_process stdout have some universal newlines thing
|
||||
if buffer.replace(b"\r\n", b"\n").endswith(b"import trio\n>>> "):
|
||||
break
|
||||
|
||||
# ensure things work
|
||||
print(buffer.decode())
|
||||
buffer = b""
|
||||
await proc.stdin.send_all(b'print("hello!")\n')
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
assert b"hello!" in buffer
|
||||
print(buffer.decode())
|
||||
|
||||
# this seems to be necessary on Windows for reasons
|
||||
# (the parents of process groups ignore ctrl+c by default...)
|
||||
if sys.platform == "win32":
|
||||
buffer = b""
|
||||
await proc.stdin.send_all(
|
||||
b"import ctypes; ctypes.windll.kernel32.SetConsoleCtrlHandler(None, False)\n"
|
||||
)
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
print(buffer.decode())
|
||||
|
||||
# try to decrease flakiness...
|
||||
buffer = b""
|
||||
await proc.stdin.send_all(
|
||||
b"import coverage; trio.lowlevel.enable_ki_protection(coverage.pytracer.PyTracer._trace)\n"
|
||||
)
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
print(buffer.decode())
|
||||
|
||||
# ensure that ctrl+c on a prompt works
|
||||
# NOTE: for some reason, signal.SIGINT doesn't work for this test.
|
||||
# Using CTRL_C_EVENT is also why we need subprocess.CREATE_NEW_PROCESS_GROUP
|
||||
signal_sent = signal.CTRL_C_EVENT if sys.platform == "win32" else signal.SIGINT # type: ignore[attr-defined,unused-ignore]
|
||||
os.kill(proc.pid, signal_sent)
|
||||
if sys.platform == "win32":
|
||||
# we rely on EOFError which... doesn't happen with pipes.
|
||||
# I'm not sure how to fix it...
|
||||
await proc.stdin.send_all(b"\n")
|
||||
else:
|
||||
# we test injection separately
|
||||
await proc.stdin.send_all(b"\n")
|
||||
|
||||
buffer = b""
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
assert b"KeyboardInterrupt" in buffer
|
||||
|
||||
# ensure ctrl+c while a command runs works
|
||||
print(buffer.decode())
|
||||
await proc.stdin.send_all(b'print("READY"); await trio.sleep_forever()\n')
|
||||
killed = False
|
||||
buffer = b""
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.replace(b"\r\n", b"\n").endswith(b"READY\n") and not killed:
|
||||
os.kill(proc.pid, signal_sent)
|
||||
killed = True
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
assert b"trio" in buffer
|
||||
assert b"KeyboardInterrupt" in buffer
|
||||
|
||||
# make sure it works for sync commands too
|
||||
# (though this would be hard to break)
|
||||
print(buffer.decode())
|
||||
await proc.stdin.send_all(
|
||||
b'import time; print("READY"); time.sleep(99999)\n'
|
||||
)
|
||||
killed = False
|
||||
buffer = b""
|
||||
async for part in proc.stdout: # pragma: no branch
|
||||
buffer += part
|
||||
if buffer.replace(b"\r\n", b"\n").endswith(b"READY\n") and not killed:
|
||||
os.kill(proc.pid, signal_sent)
|
||||
killed = True
|
||||
if buffer.endswith(b">>> "):
|
||||
break
|
||||
|
||||
assert b"Traceback" in buffer
|
||||
assert b"KeyboardInterrupt" in buffer
|
||||
|
||||
print(buffer.decode())
|
||||
|
||||
# kill the process
|
||||
nursery.cancel_scope.cancel()
|
||||
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pytest
|
||||
|
||||
|
||||
async def scheduler_trace() -> tuple[tuple[str, int], ...]:
|
||||
"""Returns a scheduler-dependent value we can use to check determinism."""
|
||||
trace = []
|
||||
|
||||
async def tracer(name: str) -> None:
|
||||
for i in range(50):
|
||||
trace.append((name, i))
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(5):
|
||||
nursery.start_soon(tracer, str(i))
|
||||
|
||||
return tuple(trace)
|
||||
|
||||
|
||||
def test_the_trio_scheduler_is_not_deterministic() -> None:
|
||||
# At least, not yet. See https://github.com/python-trio/trio/issues/32
|
||||
traces = [trio.run(scheduler_trace) for _ in range(10)]
|
||||
assert len(set(traces)) == len(traces)
|
||||
|
||||
|
||||
def test_the_trio_scheduler_is_deterministic_if_seeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True)
|
||||
traces = []
|
||||
for _ in range(10):
|
||||
state = trio._core._run._r.getstate()
|
||||
try:
|
||||
trio._core._run._r.seed(0)
|
||||
traces.append(trio.run(scheduler_trace))
|
||||
finally:
|
||||
trio._core._run._r.setstate(state)
|
||||
|
||||
assert len(traces) == 10
|
||||
assert len(set(traces)) == 1
|
||||
@@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import signal
|
||||
from typing import TYPE_CHECKING, NoReturn
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
|
||||
from .. import _core
|
||||
from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType
|
||||
|
||||
|
||||
async def test_open_signal_receiver() -> None:
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
# Raise it a few times, to exercise signal coalescing, both at the
|
||||
# call_soon level and at the SignalQueue level
|
||||
signal.raise_signal(signal.SIGILL)
|
||||
signal.raise_signal(signal.SIGILL)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
signal.raise_signal(signal.SIGILL)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
async for signum in receiver: # pragma: no branch
|
||||
assert signum == signal.SIGILL
|
||||
break
|
||||
assert get_pending_signal_count(receiver) == 0
|
||||
signal.raise_signal(signal.SIGILL)
|
||||
async for signum in receiver: # pragma: no branch
|
||||
assert signum == signal.SIGILL
|
||||
break
|
||||
assert get_pending_signal_count(receiver) == 0
|
||||
with pytest.raises(RuntimeError):
|
||||
await receiver.__anext__()
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None:
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"(signal number out of range|invalid signal value)$",
|
||||
):
|
||||
with open_signal_receiver(signal.SIGILL, 1234567):
|
||||
pass # pragma: no cover
|
||||
# Still restored even if we errored out
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
def test_open_signal_receiver_empty_fail() -> None:
|
||||
with pytest.raises(TypeError, match="No signals were provided"):
|
||||
with open_signal_receiver():
|
||||
pass
|
||||
|
||||
|
||||
async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None:
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGILL):
|
||||
pass
|
||||
# Still restored correctly
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_catch_signals_wrong_thread() -> None:
|
||||
async def naughty() -> None:
|
||||
with open_signal_receiver(signal.SIGINT):
|
||||
pass # pragma: no cover
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await trio.to_thread.run_sync(trio.run, naughty)
|
||||
|
||||
|
||||
async def test_open_signal_receiver_conflict() -> None:
|
||||
with pytest.RaisesGroup(trio.BusyResourceError):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver.__anext__)
|
||||
nursery.start_soon(receiver.__anext__)
|
||||
|
||||
|
||||
# Blocks until all previous calls to run_sync_soon(idempotent=True) have been
|
||||
# processed.
|
||||
async def wait_run_sync_soon_idempotent_queue_barrier() -> None:
|
||||
ev = trio.Event()
|
||||
token = _core.current_trio_token()
|
||||
token.run_sync_soon(ev.set, idempotent=True)
|
||||
await ev.wait()
|
||||
|
||||
|
||||
async def test_open_signal_receiver_no_starvation() -> None:
|
||||
# Set up a situation where there are always 2 pending signals available to
|
||||
# report, and make sure that instead of getting the same signal reported
|
||||
# over and over, it alternates between reporting both of them.
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
try:
|
||||
print(signal.getsignal(signal.SIGILL))
|
||||
previous = None
|
||||
for _ in range(10):
|
||||
signal.raise_signal(signal.SIGILL)
|
||||
signal.raise_signal(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
if previous is None:
|
||||
previous = await receiver.__anext__()
|
||||
else:
|
||||
got = await receiver.__anext__()
|
||||
assert got in [signal.SIGILL, signal.SIGFPE]
|
||||
assert got != previous
|
||||
previous = got
|
||||
# Clear out the last signal so that it doesn't get redelivered
|
||||
while get_pending_signal_count(receiver) != 0:
|
||||
await receiver.__anext__()
|
||||
except BaseException: # pragma: no cover
|
||||
# If there's an unhandled exception above, then exiting the
|
||||
# open_signal_receiver block might cause the signal to be
|
||||
# redelivered and give us a core dump instead of a traceback...
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def test_catch_signals_race_condition_on_exit() -> None:
|
||||
delivered_directly: set[int] = set()
|
||||
|
||||
def direct_handler(signo: int, frame: FrameType | None) -> None:
|
||||
delivered_directly.add(signo)
|
||||
|
||||
print(1)
|
||||
# Test the version where the call_soon *doesn't* have a chance to run
|
||||
# before we exit the with block:
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal.raise_signal(signal.SIGILL)
|
||||
signal.raise_signal(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
|
||||
delivered_directly.clear()
|
||||
|
||||
print(2)
|
||||
# Test the version where the call_soon *does* have a chance to run before
|
||||
# we exit the with block:
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal.raise_signal(signal.SIGILL)
|
||||
signal.raise_signal(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert get_pending_signal_count(receiver) == 2
|
||||
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
|
||||
delivered_directly.clear()
|
||||
|
||||
# Again, but with a SIG_IGN signal:
|
||||
|
||||
print(3)
|
||||
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
signal.raise_signal(signal.SIGILL)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
# test passes if the process reaches this point without dying
|
||||
|
||||
print(4)
|
||||
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
signal.raise_signal(signal.SIGILL)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert get_pending_signal_count(receiver) == 1
|
||||
# test passes if the process reaches this point without dying
|
||||
|
||||
# Check exception chaining if there are multiple exception-raising
|
||||
# handlers
|
||||
def raise_handler(signum: int, frame: FrameType | None) -> NoReturn:
|
||||
raise RuntimeError(signum)
|
||||
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler):
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal.raise_signal(signal.SIGILL)
|
||||
signal.raise_signal(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert get_pending_signal_count(receiver) == 2
|
||||
exc = excinfo.value
|
||||
signums = {exc.args[0]}
|
||||
assert isinstance(exc.__context__, RuntimeError)
|
||||
signums.add(exc.__context__.args[0])
|
||||
assert signums == {signal.SIGILL, signal.SIGFPE}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,767 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path as SyncPath
|
||||
from signal import Signals
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
NoReturn,
|
||||
TypeAlias,
|
||||
)
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
|
||||
from .. import (
|
||||
Event,
|
||||
Process,
|
||||
_core,
|
||||
fail_after,
|
||||
move_on_after,
|
||||
run_process,
|
||||
sleep,
|
||||
sleep_forever,
|
||||
)
|
||||
from .._core._tests.tutil import skip_if_fbsd_pipes_broken, slow
|
||||
from ..lowlevel import open_process
|
||||
from ..testing import MockClock, assert_no_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType
|
||||
|
||||
from .._abc import ReceiveStream
|
||||
|
||||
if sys.platform == "win32":
|
||||
SignalType: TypeAlias = None
|
||||
else:
|
||||
SignalType: TypeAlias = Signals
|
||||
|
||||
SIGKILL: SignalType
|
||||
SIGTERM: SignalType
|
||||
SIGUSR1: SignalType
|
||||
|
||||
posix = os.name == "posix"
|
||||
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
|
||||
from signal import SIGKILL, SIGTERM, SIGUSR1
|
||||
else:
|
||||
SIGKILL, SIGTERM, SIGUSR1 = None, None, None
|
||||
|
||||
|
||||
# Since Windows has very few command-line utilities generally available,
|
||||
# all of our subprocesses are Python processes running short bits of
|
||||
# (mostly) cross-platform code.
|
||||
def python(code: str) -> list[str]:
|
||||
return [sys.executable, "-u", "-c", "import sys; " + code]
|
||||
|
||||
|
||||
EXIT_TRUE = python("sys.exit(0)")
|
||||
EXIT_FALSE = python("sys.exit(1)")
|
||||
CAT = python("sys.stdout.buffer.write(sys.stdin.buffer.read())")
|
||||
|
||||
if posix:
|
||||
|
||||
def SLEEP(seconds: int) -> list[str]:
|
||||
return ["sleep", str(seconds)]
|
||||
|
||||
else:
|
||||
|
||||
def SLEEP(seconds: int) -> list[str]:
|
||||
return python(f"import time; time.sleep({seconds})")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_process_then_kill( # type: ignore[misc, explicit-any]
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Process]:
|
||||
proc = await open_process(*args, **kwargs)
|
||||
try:
|
||||
yield proc
|
||||
finally:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_process_in_nursery( # type: ignore[misc, explicit-any]
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Process]:
|
||||
async with _core.open_nursery() as nursery:
|
||||
kwargs.setdefault("check", False)
|
||||
value = await nursery.start(partial(run_process, *args, **kwargs))
|
||||
assert isinstance(value, Process)
|
||||
proc: Process = value
|
||||
yield proc
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
background_process_param = pytest.mark.parametrize(
|
||||
"background_process",
|
||||
[open_process_then_kill, run_process_in_nursery],
|
||||
ids=["open_process", "run_process in nursery"],
|
||||
)
|
||||
|
||||
BackgroundProcessType: TypeAlias = Callable[ # type: ignore[explicit-any]
|
||||
...,
|
||||
AbstractAsyncContextManager[Process],
|
||||
]
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_basic(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(EXIT_TRUE) as proc:
|
||||
await proc.wait()
|
||||
assert isinstance(proc, Process)
|
||||
assert proc._pidfd is None
|
||||
assert proc.returncode == 0
|
||||
assert repr(proc) == f"<trio.Process {EXIT_TRUE}: exited with status 0>"
|
||||
|
||||
async with background_process(EXIT_FALSE) as proc:
|
||||
await proc.wait()
|
||||
assert proc.returncode == 1
|
||||
assert repr(proc) == "<trio.Process {!r}: {}>".format(
|
||||
EXIT_FALSE,
|
||||
"exited with status 1",
|
||||
)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_basic_no_pidfd(background_process: BackgroundProcessType) -> None:
|
||||
with mock.patch("trio._subprocess.can_try_pidfd_open", new=False):
|
||||
async with background_process(EXIT_TRUE) as proc:
|
||||
assert proc._pidfd is None
|
||||
await proc.wait()
|
||||
assert isinstance(proc, Process)
|
||||
assert proc._pidfd is None
|
||||
assert proc.returncode == 0
|
||||
assert repr(proc) == f"<trio.Process {EXIT_TRUE}: exited with status 0>"
|
||||
|
||||
async with background_process(EXIT_FALSE) as proc:
|
||||
await proc.wait()
|
||||
assert proc.returncode == 1
|
||||
assert repr(proc) == "<trio.Process {!r}: {}>".format(
|
||||
EXIT_FALSE,
|
||||
"exited with status 1",
|
||||
)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_auto_update_returncode(
|
||||
background_process: BackgroundProcessType,
|
||||
) -> None:
|
||||
async with background_process(SLEEP(9999)) as p:
|
||||
assert p.returncode is None
|
||||
assert "running" in repr(p)
|
||||
p.kill()
|
||||
p._proc.wait()
|
||||
assert p.returncode is not None
|
||||
assert "exited" in repr(p)
|
||||
assert p._pidfd is None
|
||||
assert p.returncode is not None
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_multi_wait(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(SLEEP(10)) as proc:
|
||||
# Check that wait (including multi-wait) tolerates being cancelled
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# Now try waiting for real
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
proc.kill()
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_multi_wait_no_pidfd(background_process: BackgroundProcessType) -> None:
|
||||
with mock.patch("trio._subprocess.can_try_pidfd_open", new=False):
|
||||
async with background_process(SLEEP(10)) as proc:
|
||||
# Check that wait (including multi-wait) tolerates being cancelled
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# Now try waiting for real
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
proc.kill()
|
||||
|
||||
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python(
|
||||
"data = sys.stdin.buffer.read(); "
|
||||
"sys.stdout.buffer.write(data); "
|
||||
"sys.stderr.buffer.write(data[::-1])",
|
||||
)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_pipes(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
msg = b"the quick brown fox jumps over the lazy dog"
|
||||
|
||||
async def feed_input() -> None:
|
||||
assert proc.stdin is not None
|
||||
await proc.stdin.send_all(msg)
|
||||
await proc.stdin.aclose()
|
||||
|
||||
async def check_output(stream: ReceiveStream, expected: bytes) -> None:
|
||||
seen = bytearray()
|
||||
async for chunk in stream:
|
||||
seen += chunk
|
||||
assert seen == expected
|
||||
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is not None
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
# fail eventually if something is broken
|
||||
nursery.cancel_scope.deadline = _core.current_time() + 30.0
|
||||
nursery.start_soon(feed_input)
|
||||
nursery.start_soon(check_output, proc.stdout, msg)
|
||||
nursery.start_soon(check_output, proc.stderr, msg[::-1])
|
||||
|
||||
assert not nursery.cancel_scope.cancelled_caught
|
||||
assert await proc.wait() == 0
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_interactive(background_process: BackgroundProcessType) -> None:
|
||||
# Test some back-and-forth with a subprocess. This one works like so:
|
||||
# in: 32\n
|
||||
# out: 0000...0000\n (32 zeroes)
|
||||
# err: 1111...1111\n (64 ones)
|
||||
# in: 10\n
|
||||
# out: 2222222222\n (10 twos)
|
||||
# err: 3333....3333\n (20 threes)
|
||||
# in: EOF
|
||||
# out: EOF
|
||||
# err: EOF
|
||||
|
||||
async with background_process(
|
||||
python(
|
||||
"idx = 0\n"
|
||||
"while True:\n"
|
||||
" line = sys.stdin.readline()\n"
|
||||
" if line == '': break\n"
|
||||
" request = int(line.strip())\n"
|
||||
" print(str(idx * 2) * request)\n"
|
||||
" print(str(idx * 2 + 1) * request * 2, file=sys.stderr)\n"
|
||||
" idx += 1\n",
|
||||
),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
newline = b"\n" if posix else b"\r\n"
|
||||
|
||||
async def expect(idx: int, request: int) -> None:
|
||||
async with _core.open_nursery() as nursery:
|
||||
|
||||
async def drain_one(
|
||||
stream: ReceiveStream,
|
||||
count: int,
|
||||
digit: int,
|
||||
) -> None:
|
||||
while count > 0:
|
||||
result = await stream.receive_some(count)
|
||||
assert result == (f"{digit}".encode() * len(result))
|
||||
count -= len(result)
|
||||
assert count == 0
|
||||
assert await stream.receive_some(len(newline)) == newline
|
||||
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is not None
|
||||
nursery.start_soon(drain_one, proc.stdout, request, idx * 2)
|
||||
nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1)
|
||||
|
||||
assert proc.stdin is not None
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is not None
|
||||
with fail_after(5):
|
||||
await proc.stdin.send_all(b"12")
|
||||
await sleep(0.1)
|
||||
await proc.stdin.send_all(b"345" + newline)
|
||||
await expect(0, 12345)
|
||||
await proc.stdin.send_all(b"100" + newline + b"200" + newline)
|
||||
await expect(1, 100)
|
||||
await expect(2, 200)
|
||||
await proc.stdin.send_all(b"0" + newline)
|
||||
await expect(3, 0)
|
||||
await proc.stdin.send_all(b"999999")
|
||||
with move_on_after(0.1) as scope:
|
||||
await expect(4, 0)
|
||||
assert scope.cancelled_caught
|
||||
await proc.stdin.send_all(newline)
|
||||
await expect(4, 999999)
|
||||
await proc.stdin.aclose()
|
||||
assert await proc.stdout.receive_some(1) == b""
|
||||
assert await proc.stderr.receive_some(1) == b""
|
||||
await proc.wait()
|
||||
|
||||
assert proc.returncode == 0
|
||||
|
||||
|
||||
async def test_run() -> None:
|
||||
data = bytes(random.randint(0, 255) for _ in range(2**18))
|
||||
|
||||
result = await run_process(
|
||||
CAT,
|
||||
stdin=data,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
)
|
||||
assert result.args == CAT
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == data
|
||||
assert result.stderr == b""
|
||||
|
||||
result = await run_process(CAT, capture_stdout=True)
|
||||
assert result.args == CAT
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == b""
|
||||
assert result.stderr is None
|
||||
|
||||
result = await run_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=data,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
)
|
||||
assert result.args == COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == data
|
||||
assert result.stderr == data[::-1]
|
||||
|
||||
# invalid combinations
|
||||
with pytest.raises(UnicodeError):
|
||||
await run_process(CAT, stdin="oh no, it's text")
|
||||
|
||||
pipe_stdout_error = r"^stdout=subprocess\.PIPE is only valid with nursery\.start, since that's the only way to access the pipe(; use nursery\.start or pass the data you want to write directly)*$"
|
||||
with pytest.raises(ValueError, match=pipe_stdout_error):
|
||||
await run_process(CAT, stdin=subprocess.PIPE)
|
||||
with pytest.raises(ValueError, match=pipe_stdout_error):
|
||||
await run_process(CAT, stdout=subprocess.PIPE)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=pipe_stdout_error.replace("stdout", "stderr", 1),
|
||||
):
|
||||
await run_process(CAT, stderr=subprocess.PIPE)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"^can't specify both stdout and capture_stdout$",
|
||||
):
|
||||
await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"^can't specify both stderr and capture_stderr$",
|
||||
):
|
||||
await run_process(CAT, capture_stderr=True, stderr=None)
|
||||
|
||||
|
||||
async def test_run_check() -> None:
|
||||
cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)")
|
||||
with pytest.raises(subprocess.CalledProcessError) as excinfo:
|
||||
await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True)
|
||||
assert excinfo.value.cmd == cmd
|
||||
assert excinfo.value.returncode == 1
|
||||
assert excinfo.value.stderr == b"test\n"
|
||||
assert excinfo.value.stdout is None
|
||||
|
||||
result = await run_process(
|
||||
cmd,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
check=False,
|
||||
)
|
||||
assert result.args == cmd
|
||||
assert result.stdout == b""
|
||||
assert result.stderr == b"test\n"
|
||||
assert result.returncode == 1
|
||||
|
||||
|
||||
@skip_if_fbsd_pipes_broken
|
||||
async def test_run_with_broken_pipe() -> None:
|
||||
result = await run_process(
|
||||
[sys.executable, "-c", "import sys; sys.stdin.close()"],
|
||||
stdin=b"x" * 131072,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
assert result.stdout is result.stderr is None
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_stderr_stdout(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
assert proc.stdio is not None
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is None
|
||||
await proc.stdio.send_all(b"1234")
|
||||
await proc.stdio.send_eof()
|
||||
|
||||
output = []
|
||||
while True:
|
||||
chunk = await proc.stdio.receive_some(16)
|
||||
if chunk == b"":
|
||||
break
|
||||
output.append(chunk)
|
||||
assert b"".join(output) == b"12344321"
|
||||
assert proc.returncode == 0
|
||||
|
||||
# equivalent test with run_process()
|
||||
result = await run_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=b"1234",
|
||||
capture_stdout=True,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == b"12344321"
|
||||
assert result.stderr is None
|
||||
|
||||
# this one hits the branch where stderr=STDOUT but stdout
|
||||
# is not redirected
|
||||
async with background_process(
|
||||
CAT,
|
||||
stdin=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
assert proc.stdout is None
|
||||
assert proc.stderr is None
|
||||
await proc.stdin.aclose()
|
||||
await proc.wait()
|
||||
assert proc.returncode == 0
|
||||
|
||||
if posix:
|
||||
try:
|
||||
r, w = os.pipe()
|
||||
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=w,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
os.close(w)
|
||||
assert proc.stdio is None
|
||||
assert proc.stdout is None
|
||||
assert proc.stderr is None
|
||||
await proc.stdin.send_all(b"1234")
|
||||
await proc.stdin.aclose()
|
||||
assert await proc.wait() == 0
|
||||
assert os.read(r, 4096) == b"12344321"
|
||||
assert os.read(r, 4096) == b""
|
||||
finally:
|
||||
os.close(r)
|
||||
|
||||
|
||||
async def test_errors() -> None:
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
# call-overload on unix, call-arg on windows
|
||||
await open_process(["ls"], encoding="utf-8") # type: ignore
|
||||
assert "unbuffered byte streams" in str(excinfo.value)
|
||||
assert "the 'encoding' option is not supported" in str(excinfo.value)
|
||||
|
||||
if posix:
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
await open_process(["ls"], shell=True)
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
await open_process("ls", shell=False)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_signals(background_process: BackgroundProcessType) -> None:
|
||||
async def test_one_signal(
|
||||
send_it: Callable[[Process], None],
|
||||
signum: signal.Signals | None,
|
||||
) -> None:
|
||||
with move_on_after(1.0) as scope:
|
||||
async with background_process(SLEEP(3600)) as proc:
|
||||
send_it(proc)
|
||||
await proc.wait()
|
||||
assert not scope.cancelled_caught
|
||||
if posix:
|
||||
assert signum is not None
|
||||
assert proc.returncode == -signum
|
||||
else:
|
||||
assert proc.returncode != 0
|
||||
|
||||
await test_one_signal(Process.kill, SIGKILL)
|
||||
await test_one_signal(Process.terminate, SIGTERM)
|
||||
# Test that we can send arbitrary signals.
|
||||
#
|
||||
# We used to use SIGINT here, but it turns out that the Python interpreter
|
||||
# has race conditions that can cause it to explode in weird ways if it
|
||||
# tries to handle SIGINT during startup. SIGUSR1's default disposition is
|
||||
# to terminate the target process, and Python doesn't try to do anything
|
||||
# clever to handle it.
|
||||
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
|
||||
await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not posix, reason="POSIX specific")
|
||||
@background_process_param
|
||||
async def test_wait_reapable_fails(background_process: BackgroundProcessType) -> None:
|
||||
if TYPE_CHECKING and sys.platform == "win32":
|
||||
return
|
||||
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
|
||||
try:
|
||||
# With SIGCHLD disabled, the wait() syscall will wait for the
|
||||
# process to exit but then fail with ECHILD. Make sure we
|
||||
# support this case as the stdlib subprocess module does.
|
||||
async with background_process(SLEEP(3600)) as proc:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
proc.kill()
|
||||
nursery.cancel_scope.deadline = _core.current_time() + 1.0
|
||||
assert not nursery.cancel_scope.cancelled_caught
|
||||
assert proc.returncode == 0 # exit status unknowable, so...
|
||||
finally:
|
||||
signal.signal(signal.SIGCHLD, old_sigchld)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not posix, reason="POSIX specific")
|
||||
@background_process_param
|
||||
async def test_wait_reapable_fails_no_pidfd(
|
||||
background_process: BackgroundProcessType,
|
||||
) -> None:
|
||||
if TYPE_CHECKING and sys.platform == "win32":
|
||||
return
|
||||
with mock.patch("trio._subprocess.can_try_pidfd_open", new=False):
|
||||
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
|
||||
try:
|
||||
# With SIGCHLD disabled, the wait() syscall will wait for the
|
||||
# process to exit but then fail with ECHILD. Make sure we
|
||||
# support this case as the stdlib subprocess module does.
|
||||
async with background_process(SLEEP(3600)) as proc:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
proc.kill()
|
||||
nursery.cancel_scope.deadline = _core.current_time() + 1.0
|
||||
assert not nursery.cancel_scope.cancelled_caught
|
||||
assert proc.returncode == 0 # exit status unknowable, so...
|
||||
finally:
|
||||
signal.signal(signal.SIGCHLD, old_sigchld)
|
||||
|
||||
|
||||
@slow
|
||||
def test_waitid_eintr() -> None:
|
||||
# This only matters on PyPy (where we're coding EINTR handling
|
||||
# ourselves) but the test works on all waitid platforms.
|
||||
from .._subprocess_platform import wait_child_exiting
|
||||
|
||||
if TYPE_CHECKING and (sys.platform == "win32" or sys.platform == "darwin"):
|
||||
return
|
||||
|
||||
if not wait_child_exiting.__module__.endswith("waitid"):
|
||||
pytest.skip("waitid only")
|
||||
|
||||
# despite the TYPE_CHECKING early return silencing warnings about signal.SIGALRM etc
|
||||
# this import is still checked on win32&darwin and raises [attr-defined].
|
||||
# Linux doesn't raise [attr-defined] though, so we need [unused-ignore]
|
||||
from .._subprocess_platform.waitid import ( # type: ignore[attr-defined, unused-ignore]
|
||||
sync_wait_reapable,
|
||||
)
|
||||
|
||||
got_alarm = False
|
||||
sleeper = subprocess.Popen(["sleep", "3600"])
|
||||
|
||||
def on_alarm(sig: int, frame: FrameType | None) -> None:
|
||||
nonlocal got_alarm
|
||||
got_alarm = True
|
||||
sleeper.kill()
|
||||
|
||||
old_sigalrm = signal.signal(signal.SIGALRM, on_alarm)
|
||||
try:
|
||||
signal.alarm(1)
|
||||
sync_wait_reapable(sleeper.pid)
|
||||
assert sleeper.wait(timeout=1) == -9
|
||||
finally:
|
||||
if sleeper.returncode is None: # pragma: no cover
|
||||
# We only get here if something fails in the above;
|
||||
# if the test passes, wait() will reap the process
|
||||
sleeper.kill()
|
||||
sleeper.wait()
|
||||
signal.signal(signal.SIGALRM, old_sigalrm)
|
||||
|
||||
|
||||
async def test_custom_deliver_cancel() -> None:
|
||||
custom_deliver_cancel_called = False
|
||||
|
||||
async def custom_deliver_cancel(proc: Process) -> None:
|
||||
nonlocal custom_deliver_cancel_called
|
||||
custom_deliver_cancel_called = True
|
||||
proc.terminate()
|
||||
# Make sure this does get cancelled when the process exits, and that
|
||||
# the process really exited.
|
||||
try:
|
||||
await sleep_forever()
|
||||
finally:
|
||||
assert proc.returncode is not None
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel),
|
||||
)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
assert custom_deliver_cancel_called
|
||||
|
||||
|
||||
def test_bad_deliver_cancel() -> None:
|
||||
async def custom_deliver_cancel(proc: Process) -> None:
|
||||
proc.terminate()
|
||||
raise ValueError("foo")
|
||||
|
||||
async def do_stuff() -> None:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel),
|
||||
)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# double wrap from our nursery + the internal nursery
|
||||
with pytest.RaisesGroup(
|
||||
pytest.RaisesGroup(pytest.RaisesExc(ValueError, match="^foo$"))
|
||||
):
|
||||
_core.run(do_stuff, strict_exception_groups=True)
|
||||
|
||||
|
||||
async def test_warn_on_failed_cancel_terminate(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
original_terminate = Process.terminate
|
||||
|
||||
def broken_terminate(self: Process) -> NoReturn:
|
||||
original_terminate(self)
|
||||
raise OSError("whoops")
|
||||
|
||||
monkeypatch.setattr(Process, "terminate", broken_terminate)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match=".*whoops.*"): # noqa: PT031
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(run_process, SLEEP(9999))
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not posix, reason="posix only")
|
||||
async def test_warn_on_cancel_SIGKILL_escalation(
|
||||
autojump_clock: MockClock,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(Process, "terminate", lambda *args: None)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"): # noqa: PT031
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(run_process, SLEEP(9999))
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
# the background_process_param exercises a lot of run_process cases, but it uses
|
||||
# check=False, so lets have a test that uses check=True as well
|
||||
async def test_run_process_background_fail() -> None:
|
||||
with pytest.RaisesGroup(subprocess.CalledProcessError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
value = await nursery.start(run_process, EXIT_FALSE)
|
||||
assert isinstance(value, Process)
|
||||
proc: Process = value
|
||||
assert proc.returncode == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not SyncPath("/dev/fd").exists(),
|
||||
reason="requires a way to iterate through open files",
|
||||
)
|
||||
async def test_for_leaking_fds() -> None:
|
||||
gc.collect() # address possible flakiness on PyPy
|
||||
|
||||
starting_fds = set(SyncPath("/dev/fd").iterdir()) # noqa: ASYNC240
|
||||
await run_process(EXIT_TRUE)
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds # noqa: ASYNC240
|
||||
|
||||
with pytest.raises(subprocess.CalledProcessError):
|
||||
await run_process(EXIT_FALSE)
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds # noqa: ASYNC240
|
||||
|
||||
with pytest.raises(PermissionError):
|
||||
await run_process(["/dev/fd/0"])
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds # noqa: ASYNC240
|
||||
|
||||
|
||||
async def test_run_process_internal_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# There's probably less extreme ways of triggering errors inside the nursery
|
||||
# in run_process.
|
||||
async def very_broken_open(*args: object, **kwargs: object) -> str:
|
||||
return "oops"
|
||||
|
||||
monkeypatch.setattr(trio._subprocess, "_open_process", very_broken_open)
|
||||
with pytest.RaisesGroup(AttributeError, AttributeError):
|
||||
await run_process(EXIT_TRUE, capture_stdout=True)
|
||||
|
||||
|
||||
# regression test for #2209
|
||||
async def test_subprocess_pidfd_unnotified() -> None:
|
||||
noticed_exit = None
|
||||
|
||||
async def wait_and_tell(proc: Process) -> None:
|
||||
nonlocal noticed_exit
|
||||
noticed_exit = Event()
|
||||
await proc.wait()
|
||||
noticed_exit.set()
|
||||
|
||||
proc = await open_process(SLEEP(9999))
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_and_tell, proc)
|
||||
await wait_all_tasks_blocked()
|
||||
assert isinstance(noticed_exit, Event)
|
||||
proc.terminate()
|
||||
# without giving trio a chance to do so,
|
||||
with assert_no_checkpoints():
|
||||
# wait until the process has actually exited;
|
||||
proc._proc.wait()
|
||||
# force a call to poll (that closes the pidfd on linux)
|
||||
proc.poll()
|
||||
with move_on_after(5):
|
||||
# Some platforms use threads to wait for exit, so it might take a bit
|
||||
# for everything to notice
|
||||
await noticed_exit.wait()
|
||||
assert noticed_exit.is_set(), "child task wasn't woken after poll, DEADLOCK"
|
||||
@@ -0,0 +1,735 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
from typing import TypeAlias
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import _core
|
||||
from .._core._parking_lot import GLOBAL_PARKING_LOT_BREAKER
|
||||
from .._sync import *
|
||||
from .._timeouts import sleep_forever
|
||||
from ..testing import assert_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
|
||||
async def test_Event() -> None:
|
||||
e = Event()
|
||||
assert not e.is_set()
|
||||
assert e.statistics().tasks_waiting == 0
|
||||
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match=r"trio\.Event\.__bool__ is deprecated since Trio 0\.31\.0; use trio\.Event\.is_set instead \(https://github.com/python-trio/trio/issues/3238\)",
|
||||
):
|
||||
e.__bool__()
|
||||
|
||||
e.set()
|
||||
assert e.is_set()
|
||||
with assert_checkpoints():
|
||||
await e.wait()
|
||||
|
||||
e = Event()
|
||||
|
||||
record = []
|
||||
|
||||
async def child() -> None:
|
||||
record.append("sleeping")
|
||||
await e.wait()
|
||||
record.append("woken")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
nursery.start_soon(child)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["sleeping", "sleeping"]
|
||||
assert e.statistics().tasks_waiting == 2
|
||||
e.set()
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["sleeping", "sleeping", "woken", "woken"]
|
||||
|
||||
|
||||
async def test_CapacityLimiter() -> None:
|
||||
assert CapacityLimiter(0).total_tokens == 0
|
||||
with pytest.raises(TypeError):
|
||||
CapacityLimiter(1.0)
|
||||
with pytest.raises(ValueError, match=r"^total_tokens must be >= 0$"):
|
||||
CapacityLimiter(-1)
|
||||
c = CapacityLimiter(2)
|
||||
repr(c) # smoke test
|
||||
assert c.total_tokens == 2
|
||||
assert c.borrowed_tokens == 0
|
||||
assert c.available_tokens == 2
|
||||
with pytest.raises(RuntimeError):
|
||||
c.release()
|
||||
assert c.borrowed_tokens == 0
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
assert c.available_tokens == 1
|
||||
|
||||
stats = c.statistics()
|
||||
assert stats.borrowed_tokens == 1
|
||||
assert stats.total_tokens == 2
|
||||
assert stats.borrowers == [_core.current_task()]
|
||||
assert stats.tasks_waiting == 0
|
||||
|
||||
# Can't re-acquire when we already have it
|
||||
with pytest.raises(RuntimeError):
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
with pytest.raises(RuntimeError):
|
||||
await c.acquire()
|
||||
assert c.borrowed_tokens == 1
|
||||
|
||||
# We can acquire on behalf of someone else though
|
||||
with assert_checkpoints():
|
||||
await c.acquire_on_behalf_of("someone")
|
||||
|
||||
# But then we've run out of capacity
|
||||
assert c.borrowed_tokens == 2
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
c.acquire_on_behalf_of_nowait("third party")
|
||||
|
||||
assert set(c.statistics().borrowers) == {_core.current_task(), "someone"}
|
||||
|
||||
# Until we release one
|
||||
c.release_on_behalf_of(_core.current_task())
|
||||
assert c.statistics().borrowers == ["someone"]
|
||||
|
||||
c.release_on_behalf_of("someone")
|
||||
assert c.borrowed_tokens == 0
|
||||
with assert_checkpoints():
|
||||
async with c:
|
||||
assert c.borrowed_tokens == 1
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
await c.acquire_on_behalf_of("value 1")
|
||||
await c.acquire_on_behalf_of("value 2")
|
||||
nursery.start_soon(c.acquire_on_behalf_of, "value 3")
|
||||
await wait_all_tasks_blocked()
|
||||
assert c.borrowed_tokens == 2
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.release_on_behalf_of("value 2")
|
||||
# Fairness:
|
||||
assert c.borrowed_tokens == 2
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
c.acquire_nowait()
|
||||
|
||||
c.release_on_behalf_of("value 3")
|
||||
c.release_on_behalf_of("value 1")
|
||||
|
||||
|
||||
async def test_CapacityLimiter_inf() -> None:
|
||||
from math import inf
|
||||
|
||||
c = CapacityLimiter(inf)
|
||||
repr(c) # smoke test
|
||||
assert c.total_tokens == inf
|
||||
assert c.borrowed_tokens == 0
|
||||
assert c.available_tokens == inf
|
||||
with pytest.raises(RuntimeError):
|
||||
c.release()
|
||||
assert c.borrowed_tokens == 0
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
assert c.available_tokens == inf
|
||||
|
||||
|
||||
async def test_CapacityLimiter_change_total_tokens() -> None:
|
||||
c = CapacityLimiter(2)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
c.total_tokens = 1.0
|
||||
|
||||
with pytest.raises(ValueError, match=r"^total_tokens must be >= 0$"):
|
||||
c.total_tokens = -1
|
||||
|
||||
with pytest.raises(ValueError, match=r"^total_tokens must be >= 0$"):
|
||||
c.total_tokens = -10
|
||||
|
||||
assert c.total_tokens == 2
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(5):
|
||||
nursery.start_soon(c.acquire_on_behalf_of, i)
|
||||
await wait_all_tasks_blocked()
|
||||
assert set(c.statistics().borrowers) == {0, 1}
|
||||
assert c.statistics().tasks_waiting == 3
|
||||
c.total_tokens += 2
|
||||
assert set(c.statistics().borrowers) == {0, 1, 2, 3}
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.total_tokens -= 3
|
||||
assert c.borrowed_tokens == 4
|
||||
assert c.total_tokens == 1
|
||||
c.release_on_behalf_of(0)
|
||||
c.release_on_behalf_of(1)
|
||||
c.release_on_behalf_of(2)
|
||||
assert set(c.statistics().borrowers) == {3}
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.release_on_behalf_of(3)
|
||||
assert set(c.statistics().borrowers) == {4}
|
||||
assert c.statistics().tasks_waiting == 0
|
||||
|
||||
|
||||
# regression test for issue #548
|
||||
async def test_CapacityLimiter_memleak_548() -> None:
|
||||
limiter = CapacityLimiter(total_tokens=1)
|
||||
await limiter.acquire()
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
n.start_soon(limiter.acquire)
|
||||
await wait_all_tasks_blocked() # give it a chance to run the task
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
# if this is 1, the acquire call (despite being killed) is still there in the task, and will
|
||||
# leak memory all the while the limiter is active
|
||||
assert len(limiter._pending_borrowers) == 0
|
||||
|
||||
|
||||
async def test_CapacityLimiter_zero_limit_tokens() -> None:
|
||||
c = CapacityLimiter(5)
|
||||
|
||||
assert c.total_tokens == 5
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
c.total_tokens = 0
|
||||
|
||||
for i in range(5):
|
||||
nursery.start_soon(c.acquire_on_behalf_of, i)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
assert set(c.statistics().borrowers) == set()
|
||||
assert c.statistics().tasks_waiting == 5
|
||||
|
||||
c.total_tokens = 5
|
||||
|
||||
assert set(c.statistics().borrowers) == {0, 1, 2, 3, 4}
|
||||
|
||||
nursery.start_soon(c.acquire_on_behalf_of, 5)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
|
||||
for i in range(5):
|
||||
c.release_on_behalf_of(i)
|
||||
|
||||
assert c.statistics().tasks_waiting == 0
|
||||
c.release_on_behalf_of(5)
|
||||
|
||||
# making sure that zero limit capacity limiter doesn't let any tasks through
|
||||
|
||||
c.total_tokens = 0
|
||||
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
c.acquire_nowait()
|
||||
|
||||
nursery.start_soon(c.acquire_on_behalf_of, 6)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
assert c.statistics().borrowers == []
|
||||
|
||||
c.total_tokens = 1
|
||||
assert c.statistics().tasks_waiting == 0
|
||||
assert c.statistics().borrowers == [6]
|
||||
c.release_on_behalf_of(6)
|
||||
|
||||
await c.acquire_on_behalf_of(0) # total_tokens is 1
|
||||
|
||||
nursery.start_soon(c.acquire_on_behalf_of, 1)
|
||||
await wait_all_tasks_blocked()
|
||||
c.total_tokens = 0
|
||||
|
||||
assert c.statistics().borrowers == [0]
|
||||
|
||||
c.release_on_behalf_of(0)
|
||||
await wait_all_tasks_blocked()
|
||||
assert c.statistics().borrowers == []
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
|
||||
c.total_tokens = 1
|
||||
await wait_all_tasks_blocked()
|
||||
assert c.statistics().borrowers == [1]
|
||||
assert c.statistics().tasks_waiting == 0
|
||||
|
||||
c.release_on_behalf_of(1)
|
||||
|
||||
c.total_tokens = 0
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
assert c.total_tokens == 0
|
||||
assert c.statistics().borrowers == []
|
||||
assert c._pending_borrowers == {}
|
||||
|
||||
|
||||
async def test_Semaphore() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
Semaphore(1.0) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match=r"^initial value must be >= 0$"):
|
||||
Semaphore(-1)
|
||||
s = Semaphore(1)
|
||||
repr(s) # smoke test
|
||||
assert s.value == 1
|
||||
assert s.max_value is None
|
||||
s.release()
|
||||
assert s.value == 2
|
||||
assert s.statistics().tasks_waiting == 0
|
||||
s.acquire_nowait()
|
||||
assert s.value == 1
|
||||
with assert_checkpoints():
|
||||
await s.acquire()
|
||||
assert s.value == 0
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
s.acquire_nowait()
|
||||
|
||||
s.release()
|
||||
assert s.value == 1
|
||||
with assert_checkpoints():
|
||||
async with s:
|
||||
assert s.value == 0
|
||||
assert s.value == 1
|
||||
s.acquire_nowait()
|
||||
|
||||
record = []
|
||||
|
||||
async def do_acquire(s: Semaphore) -> None:
|
||||
record.append("started")
|
||||
await s.acquire()
|
||||
record.append("finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_acquire, s)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["started"]
|
||||
assert s.value == 0
|
||||
s.release()
|
||||
# Fairness:
|
||||
assert s.value == 0
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
s.acquire_nowait()
|
||||
assert record == ["started", "finished"]
|
||||
|
||||
|
||||
def test_Semaphore_bounded() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
Semaphore(1, max_value=1.0) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match=r"^max_values must be >= initial_value$"):
|
||||
Semaphore(2, max_value=1)
|
||||
bs = Semaphore(1, max_value=1)
|
||||
assert bs.max_value == 1
|
||||
repr(bs) # smoke test
|
||||
with pytest.raises(ValueError, match=r"^semaphore released too many times$"):
|
||||
bs.release()
|
||||
assert bs.value == 1
|
||||
bs.acquire_nowait()
|
||||
assert bs.value == 0
|
||||
bs.release()
|
||||
assert bs.value == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
|
||||
async def test_Lock_and_StrictFIFOLock(
|
||||
lockcls: type[Lock | StrictFIFOLock],
|
||||
) -> None:
|
||||
l = lockcls() # noqa
|
||||
assert not l.locked()
|
||||
|
||||
# make sure locks can be weakref'ed (gh-331)
|
||||
r = weakref.ref(l)
|
||||
assert r() is l
|
||||
|
||||
repr(l) # smoke test
|
||||
# make sure repr uses the right name for subclasses
|
||||
assert lockcls.__name__ in repr(l)
|
||||
with assert_checkpoints():
|
||||
async with l:
|
||||
assert l.locked()
|
||||
repr(l) # smoke test (repr branches on locked/unlocked)
|
||||
assert not l.locked()
|
||||
l.acquire_nowait()
|
||||
assert l.locked()
|
||||
l.release()
|
||||
assert not l.locked()
|
||||
with assert_checkpoints():
|
||||
await l.acquire()
|
||||
assert l.locked()
|
||||
l.release()
|
||||
assert not l.locked()
|
||||
|
||||
l.acquire_nowait()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Error out if we already own the lock
|
||||
l.acquire_nowait()
|
||||
l.release()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Error out if we don't own the lock
|
||||
l.release()
|
||||
|
||||
holder_task = None
|
||||
|
||||
async def holder() -> None:
|
||||
nonlocal holder_task
|
||||
holder_task = _core.current_task()
|
||||
async with l:
|
||||
await sleep_forever()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
assert not l.locked()
|
||||
nursery.start_soon(holder)
|
||||
await wait_all_tasks_blocked()
|
||||
assert l.locked()
|
||||
# WouldBlock if someone else holds the lock
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
l.acquire_nowait()
|
||||
# Can't release a lock someone else holds
|
||||
with pytest.raises(RuntimeError):
|
||||
l.release()
|
||||
|
||||
statistics = l.statistics()
|
||||
print(statistics)
|
||||
assert statistics.locked
|
||||
assert statistics.owner is holder_task
|
||||
assert statistics.tasks_waiting == 0
|
||||
|
||||
nursery.start_soon(holder)
|
||||
await wait_all_tasks_blocked()
|
||||
statistics = l.statistics()
|
||||
print(statistics)
|
||||
assert statistics.tasks_waiting == 1
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
statistics = l.statistics()
|
||||
assert not statistics.locked
|
||||
assert statistics.owner is None
|
||||
assert statistics.tasks_waiting == 0
|
||||
|
||||
|
||||
async def test_Condition() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
Condition(Semaphore(1)) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
Condition(StrictFIFOLock) # type: ignore[arg-type]
|
||||
l = Lock() # noqa
|
||||
c = Condition(l)
|
||||
assert not l.locked()
|
||||
assert not c.locked()
|
||||
with assert_checkpoints():
|
||||
await c.acquire()
|
||||
assert l.locked()
|
||||
assert c.locked()
|
||||
|
||||
c = Condition()
|
||||
assert not c.locked()
|
||||
c.acquire_nowait()
|
||||
assert c.locked()
|
||||
with pytest.raises(RuntimeError):
|
||||
c.acquire_nowait()
|
||||
c.release()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't wait without holding the lock
|
||||
await c.wait()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't notify without holding the lock
|
||||
c.notify()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't notify without holding the lock
|
||||
c.notify_all()
|
||||
|
||||
finished_waiters = set()
|
||||
|
||||
async def waiter(i: int) -> None:
|
||||
async with c:
|
||||
await c.wait()
|
||||
finished_waiters.add(i)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i)
|
||||
await wait_all_tasks_blocked()
|
||||
async with c:
|
||||
c.notify()
|
||||
assert c.locked()
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0}
|
||||
async with c:
|
||||
c.notify_all()
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0, 1, 2}
|
||||
|
||||
finished_waiters = set()
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i)
|
||||
await wait_all_tasks_blocked()
|
||||
async with c:
|
||||
c.notify(2)
|
||||
statistics = c.statistics()
|
||||
print(statistics)
|
||||
assert statistics.tasks_waiting == 1
|
||||
assert statistics.lock_statistics.tasks_waiting == 2
|
||||
# exiting the context manager hands off the lock to the first task
|
||||
assert c.statistics().lock_statistics.tasks_waiting == 1
|
||||
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0, 1}
|
||||
|
||||
async with c:
|
||||
c.notify_all()
|
||||
|
||||
# After being cancelled still hold the lock (!)
|
||||
# (Note that c.__aexit__ checks that we hold the lock as well)
|
||||
with _core.CancelScope() as scope:
|
||||
async with c:
|
||||
scope.cancel()
|
||||
try:
|
||||
await c.wait()
|
||||
finally:
|
||||
assert c.locked()
|
||||
|
||||
|
||||
from .._channel import open_memory_channel
|
||||
from .._sync import AsyncContextManagerMixin
|
||||
|
||||
# Three ways of implementing a Lock in terms of a channel. Used to let us put
|
||||
# the channel through the generic lock tests.
|
||||
|
||||
|
||||
class ChannelLock1(AsyncContextManagerMixin):
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.s, self.r = open_memory_channel[None](capacity)
|
||||
for _ in range(capacity - 1):
|
||||
self.s.send_nowait(None)
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
self.s.send_nowait(None)
|
||||
|
||||
async def acquire(self) -> None:
|
||||
await self.s.send(None)
|
||||
|
||||
def release(self) -> None:
|
||||
self.r.receive_nowait()
|
||||
|
||||
|
||||
class ChannelLock2(AsyncContextManagerMixin):
|
||||
def __init__(self) -> None:
|
||||
self.s, self.r = open_memory_channel[None](10)
|
||||
self.s.send_nowait(None)
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
self.r.receive_nowait()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
await self.r.receive()
|
||||
|
||||
def release(self) -> None:
|
||||
self.s.send_nowait(None)
|
||||
|
||||
|
||||
class ChannelLock3(AsyncContextManagerMixin):
|
||||
def __init__(self) -> None:
|
||||
self.s, self.r = open_memory_channel[None](0)
|
||||
# self.acquired is true when one task acquires the lock and
|
||||
# only becomes false when it's released and no tasks are
|
||||
# waiting to acquire.
|
||||
self.acquired = False
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
assert not self.acquired
|
||||
self.acquired = True
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if self.acquired:
|
||||
await self.s.send(None)
|
||||
else:
|
||||
self.acquired = True
|
||||
await _core.checkpoint()
|
||||
|
||||
def release(self) -> None:
|
||||
try:
|
||||
self.r.receive_nowait()
|
||||
except _core.WouldBlock:
|
||||
assert self.acquired
|
||||
self.acquired = False
|
||||
|
||||
|
||||
lock_factories = [
|
||||
lambda: CapacityLimiter(1),
|
||||
lambda: Semaphore(1),
|
||||
Lock,
|
||||
StrictFIFOLock,
|
||||
lambda: ChannelLock1(10),
|
||||
lambda: ChannelLock1(1),
|
||||
ChannelLock2,
|
||||
ChannelLock3,
|
||||
]
|
||||
lock_factory_names = [
|
||||
"CapacityLimiter(1)",
|
||||
"Semaphore(1)",
|
||||
"Lock",
|
||||
"StrictFIFOLock",
|
||||
"ChannelLock1(10)",
|
||||
"ChannelLock1(1)",
|
||||
"ChannelLock2",
|
||||
"ChannelLock3",
|
||||
]
|
||||
|
||||
generic_lock_test = pytest.mark.parametrize(
|
||||
"lock_factory",
|
||||
lock_factories,
|
||||
ids=lock_factory_names,
|
||||
)
|
||||
|
||||
LockLike: TypeAlias = (
|
||||
CapacityLimiter
|
||||
| Semaphore
|
||||
| Lock
|
||||
| StrictFIFOLock
|
||||
| ChannelLock1
|
||||
| ChannelLock2
|
||||
| ChannelLock3
|
||||
)
|
||||
LockFactory: TypeAlias = Callable[[], LockLike]
|
||||
|
||||
|
||||
# Spawn a bunch of workers that take a lock and then yield; make sure that
|
||||
# only one worker is ever in the critical section at a time.
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_exclusion(lock_factory: LockFactory) -> None:
|
||||
LOOPS = 10
|
||||
WORKERS = 5
|
||||
in_critical_section = False
|
||||
acquires = 0
|
||||
|
||||
async def worker(lock_like: LockLike) -> None:
|
||||
nonlocal in_critical_section, acquires
|
||||
for _ in range(LOOPS):
|
||||
async with lock_like:
|
||||
acquires += 1
|
||||
assert not in_critical_section
|
||||
in_critical_section = True
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
assert in_critical_section
|
||||
in_critical_section = False
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
lock_like = lock_factory()
|
||||
for _ in range(WORKERS):
|
||||
nursery.start_soon(worker, lock_like)
|
||||
assert not in_critical_section
|
||||
assert acquires == LOOPS * WORKERS
|
||||
|
||||
|
||||
# Several workers queue on the same lock; make sure they each get it, in
|
||||
# order.
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_fifo_fairness(lock_factory: LockFactory) -> None:
|
||||
initial_order = []
|
||||
record = []
|
||||
LOOPS = 5
|
||||
|
||||
async def loopy(name: int, lock_like: LockLike) -> None:
|
||||
# Record the order each task was initially scheduled in
|
||||
initial_order.append(name)
|
||||
for _ in range(LOOPS):
|
||||
async with lock_like:
|
||||
record.append(name)
|
||||
|
||||
lock_like = lock_factory()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(loopy, 1, lock_like)
|
||||
nursery.start_soon(loopy, 2, lock_like)
|
||||
nursery.start_soon(loopy, 3, lock_like)
|
||||
# The first three could be in any order due to scheduling randomness,
|
||||
# but after that they should repeat in the same order
|
||||
for i in range(LOOPS):
|
||||
assert record[3 * i : 3 * (i + 1)] == initial_order
|
||||
|
||||
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_acquire_nowait_blocks_acquire(
|
||||
lock_factory: LockFactory,
|
||||
) -> None:
|
||||
lock_like = lock_factory()
|
||||
|
||||
record = []
|
||||
|
||||
async def lock_taker() -> None:
|
||||
record.append("started")
|
||||
async with lock_like:
|
||||
pass
|
||||
record.append("finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
lock_like.acquire_nowait()
|
||||
nursery.start_soon(lock_taker)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["started"]
|
||||
lock_like.release()
|
||||
|
||||
|
||||
async def test_lock_acquire_unowned_lock() -> None:
|
||||
"""Test that trying to acquire a lock whose owner has exited raises an error.
|
||||
see https://github.com/python-trio/trio/issues/3035
|
||||
"""
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
lock = trio.Lock()
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(lock.acquire)
|
||||
owner_str = re.escape(str(lock._lot.broken_by[0]))
|
||||
with pytest.raises(
|
||||
trio.BrokenResourceError,
|
||||
match=f"^Owner of this lock exited without releasing: {owner_str}$",
|
||||
):
|
||||
await lock.acquire()
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
|
||||
|
||||
async def test_lock_multiple_acquire() -> None:
|
||||
"""Test for error if awaiting on a lock whose owner exits without releasing.
|
||||
see https://github.com/python-trio/trio/issues/3035"""
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
lock = trio.Lock()
|
||||
with pytest.RaisesGroup(
|
||||
pytest.RaisesExc(
|
||||
trio.BrokenResourceError,
|
||||
match="^Owner of this lock exited without releasing: ",
|
||||
),
|
||||
):
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(lock.acquire)
|
||||
nursery.start_soon(lock.acquire)
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
|
||||
|
||||
async def test_lock_handover() -> None:
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
child_task: Task | None = None
|
||||
lock = trio.Lock()
|
||||
|
||||
# this task acquires the lock
|
||||
lock.acquire_nowait()
|
||||
assert {
|
||||
_core.current_task(): [
|
||||
lock._lot,
|
||||
],
|
||||
} == GLOBAL_PARKING_LOT_BREAKER
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(lock.acquire)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# hand over the lock to the child task
|
||||
lock.release()
|
||||
|
||||
# check values, and get the identifier out of the dict for later check
|
||||
assert len(GLOBAL_PARKING_LOT_BREAKER) == 1
|
||||
child_task = next(iter(GLOBAL_PARKING_LOT_BREAKER))
|
||||
assert GLOBAL_PARKING_LOT_BREAKER[child_task] == [lock._lot]
|
||||
|
||||
assert lock._lot.broken_by == [child_task]
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
@@ -0,0 +1,682 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# XX this should get broken up, like testing.py did
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import _core, sleep, socket as tsocket
|
||||
from .._core._tests.tutil import can_bind_ipv6
|
||||
from .._highlevel_generic import StapledStream, aclose_forcefully
|
||||
from .._highlevel_socket import SocketListener
|
||||
from ..testing import *
|
||||
from ..testing._check_streams import _assert_raises
|
||||
from ..testing._memory_streams import _UnboundedByteQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trio import Nursery
|
||||
from trio.abc import ReceiveStream, SendStream
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked() -> None:
|
||||
record = []
|
||||
|
||||
async def busy_bee() -> None:
|
||||
for _ in range(10):
|
||||
await _core.checkpoint()
|
||||
record.append("busy bee exhausted")
|
||||
|
||||
async def waiting_for_bee_to_leave() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
record.append("quiet at last!")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(busy_bee)
|
||||
nursery.start_soon(waiting_for_bee_to_leave)
|
||||
nursery.start_soon(waiting_for_bee_to_leave)
|
||||
|
||||
# check cancellation
|
||||
record = []
|
||||
|
||||
async def cancelled_while_waiting() -> None:
|
||||
try:
|
||||
await wait_all_tasks_blocked()
|
||||
except _core.Cancelled:
|
||||
record.append("ok")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(cancelled_while_waiting)
|
||||
nursery.cancel_scope.cancel()
|
||||
assert record == ["ok"]
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked_with_timeouts(mock_clock: MockClock) -> None:
|
||||
record = []
|
||||
|
||||
async def timeout_task() -> None:
|
||||
record.append("tt start")
|
||||
await sleep(5)
|
||||
record.append("tt finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(timeout_task)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["tt start"]
|
||||
mock_clock.jump(10)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["tt start", "tt finished"]
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked_with_cushion() -> None:
|
||||
record = []
|
||||
|
||||
async def blink() -> None:
|
||||
record.append("blink start")
|
||||
await sleep(0.01)
|
||||
await sleep(0.01)
|
||||
await sleep(0.01)
|
||||
record.append("blink end")
|
||||
|
||||
async def wait_no_cushion() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
record.append("wait_no_cushion end")
|
||||
|
||||
async def wait_small_cushion() -> None:
|
||||
await wait_all_tasks_blocked(0.02)
|
||||
record.append("wait_small_cushion end")
|
||||
|
||||
async def wait_big_cushion() -> None:
|
||||
await wait_all_tasks_blocked(0.03)
|
||||
record.append("wait_big_cushion end")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(blink)
|
||||
nursery.start_soon(wait_no_cushion)
|
||||
nursery.start_soon(wait_small_cushion)
|
||||
nursery.start_soon(wait_small_cushion)
|
||||
nursery.start_soon(wait_big_cushion)
|
||||
|
||||
assert record == [
|
||||
"blink start",
|
||||
"wait_no_cushion end",
|
||||
"blink end",
|
||||
"wait_small_cushion end",
|
||||
"wait_small_cushion end",
|
||||
"wait_big_cushion end",
|
||||
]
|
||||
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
async def test_assert_checkpoints(recwarn: pytest.WarningsRecorder) -> None:
|
||||
with assert_checkpoints():
|
||||
await _core.checkpoint()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_checkpoints():
|
||||
1 + 1 # noqa: B018 # "useless expression"
|
||||
|
||||
# partial yield cases
|
||||
# if you have a schedule point but not a cancel point, or vice-versa, then
|
||||
# that's not a checkpoint.
|
||||
for partial_yield in [
|
||||
_core.checkpoint_if_cancelled,
|
||||
_core.cancel_shielded_checkpoint,
|
||||
]:
|
||||
print(partial_yield)
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_checkpoints():
|
||||
await partial_yield()
|
||||
|
||||
# But both together count as a checkpoint
|
||||
with assert_checkpoints():
|
||||
await _core.checkpoint_if_cancelled()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
|
||||
|
||||
async def test_assert_no_checkpoints(recwarn: pytest.WarningsRecorder) -> None:
|
||||
with assert_no_checkpoints():
|
||||
1 + 1 # noqa: B018 # "useless expression"
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await _core.checkpoint()
|
||||
|
||||
# partial yield cases
|
||||
# if you have a schedule point but not a cancel point, or vice-versa, then
|
||||
# that doesn't make *either* version of assert_{no_,}yields happy.
|
||||
for partial_yield in [
|
||||
_core.checkpoint_if_cancelled,
|
||||
_core.cancel_shielded_checkpoint,
|
||||
]:
|
||||
print(partial_yield)
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await partial_yield()
|
||||
|
||||
# And both together also count as a checkpoint
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await _core.checkpoint_if_cancelled()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
async def test_Sequencer() -> None:
|
||||
record = []
|
||||
|
||||
def t(val: object) -> None:
|
||||
print(val)
|
||||
record.append(val)
|
||||
|
||||
async def f1(seq: Sequencer) -> None:
|
||||
async with seq(1):
|
||||
t(("f1", 1))
|
||||
async with seq(3):
|
||||
t(("f1", 3))
|
||||
async with seq(4):
|
||||
t(("f1", 4))
|
||||
|
||||
async def f2(seq: Sequencer) -> None:
|
||||
async with seq(0):
|
||||
t(("f2", 0))
|
||||
async with seq(2):
|
||||
t(("f2", 2))
|
||||
|
||||
seq = Sequencer()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(f1, seq)
|
||||
nursery.start_soon(f2, seq)
|
||||
async with seq(5):
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)]
|
||||
|
||||
seq = Sequencer()
|
||||
# Catches us if we try to reuse a sequence point:
|
||||
async with seq(0):
|
||||
pass
|
||||
with pytest.raises(RuntimeError):
|
||||
async with seq(0):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
async def test_Sequencer_cancel() -> None:
|
||||
# Killing a blocked task makes everything blow up
|
||||
record = []
|
||||
seq = Sequencer()
|
||||
|
||||
async def child(i: int) -> None:
|
||||
with _core.CancelScope() as scope:
|
||||
if i == 1:
|
||||
scope.cancel()
|
||||
try:
|
||||
async with seq(i):
|
||||
pass # pragma: no cover
|
||||
except RuntimeError:
|
||||
record.append(f"seq({i}) RuntimeError")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child, 1)
|
||||
nursery.start_soon(child, 2)
|
||||
async with seq(0):
|
||||
pass # pragma: no cover
|
||||
|
||||
assert record == ["seq(1) RuntimeError", "seq(2) RuntimeError"]
|
||||
|
||||
# Late arrivals also get errors
|
||||
with pytest.raises(RuntimeError):
|
||||
async with seq(3):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
################################################################
|
||||
def test__assert_raises() -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
with _assert_raises(RuntimeError):
|
||||
1 + 1 # noqa: B018 # "useless expression"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
with _assert_raises(RuntimeError):
|
||||
"foo" + 1 # type: ignore[operator] # noqa: B018 # "useless expression"
|
||||
|
||||
with _assert_raises(RuntimeError):
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
# This is a private implementation detail, but it's complex enough to be worth
|
||||
# testing directly
|
||||
async def test__UnboundeByteQueue() -> None:
|
||||
ubq = _UnboundedByteQueue()
|
||||
|
||||
ubq.put(b"123")
|
||||
ubq.put(b"456")
|
||||
assert ubq.get_nowait(1) == b"1"
|
||||
assert ubq.get_nowait(10) == b"23456"
|
||||
ubq.put(b"789")
|
||||
assert ubq.get_nowait() == b"789"
|
||||
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
ubq.get_nowait(10)
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
ubq.get_nowait()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
ubq.put("string") # type: ignore[arg-type]
|
||||
|
||||
ubq.put(b"abc")
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get(10) == b"abc"
|
||||
ubq.put(b"def")
|
||||
ubq.put(b"ghi")
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get(1) == b"d"
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get() == b"efghi"
|
||||
|
||||
async def putter(data: bytes) -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
ubq.put(data)
|
||||
|
||||
async def getter(expect: bytes) -> None:
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get() == expect
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"xyz")
|
||||
nursery.start_soon(putter, b"xyz")
|
||||
|
||||
# Two gets at the same time -> BusyResourceError
|
||||
with pytest.RaisesGroup(_core.BusyResourceError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"asdf")
|
||||
nursery.start_soon(getter, b"asdf")
|
||||
|
||||
# Closing
|
||||
|
||||
ubq.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
ubq.put(b"---")
|
||||
|
||||
assert ubq.get_nowait(10) == b""
|
||||
assert ubq.get_nowait() == b""
|
||||
assert await ubq.get(10) == b""
|
||||
assert await ubq.get() == b""
|
||||
|
||||
# close is idempotent
|
||||
ubq.close()
|
||||
|
||||
# close wakes up blocked getters
|
||||
ubq2 = _UnboundedByteQueue()
|
||||
|
||||
async def closer() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
ubq2.close()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"")
|
||||
nursery.start_soon(closer)
|
||||
|
||||
|
||||
async def test_MemorySendStream() -> None:
|
||||
mss = MemorySendStream()
|
||||
|
||||
async def do_send_all(data: bytes) -> None:
|
||||
with assert_checkpoints():
|
||||
await mss.send_all(data)
|
||||
|
||||
await do_send_all(b"123")
|
||||
assert mss.get_data_nowait(1) == b"1"
|
||||
assert mss.get_data_nowait() == b"23"
|
||||
|
||||
with assert_checkpoints():
|
||||
await mss.wait_send_all_might_not_block()
|
||||
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
mss.get_data_nowait()
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
mss.get_data_nowait(10)
|
||||
|
||||
await do_send_all(b"456")
|
||||
with assert_checkpoints():
|
||||
assert await mss.get_data() == b"456"
|
||||
|
||||
# Call send_all twice at once; one should get BusyResourceError and one
|
||||
# should succeed. But we can't let the error propagate, because it might
|
||||
# cause the other to be cancelled before it can finish doing its thing,
|
||||
# and we don't know which one will get the error.
|
||||
resource_busy_count = 0
|
||||
|
||||
async def do_send_all_count_resourcebusy() -> None:
|
||||
nonlocal resource_busy_count
|
||||
try:
|
||||
await do_send_all(b"xxx")
|
||||
except _core.BusyResourceError:
|
||||
resource_busy_count += 1
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all_count_resourcebusy)
|
||||
nursery.start_soon(do_send_all_count_resourcebusy)
|
||||
|
||||
assert resource_busy_count == 1
|
||||
|
||||
with assert_checkpoints():
|
||||
await mss.aclose()
|
||||
|
||||
assert await mss.get_data() == b"xxx"
|
||||
assert await mss.get_data() == b""
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"---")
|
||||
|
||||
# hooks
|
||||
|
||||
assert mss.send_all_hook is None
|
||||
assert mss.wait_send_all_might_not_block_hook is None
|
||||
assert mss.close_hook is None
|
||||
|
||||
record = []
|
||||
|
||||
async def send_all_hook() -> None:
|
||||
# hook runs after send_all does its work (can pull data out)
|
||||
assert mss2.get_data_nowait() == b"abc"
|
||||
record.append("send_all_hook")
|
||||
|
||||
async def wait_send_all_might_not_block_hook() -> None:
|
||||
record.append("wait_send_all_might_not_block_hook")
|
||||
|
||||
def close_hook() -> None:
|
||||
record.append("close_hook")
|
||||
|
||||
mss2 = MemorySendStream(
|
||||
send_all_hook,
|
||||
wait_send_all_might_not_block_hook,
|
||||
close_hook,
|
||||
)
|
||||
|
||||
assert mss2.send_all_hook is send_all_hook
|
||||
assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook
|
||||
assert mss2.close_hook is close_hook
|
||||
|
||||
await mss2.send_all(b"abc")
|
||||
await mss2.wait_send_all_might_not_block()
|
||||
await aclose_forcefully(mss2)
|
||||
mss2.close()
|
||||
|
||||
assert record == [
|
||||
"send_all_hook",
|
||||
"wait_send_all_might_not_block_hook",
|
||||
"close_hook",
|
||||
"close_hook",
|
||||
]
|
||||
|
||||
|
||||
async def test_MemoryReceiveStream() -> None:
|
||||
mrs = MemoryReceiveStream()
|
||||
|
||||
async def do_receive_some(max_bytes: int | None) -> bytes:
|
||||
with assert_checkpoints():
|
||||
return await mrs.receive_some(max_bytes)
|
||||
|
||||
mrs.put_data(b"abc")
|
||||
assert await do_receive_some(1) == b"a"
|
||||
assert await do_receive_some(10) == b"bc"
|
||||
mrs.put_data(b"abc")
|
||||
assert await do_receive_some(None) == b"abc"
|
||||
|
||||
with pytest.RaisesGroup(_core.BusyResourceError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_receive_some, 10)
|
||||
nursery.start_soon(do_receive_some, 10)
|
||||
|
||||
assert mrs.receive_some_hook is None
|
||||
|
||||
mrs.put_data(b"def")
|
||||
mrs.put_eof()
|
||||
mrs.put_eof()
|
||||
|
||||
assert await do_receive_some(10) == b"def"
|
||||
assert await do_receive_some(10) == b""
|
||||
assert await do_receive_some(10) == b""
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
mrs.put_data(b"---")
|
||||
|
||||
async def receive_some_hook() -> None:
|
||||
mrs2.put_data(b"xxx")
|
||||
|
||||
record = []
|
||||
|
||||
def close_hook() -> None:
|
||||
record.append("closed")
|
||||
|
||||
mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
|
||||
assert mrs2.receive_some_hook is receive_some_hook
|
||||
assert mrs2.close_hook is close_hook
|
||||
|
||||
mrs2.put_data(b"yyy")
|
||||
assert await mrs2.receive_some(10) == b"yyyxxx"
|
||||
assert await mrs2.receive_some(10) == b"xxx"
|
||||
assert await mrs2.receive_some(10) == b"xxx"
|
||||
|
||||
mrs2.put_data(b"zzz")
|
||||
mrs2.receive_some_hook = None
|
||||
assert await mrs2.receive_some(10) == b"zzz"
|
||||
|
||||
mrs2.put_data(b"lost on close")
|
||||
with assert_checkpoints():
|
||||
await mrs2.aclose()
|
||||
assert record == ["closed"]
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await mrs2.receive_some(10)
|
||||
|
||||
|
||||
async def test_MemoryRecvStream_closing() -> None:
|
||||
mrs = MemoryReceiveStream()
|
||||
# close with no pending data
|
||||
mrs.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
assert await mrs.receive_some(10) == b""
|
||||
# repeated closes ok
|
||||
mrs.close()
|
||||
# put_data now fails
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
mrs.put_data(b"123")
|
||||
|
||||
mrs2 = MemoryReceiveStream()
|
||||
# close with pending data
|
||||
mrs2.put_data(b"xyz")
|
||||
mrs2.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await mrs2.receive_some(10)
|
||||
|
||||
|
||||
async def test_memory_stream_pump() -> None:
|
||||
mss = MemorySendStream()
|
||||
mrs = MemoryReceiveStream()
|
||||
|
||||
# no-op if no data present
|
||||
memory_stream_pump(mss, mrs)
|
||||
|
||||
await mss.send_all(b"123")
|
||||
memory_stream_pump(mss, mrs)
|
||||
assert await mrs.receive_some(10) == b"123"
|
||||
|
||||
await mss.send_all(b"456")
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert await mrs.receive_some(10) == b"4"
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert not memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert await mrs.receive_some(10) == b"56"
|
||||
|
||||
mss.close()
|
||||
memory_stream_pump(mss, mrs)
|
||||
assert await mrs.receive_some(10) == b""
|
||||
|
||||
|
||||
async def test_memory_stream_one_way_pair() -> None:
|
||||
s, r = memory_stream_one_way_pair()
|
||||
assert s.send_all_hook is not None
|
||||
assert s.wait_send_all_might_not_block_hook is None
|
||||
assert s.close_hook is not None
|
||||
assert r.receive_some_hook is None
|
||||
await s.send_all(b"123")
|
||||
assert await r.receive_some(10) == b"123"
|
||||
|
||||
async def receiver(expected: bytes) -> None:
|
||||
assert await r.receive_some(10) == expected
|
||||
|
||||
# This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"abc")
|
||||
await wait_all_tasks_blocked()
|
||||
await s.send_all(b"abc")
|
||||
|
||||
# And this fails if we don't pump from close_hook
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"")
|
||||
await wait_all_tasks_blocked()
|
||||
await s.aclose()
|
||||
|
||||
s, r = memory_stream_one_way_pair()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"")
|
||||
await wait_all_tasks_blocked()
|
||||
s.close()
|
||||
|
||||
s, r = memory_stream_one_way_pair()
|
||||
|
||||
old = s.send_all_hook
|
||||
s.send_all_hook = None
|
||||
await s.send_all(b"456")
|
||||
|
||||
async def cancel_after_idle(nursery: Nursery) -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async def check_for_cancel() -> None:
|
||||
with pytest.raises(_core.Cancelled):
|
||||
# This should block forever... or until cancelled. Even though we
|
||||
# sent some data on the send stream.
|
||||
await r.receive_some(10)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(cancel_after_idle, nursery)
|
||||
nursery.start_soon(check_for_cancel)
|
||||
|
||||
s.send_all_hook = old
|
||||
await s.send_all(b"789")
|
||||
assert await r.receive_some(10) == b"456789"
|
||||
|
||||
|
||||
async def test_memory_stream_pair() -> None:
|
||||
a, b = memory_stream_pair()
|
||||
await a.send_all(b"123")
|
||||
await b.send_all(b"abc")
|
||||
assert await b.receive_some(10) == b"123"
|
||||
assert await a.receive_some(10) == b"abc"
|
||||
|
||||
await a.send_eof()
|
||||
assert await b.receive_some(10) == b""
|
||||
|
||||
async def sender() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
await b.send_all(b"xyz")
|
||||
|
||||
async def receiver() -> None:
|
||||
assert await a.receive_some(10) == b"xyz"
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver)
|
||||
nursery.start_soon(sender)
|
||||
|
||||
|
||||
async def test_memory_streams_with_generic_tests() -> None:
|
||||
async def one_way_stream_maker() -> tuple[MemorySendStream, MemoryReceiveStream]:
|
||||
return memory_stream_one_way_pair()
|
||||
|
||||
await check_one_way_stream(one_way_stream_maker, None)
|
||||
|
||||
async def half_closeable_stream_maker() -> tuple[
|
||||
StapledStream[MemorySendStream, MemoryReceiveStream],
|
||||
StapledStream[MemorySendStream, MemoryReceiveStream],
|
||||
]:
|
||||
return memory_stream_pair()
|
||||
|
||||
await check_half_closeable_stream(half_closeable_stream_maker, None)
|
||||
|
||||
|
||||
async def test_lockstep_streams_with_generic_tests() -> None:
|
||||
async def one_way_stream_maker() -> tuple[SendStream, ReceiveStream]:
|
||||
return lockstep_stream_one_way_pair()
|
||||
|
||||
await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
|
||||
|
||||
async def two_way_stream_maker() -> tuple[
|
||||
StapledStream[SendStream, ReceiveStream],
|
||||
StapledStream[SendStream, ReceiveStream],
|
||||
]:
|
||||
return lockstep_stream_pair()
|
||||
|
||||
await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
|
||||
|
||||
|
||||
async def test_open_stream_to_socket_listener() -> None:
|
||||
async def check(listener: SocketListener) -> None:
|
||||
async with listener:
|
||||
client_stream = await open_stream_to_socket_listener(listener)
|
||||
async with client_stream:
|
||||
server_stream = await listener.accept()
|
||||
async with server_stream:
|
||||
await client_stream.send_all(b"x")
|
||||
assert await server_stream.receive_some(1) == b"x"
|
||||
|
||||
# Listener bound to localhost
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("127.0.0.1", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
# Listener bound to IPv4 wildcard (needs special handling)
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("0.0.0.0", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
# true on all CI systems
|
||||
if can_bind_ipv6: # pragma: no branch
|
||||
# Listener bound to IPv6 wildcard (needs special handling)
|
||||
sock = tsocket.socket(family=tsocket.AF_INET6)
|
||||
await sock.bind(("::", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
if hasattr(tsocket, "AF_UNIX"):
|
||||
# Listener bound to Unix-domain socket
|
||||
sock = tsocket.socket(family=tsocket.AF_UNIX)
|
||||
# can't use pytest's tmpdir; if we try then macOS says "OSError:
|
||||
# AF_UNIX path too long"
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = f"{tmpdir}/sock"
|
||||
await sock.bind(path)
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
|
||||
def test_trio_test() -> None:
|
||||
async def busy_kitchen(
|
||||
*,
|
||||
mock_clock: object,
|
||||
autojump_clock: object,
|
||||
) -> None: ... # pragma: no cover
|
||||
|
||||
with pytest.raises(ValueError, match=r"^too many clocks spoil the broth!$"):
|
||||
trio_test(busy_kitchen)(
|
||||
mock_clock=MockClock(),
|
||||
autojump_clock=MockClock(autojump_threshold=0),
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,281 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Protocol, TypeVar
|
||||
|
||||
import outcome
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
|
||||
from .. import _core
|
||||
from .._core._tests.tutil import slow
|
||||
from .._timeouts import (
|
||||
TooSlowError,
|
||||
fail_after,
|
||||
fail_at,
|
||||
move_on_after,
|
||||
move_on_at,
|
||||
sleep,
|
||||
sleep_forever,
|
||||
sleep_until,
|
||||
)
|
||||
from ..testing import assert_checkpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def check_takes_about(f: Callable[[], Awaitable[T]], expected_dur: float) -> T:
|
||||
start = time.perf_counter()
|
||||
result = await outcome.acapture(f)
|
||||
dur = time.perf_counter() - start
|
||||
print(dur / expected_dur)
|
||||
# 1.5 is an arbitrary fudge factor because there's always some delay
|
||||
# between when we become eligible to wake up and when we actually do. We
|
||||
# used to sleep for 0.05, and regularly observed overruns of 1.6x on
|
||||
# Appveyor, and then started seeing overruns of 2.3x on Travis's macOS, so
|
||||
# now we bumped up the sleep to 1 second, marked the tests as slow, and
|
||||
# hopefully now the proportional error will be less huge.
|
||||
#
|
||||
# We also also for durations that are a hair shorter than expected. For
|
||||
# example, here's a run on Windows where a 1.0 second sleep was measured
|
||||
# to take 0.9999999999999858 seconds:
|
||||
# https://ci.appveyor.com/project/njsmith/trio/build/1.0.768/job/3lbdyxl63q3h9s21
|
||||
# I believe that what happened here is that Windows's low clock resolution
|
||||
# meant that our calls to time.monotonic() returned exactly the same
|
||||
# values as the calls inside the actual run loop, but the two subtractions
|
||||
# returned slightly different values because the run loop's clock adds a
|
||||
# random floating point offset to both times, which should cancel out, but
|
||||
# lol floating point we got slightly different rounding errors. (That
|
||||
# value above is exactly 128 ULPs below 1.0, which would make sense if it
|
||||
# started as a 1 ULP error at a different dynamic range.)
|
||||
assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
|
||||
|
||||
return result.unwrap()
|
||||
|
||||
|
||||
# How long to (attempt to) sleep for when testing. Smaller numbers make the
|
||||
# test suite go faster.
|
||||
TARGET = 1.0
|
||||
|
||||
|
||||
@slow
|
||||
async def test_sleep() -> None:
|
||||
async def sleep_1() -> None:
|
||||
await sleep_until(_core.current_time() + TARGET)
|
||||
|
||||
await check_takes_about(sleep_1, TARGET)
|
||||
|
||||
async def sleep_2() -> None:
|
||||
await sleep(TARGET)
|
||||
|
||||
await check_takes_about(sleep_2, TARGET)
|
||||
|
||||
with assert_checkpoints():
|
||||
await sleep(0)
|
||||
# This also serves as a test of the trivial move_on_at
|
||||
with move_on_at(_core.current_time()):
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await sleep(0)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_move_on_after() -> None:
|
||||
async def sleep_3() -> None:
|
||||
with move_on_after(TARGET):
|
||||
await sleep(100)
|
||||
|
||||
await check_takes_about(sleep_3, TARGET)
|
||||
|
||||
|
||||
async def test_cannot_wake_sleep_forever() -> None:
|
||||
# Test an error occurs if you manually wake sleep_forever().
|
||||
task = trio.lowlevel.current_task()
|
||||
|
||||
async def wake_task() -> None:
|
||||
await trio.lowlevel.checkpoint()
|
||||
trio.lowlevel.reschedule(task, outcome.Value(None))
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(wake_task)
|
||||
with pytest.raises(RuntimeError):
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
class TimeoutScope(Protocol):
|
||||
def __call__(self, seconds: float, *, shield: bool) -> trio.CancelScope: ...
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scope", [move_on_after, fail_after])
|
||||
async def test_context_shields_from_outer(scope: TimeoutScope) -> None:
|
||||
with _core.CancelScope() as outer, scope(TARGET, shield=True) as inner:
|
||||
outer.cancel()
|
||||
try:
|
||||
await trio.lowlevel.checkpoint()
|
||||
except trio.Cancelled: # pragma: no cover
|
||||
pytest.fail("shield didn't work")
|
||||
inner.shield = False
|
||||
with pytest.raises(trio.Cancelled):
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
@slow
|
||||
async def test_move_on_after_moves_on_even_if_shielded() -> None:
|
||||
async def task() -> None:
|
||||
with _core.CancelScope() as outer, move_on_after(TARGET, shield=True):
|
||||
outer.cancel()
|
||||
# The outer scope is cancelled, but this task is protected by the
|
||||
# shield, so it manages to get to sleep until deadline is met
|
||||
await sleep_forever()
|
||||
|
||||
await check_takes_about(task, TARGET)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_fail_after_fails_even_if_shielded() -> None:
|
||||
async def task() -> None:
|
||||
with (
|
||||
pytest.raises(TooSlowError),
|
||||
_core.CancelScope() as outer,
|
||||
fail_after(
|
||||
TARGET,
|
||||
shield=True,
|
||||
),
|
||||
):
|
||||
outer.cancel()
|
||||
# The outer scope is cancelled, but this task is protected by the
|
||||
# shield, so it manages to get to sleep until deadline is met
|
||||
await sleep_forever()
|
||||
|
||||
await check_takes_about(task, TARGET)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_fail() -> None:
|
||||
async def sleep_4() -> None:
|
||||
with fail_at(_core.current_time() + TARGET):
|
||||
await sleep(100)
|
||||
|
||||
with pytest.raises(TooSlowError):
|
||||
await check_takes_about(sleep_4, TARGET)
|
||||
|
||||
with fail_at(_core.current_time() + 100):
|
||||
await sleep(0)
|
||||
|
||||
async def sleep_5() -> None:
|
||||
with fail_after(TARGET):
|
||||
await sleep(100)
|
||||
|
||||
with pytest.raises(TooSlowError):
|
||||
await check_takes_about(sleep_5, TARGET)
|
||||
|
||||
with fail_after(100):
|
||||
await sleep(0)
|
||||
|
||||
|
||||
async def test_timeouts_raise_value_error() -> None:
|
||||
# deadlines are allowed to be negative, but not delays.
|
||||
# neither delays nor deadlines are allowed to be NaN
|
||||
|
||||
nan = float("nan")
|
||||
|
||||
for fun, val in (
|
||||
(sleep, -1),
|
||||
(sleep, nan),
|
||||
(sleep_until, nan),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"^(deadline|`seconds`) must (not )*be (non-negative|NaN)$",
|
||||
):
|
||||
await fun(val)
|
||||
|
||||
for cm, val in (
|
||||
(fail_after, -1),
|
||||
(fail_after, nan),
|
||||
(fail_at, nan),
|
||||
(move_on_after, -1),
|
||||
(move_on_after, nan),
|
||||
(move_on_at, nan),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"^(deadline|`seconds`) must (not )*be (non-negative|NaN)$",
|
||||
):
|
||||
with cm(val):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
async def test_timeout_deadline_on_entry(mock_clock: _core.MockClock) -> None:
|
||||
rcs = move_on_after(5)
|
||||
assert rcs.relative_deadline == 5
|
||||
|
||||
mock_clock.jump(3)
|
||||
start = _core.current_time()
|
||||
with rcs as cs:
|
||||
assert cs.is_relative is None
|
||||
|
||||
# This would previously be start+2
|
||||
assert cs.deadline == start + 5
|
||||
assert cs.relative_deadline == 5
|
||||
|
||||
cs.deadline = start + 3
|
||||
assert cs.deadline == start + 3
|
||||
assert cs.relative_deadline == 3
|
||||
|
||||
cs.relative_deadline = 4
|
||||
assert cs.deadline == start + 4
|
||||
assert cs.relative_deadline == 4
|
||||
|
||||
rcs = move_on_after(5)
|
||||
assert rcs.shield is False
|
||||
rcs.shield = True
|
||||
assert rcs.shield is True
|
||||
|
||||
mock_clock.jump(3)
|
||||
start = _core.current_time()
|
||||
with rcs as cs:
|
||||
assert cs.deadline == start + 5
|
||||
|
||||
assert rcs is cs
|
||||
|
||||
|
||||
async def test_invalid_access_unentered(mock_clock: _core.MockClock) -> None:
|
||||
cs = move_on_after(5)
|
||||
mock_clock.jump(3)
|
||||
start = _core.current_time()
|
||||
|
||||
match_str = "^unentered relative cancel scope does not have an absolute deadline"
|
||||
with pytest.warns(DeprecationWarning, match=match_str):
|
||||
assert cs.deadline == start + 5
|
||||
mock_clock.jump(1)
|
||||
# this is hella sketchy, but they *have* been warned
|
||||
with pytest.warns(DeprecationWarning, match=match_str):
|
||||
assert cs.deadline == start + 6
|
||||
|
||||
with pytest.warns(DeprecationWarning, match=match_str):
|
||||
cs.deadline = 7
|
||||
# now transformed into absolute
|
||||
assert cs.deadline == 7
|
||||
assert not cs.is_relative
|
||||
|
||||
cs = move_on_at(5)
|
||||
|
||||
match_str = (
|
||||
"^unentered non-relative cancel scope does not have a relative deadline$"
|
||||
)
|
||||
with pytest.raises(RuntimeError, match=match_str):
|
||||
assert cs.relative_deadline
|
||||
with pytest.raises(RuntimeError, match=match_str):
|
||||
cs.relative_deadline = 7
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="not implemented")
|
||||
async def test_fail_access_before_entering() -> None: # pragma: no cover
|
||||
my_fail_at = fail_at(5)
|
||||
assert my_fail_at.deadline # type: ignore[attr-defined]
|
||||
my_fail_after = fail_after(5)
|
||||
assert my_fail_after.relative_deadline # type: ignore[attr-defined]
|
||||
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
|
||||
async def coro1(event: trio.Event) -> None:
|
||||
event.set()
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
async def coro2(event: trio.Event) -> None:
|
||||
await coro1(event)
|
||||
|
||||
|
||||
async def coro3(event: trio.Event) -> None:
|
||||
await coro2(event)
|
||||
|
||||
|
||||
async def coro2_async_gen(event: trio.Event) -> AsyncGenerator[None, None]:
|
||||
# mypy does not like `yield await trio.lowlevel.checkpoint()` - but that
|
||||
# should be equivalent to splitting the statement
|
||||
await trio.lowlevel.checkpoint()
|
||||
yield
|
||||
await coro1(event)
|
||||
yield # pragma: no cover
|
||||
await trio.lowlevel.checkpoint() # pragma: no cover
|
||||
yield # pragma: no cover
|
||||
|
||||
|
||||
async def coro3_async_gen(event: trio.Event) -> None:
|
||||
async for _ in coro2_async_gen(event):
|
||||
pass
|
||||
|
||||
|
||||
async def test_task_iter_await_frames() -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
event = trio.Event()
|
||||
nursery.start_soon(coro3, event)
|
||||
await event.wait()
|
||||
|
||||
(task,) = nursery.child_tasks
|
||||
|
||||
assert [frame.f_code.co_name for frame, _ in task.iter_await_frames()][:3] == [
|
||||
"coro3",
|
||||
"coro2",
|
||||
"coro1",
|
||||
]
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_task_iter_await_frames_async_gen() -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
event = trio.Event()
|
||||
nursery.start_soon(coro3_async_gen, event)
|
||||
await event.wait()
|
||||
|
||||
(task,) = nursery.child_tasks
|
||||
|
||||
assert [frame.f_code.co_name for frame, _ in task.iter_await_frames()][:3] == [
|
||||
"coro3_async_gen",
|
||||
"coro2_async_gen",
|
||||
"coro1",
|
||||
]
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_closed_task_iter_await_frames() -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
task = object()
|
||||
|
||||
async def capture_task() -> None:
|
||||
nonlocal task
|
||||
task = trio.lowlevel.current_task()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
nursery.start_soon(capture_task)
|
||||
|
||||
# Task has completed, so coro.cr_frame should be None, thus no frames
|
||||
assert isinstance(task, trio.lowlevel.Task) # Ran `capture_task`
|
||||
assert task.coro.cr_frame is None # and the task was over, but
|
||||
assert list(task.iter_await_frames()) == [] # look, no crash!
|
||||
@@ -0,0 +1,8 @@
|
||||
def test_trio_import() -> None:
|
||||
import sys
|
||||
|
||||
for module in list(sys.modules.keys()):
|
||||
if module.startswith("trio"):
|
||||
del sys.modules[module]
|
||||
|
||||
import trio # noqa: F401
|
||||
@@ -0,0 +1,288 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import os
|
||||
import select
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import _core
|
||||
from .._core._tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken
|
||||
from ..testing import check_one_way_stream, wait_all_tasks_blocked
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._file_io import _HasFileNo
|
||||
|
||||
posix = os.name == "posix"
|
||||
pytestmark = pytest.mark.skipif(not posix, reason="posix only")
|
||||
|
||||
assert not TYPE_CHECKING or sys.platform == "unix"
|
||||
|
||||
if posix:
|
||||
from .._unix_pipes import FdStream
|
||||
|
||||
|
||||
async def make_pipe() -> tuple[FdStream, FdStream]:
|
||||
"""Makes a new pair of pipes."""
|
||||
r, w = os.pipe()
|
||||
return FdStream(w), FdStream(r)
|
||||
|
||||
|
||||
async def make_clogged_pipe() -> tuple[FdStream, FdStream]:
|
||||
s, r = await make_pipe()
|
||||
try:
|
||||
while True:
|
||||
# We want to totally fill up the pipe buffer.
|
||||
# This requires working around a weird feature that POSIX pipes
|
||||
# have.
|
||||
# If you do a write of <= PIPE_BUF bytes, then it's guaranteed
|
||||
# to either complete entirely, or not at all. So if we tried to
|
||||
# write PIPE_BUF bytes, and the buffer's free space is only
|
||||
# PIPE_BUF/2, then the write will raise BlockingIOError... even
|
||||
# though a smaller write could still succeed! To avoid this,
|
||||
# make sure to write >PIPE_BUF bytes each time, which disables
|
||||
# the special behavior.
|
||||
# For details, search for PIPE_BUF here:
|
||||
# http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html
|
||||
|
||||
# for the getattr:
|
||||
# https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3
|
||||
buf_size = getattr(select, "PIPE_BUF", 8192)
|
||||
os.write(s.fileno(), b"x" * buf_size * 2)
|
||||
except BlockingIOError:
|
||||
pass
|
||||
return s, r
|
||||
|
||||
|
||||
async def test_send_pipe() -> None:
|
||||
r, w = os.pipe()
|
||||
async with FdStream(w) as send:
|
||||
assert send.fileno() == w
|
||||
await send.send_all(b"123")
|
||||
assert (os.read(r, 8)) == b"123"
|
||||
|
||||
os.close(r)
|
||||
|
||||
|
||||
async def test_receive_pipe() -> None:
|
||||
r, w = os.pipe()
|
||||
async with FdStream(r) as recv:
|
||||
assert (recv.fileno()) == r
|
||||
os.write(w, b"123")
|
||||
assert (await recv.receive_some(8)) == b"123"
|
||||
|
||||
os.close(w)
|
||||
|
||||
|
||||
async def test_pipes_combined() -> None:
|
||||
write, read = await make_pipe()
|
||||
count = 2**20
|
||||
|
||||
async def sender() -> None:
|
||||
big = bytearray(count)
|
||||
await write.send_all(big)
|
||||
|
||||
async def reader() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
received = 0
|
||||
while received < count:
|
||||
received += len(await read.receive_some(4096))
|
||||
|
||||
assert received == count
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
n.start_soon(sender)
|
||||
n.start_soon(reader)
|
||||
|
||||
await read.aclose()
|
||||
await write.aclose()
|
||||
|
||||
|
||||
async def test_pipe_errors() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
FdStream(None)
|
||||
|
||||
r, w = os.pipe()
|
||||
os.close(w)
|
||||
async with FdStream(r) as s:
|
||||
with pytest.raises(ValueError, match=r"^max_bytes must be integer >= 1$"):
|
||||
await s.receive_some(0)
|
||||
|
||||
|
||||
async def test_del() -> None:
|
||||
w, r = await make_pipe()
|
||||
f1, f2 = w.fileno(), r.fileno()
|
||||
del w, r
|
||||
gc_collect_harder()
|
||||
|
||||
with pytest.raises(OSError, match=r"Bad file descriptor$") as excinfo:
|
||||
os.close(f1)
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
with pytest.raises(OSError, match=r"Bad file descriptor$") as excinfo:
|
||||
os.close(f2)
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
|
||||
async def test_async_with() -> None:
|
||||
w, r = await make_pipe()
|
||||
async with w, r:
|
||||
pass
|
||||
|
||||
assert w.fileno() == -1
|
||||
assert r.fileno() == -1
|
||||
|
||||
with pytest.raises(OSError, match=r"Bad file descriptor$") as excinfo:
|
||||
os.close(w.fileno())
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
with pytest.raises(OSError, match=r"Bad file descriptor$") as excinfo:
|
||||
os.close(r.fileno())
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
|
||||
async def test_misdirected_aclose_regression() -> None:
|
||||
# https://github.com/python-trio/trio/issues/661#issuecomment-456582356
|
||||
w, r = await make_pipe()
|
||||
old_r_fd = r.fileno()
|
||||
|
||||
# Close the original objects
|
||||
await w.aclose()
|
||||
await r.aclose()
|
||||
|
||||
# Do a little dance to get a new pipe whose receive handle matches the old
|
||||
# receive handle.
|
||||
r2_fd, w2_fd = os.pipe()
|
||||
if r2_fd != old_r_fd: # pragma: no cover
|
||||
os.dup2(r2_fd, old_r_fd)
|
||||
os.close(r2_fd)
|
||||
async with FdStream(old_r_fd) as r2:
|
||||
assert r2.fileno() == old_r_fd
|
||||
|
||||
# And now set up a background task that's working on the new receive
|
||||
# handle
|
||||
async def expect_eof() -> None:
|
||||
assert await r2.receive_some(10) == b""
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_eof)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# Here's the key test: does calling aclose() again on the *old*
|
||||
# handle, cause the task blocked on the *new* handle to raise
|
||||
# ClosedResourceError?
|
||||
await r.aclose()
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# Guess we survived! Close the new write handle so that the task
|
||||
# gets an EOF and can exit cleanly.
|
||||
os.close(w2_fd)
|
||||
|
||||
|
||||
async def test_close_at_bad_time_for_receive_some(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# We used to have race conditions where if one task was using the pipe,
|
||||
# and another closed it at *just* the wrong moment, it would give an
|
||||
# unexpected error instead of ClosedResourceError:
|
||||
# https://github.com/python-trio/trio/issues/661
|
||||
#
|
||||
# This tests what happens if the pipe gets closed in the moment *between*
|
||||
# when receive_some wakes up, and when it tries to call os.read
|
||||
async def expect_closedresourceerror() -> None:
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await r.receive_some(10)
|
||||
|
||||
orig_wait_readable = _core._run.TheIOManager.wait_readable
|
||||
|
||||
async def patched_wait_readable(
|
||||
self: _core._run.TheIOManager,
|
||||
fd: int | _HasFileNo,
|
||||
) -> None:
|
||||
await orig_wait_readable(self, fd)
|
||||
await r.aclose()
|
||||
|
||||
monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable)
|
||||
s, r = await make_pipe()
|
||||
async with s, r:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_closedresourceerror)
|
||||
await wait_all_tasks_blocked()
|
||||
# Trigger everything by waking up the receiver
|
||||
await s.send_all(b"x")
|
||||
|
||||
|
||||
async def test_close_at_bad_time_for_send_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# We used to have race conditions where if one task was using the pipe,
|
||||
# and another closed it at *just* the wrong moment, it would give an
|
||||
# unexpected error instead of ClosedResourceError:
|
||||
# https://github.com/python-trio/trio/issues/661
|
||||
#
|
||||
# This tests what happens if the pipe gets closed in the moment *between*
|
||||
# when send_all wakes up, and when it tries to call os.write
|
||||
async def expect_closedresourceerror() -> None:
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await s.send_all(b"x" * 100)
|
||||
|
||||
orig_wait_writable = _core._run.TheIOManager.wait_writable
|
||||
|
||||
async def patched_wait_writable(
|
||||
self: _core._run.TheIOManager,
|
||||
fd: int | _HasFileNo,
|
||||
) -> None:
|
||||
await orig_wait_writable(self, fd)
|
||||
await s.aclose()
|
||||
|
||||
monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable)
|
||||
s, r = await make_clogged_pipe()
|
||||
async with s, r:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_closedresourceerror)
|
||||
await wait_all_tasks_blocked()
|
||||
# Trigger everything by waking up the sender. On ppc64el, PIPE_BUF
|
||||
# is 8192 but make_clogged_pipe() ends up writing a total of
|
||||
# 1048576 bytes before the pipe is full, and then a subsequent
|
||||
# receive_some(10000) isn't sufficient for orig_wait_writable() to
|
||||
# return for our subsequent aclose() call. It's necessary to empty
|
||||
# the pipe further before this happens. So we loop here until the
|
||||
# pipe is empty to make sure that the sender wakes up even in this
|
||||
# case. Otherwise patched_wait_writable() never gets to the
|
||||
# aclose(), so expect_closedresourceerror() never returns, the
|
||||
# nursery never finishes all tasks and this test hangs.
|
||||
received_data = await r.receive_some(10000)
|
||||
while received_data:
|
||||
received_data = await r.receive_some(10000)
|
||||
|
||||
|
||||
# On FreeBSD, directories are readable, and we haven't found any other trick
|
||||
# for making an unreadable fd, so there's no way to run this test. Fortunately
|
||||
# the logic this is testing doesn't depend on the platform, so testing on
|
||||
# other platforms is probably good enough.
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("freebsd"),
|
||||
reason="no way to make read() return a bizarro error on FreeBSD",
|
||||
)
|
||||
async def test_bizarro_OSError_from_receive() -> None:
|
||||
# Make sure that if the read syscall returns some bizarro error, then we
|
||||
# get a BrokenResourceError. This is incredibly unlikely; there's almost
|
||||
# no way to trigger a failure here intentionally (except for EBADF, but we
|
||||
# exploit that to detect file closure, so it takes a different path). So
|
||||
# we set up a strange scenario where the pipe fd somehow transmutes into a
|
||||
# directory fd, causing os.read to raise IsADirectoryError (yes, that's a
|
||||
# real built-in exception type).
|
||||
s, r = await make_pipe()
|
||||
async with s, r:
|
||||
dir_fd = os.open("/", os.O_DIRECTORY, 0)
|
||||
try:
|
||||
os.dup2(dir_fd, r.fileno())
|
||||
with pytest.raises(_core.BrokenResourceError):
|
||||
await r.receive_some(10)
|
||||
finally:
|
||||
os.close(dir_fd)
|
||||
|
||||
|
||||
@skip_if_fbsd_pipes_broken
|
||||
async def test_pipe_fully() -> None:
|
||||
await check_one_way_stream(make_pipe, make_clogged_pipe)
|
||||
@@ -0,0 +1,349 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Coroutine, Generator
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing import _Matcher as Matcher, _RaisesGroup as RaisesGroup
|
||||
|
||||
from .. import _core
|
||||
from .._core._tests.tutil import (
|
||||
create_asyncio_future_in_new_loop,
|
||||
ignore_coroutine_never_awaited_warnings,
|
||||
)
|
||||
from .._util import (
|
||||
ConflictDetector,
|
||||
MultipleExceptionError,
|
||||
NoPublicConstructor,
|
||||
coroutine_or_error,
|
||||
final,
|
||||
fixup_module_metadata,
|
||||
is_main_thread,
|
||||
raise_single_exception_from_group,
|
||||
)
|
||||
from ..testing import wait_all_tasks_blocked
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def test_ConflictDetector() -> None:
|
||||
ul1 = ConflictDetector("ul1")
|
||||
ul2 = ConflictDetector("ul2")
|
||||
|
||||
with ul1:
|
||||
with ul2:
|
||||
print("ok")
|
||||
|
||||
with pytest.raises(_core.BusyResourceError, match="ul1"):
|
||||
with ul1:
|
||||
with ul1:
|
||||
pass # pragma: no cover
|
||||
|
||||
async def wait_with_ul1() -> None:
|
||||
with ul1:
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
with RaisesGroup(Matcher(_core.BusyResourceError, "ul1")):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_with_ul1)
|
||||
nursery.start_soon(wait_with_ul1)
|
||||
|
||||
|
||||
def test_module_metadata_is_fixed_up() -> None:
|
||||
import trio
|
||||
import trio.testing
|
||||
|
||||
assert trio.Cancelled.__module__ == "trio"
|
||||
assert trio.open_nursery.__module__ == "trio"
|
||||
assert trio.abc.Stream.__module__ == "trio.abc"
|
||||
assert trio.lowlevel.wait_task_rescheduled.__module__ == "trio.lowlevel"
|
||||
assert trio.testing.trio_test.__module__ == "trio.testing"
|
||||
|
||||
# Also check methods
|
||||
assert trio.lowlevel.ParkingLot.__init__.__module__ == "trio.lowlevel"
|
||||
assert trio.abc.Stream.send_all.__module__ == "trio.abc"
|
||||
|
||||
# And names
|
||||
assert trio.Cancelled.__name__ == "Cancelled"
|
||||
assert trio.Cancelled.__qualname__ == "Cancelled"
|
||||
assert trio.abc.SendStream.send_all.__name__ == "send_all"
|
||||
assert trio.abc.SendStream.send_all.__qualname__ == "SendStream.send_all"
|
||||
assert trio.to_thread.__name__ == "trio.to_thread"
|
||||
assert trio.to_thread.run_sync.__name__ == "run_sync"
|
||||
assert trio.to_thread.run_sync.__qualname__ == "run_sync"
|
||||
|
||||
|
||||
async def test_is_main_thread() -> None:
|
||||
assert is_main_thread()
|
||||
|
||||
def not_main_thread() -> None:
|
||||
assert not is_main_thread()
|
||||
|
||||
await trio.to_thread.run_sync(not_main_thread)
|
||||
|
||||
|
||||
# @coroutine is deprecated since python 3.8, which is fine with us.
|
||||
@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning")
|
||||
def test_coroutine_or_error() -> None:
|
||||
class Deferred:
|
||||
"Just kidding"
|
||||
|
||||
with ignore_coroutine_never_awaited_warnings():
|
||||
|
||||
async def f() -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(f()) # type: ignore[arg-type, unused-coroutine]
|
||||
assert "expecting an async function" in str(excinfo.value)
|
||||
|
||||
import asyncio
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
|
||||
@asyncio.coroutine
|
||||
def generator_based_coro() -> (
|
||||
Generator[Coroutine[None, None, None], None, None]
|
||||
): # pragma: no cover
|
||||
yield from asyncio.sleep(1)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(generator_based_coro()) # type: ignore[arg-type, unused-coroutine]
|
||||
assert "asyncio" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(create_asyncio_future_in_new_loop()) # type: ignore[arg-type, unused-coroutine]
|
||||
assert "asyncio" in str(excinfo.value)
|
||||
|
||||
# does not raise arg-type error
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(create_asyncio_future_in_new_loop) # type: ignore[unused-coroutine]
|
||||
assert "asyncio" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(Deferred()) # type: ignore[arg-type, unused-coroutine]
|
||||
assert "twisted" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(lambda: Deferred()) # type: ignore[arg-type, unused-coroutine, return-value]
|
||||
assert "twisted" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(len, [[1, 2, 3]]) # type: ignore[arg-type, unused-coroutine]
|
||||
|
||||
assert "appears to be synchronous" in str(excinfo.value)
|
||||
|
||||
async def async_gen(
|
||||
_: object,
|
||||
) -> AsyncGenerator[None, None]: # pragma: no cover
|
||||
yield
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(async_gen, [0]) # type: ignore[arg-type,unused-coroutine]
|
||||
msg = "expected an async function but got an async generator"
|
||||
assert msg in str(excinfo.value)
|
||||
|
||||
# Make sure no references are kept around to keep anything alive
|
||||
del excinfo
|
||||
|
||||
|
||||
def test_final_decorator() -> None:
|
||||
"""Test that subclassing a @final-annotated class is not allowed.
|
||||
|
||||
This checks both runtime results, and verifies that type checkers detect
|
||||
the error statically through the type-ignore comment.
|
||||
"""
|
||||
|
||||
@final
|
||||
class FinalClass:
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class SubClass(FinalClass): # type: ignore[misc]
|
||||
pass
|
||||
|
||||
|
||||
def test_no_public_constructor_metaclass() -> None:
|
||||
"""The NoPublicConstructor metaclass prevents calling the constructor directly."""
|
||||
|
||||
class SpecialClass(metaclass=NoPublicConstructor):
|
||||
def __init__(self, a: int, b: float) -> None:
|
||||
"""Check arguments can be passed to __init__."""
|
||||
assert a == 8
|
||||
assert b == 3.15
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
SpecialClass(8, 3.15)
|
||||
|
||||
# Private constructor should not raise, and passes args to __init__.
|
||||
assert isinstance(SpecialClass._create(8, b=3.15), SpecialClass)
|
||||
|
||||
|
||||
def test_fixup_module_metadata() -> None:
|
||||
# Ignores modules not in the trio.X tree.
|
||||
non_trio_module = types.ModuleType("not_trio")
|
||||
non_trio_module.some_func = lambda: None # type: ignore[attr-defined]
|
||||
non_trio_module.some_func.__name__ = "some_func"
|
||||
non_trio_module.some_func.__qualname__ = "some_func"
|
||||
|
||||
fixup_module_metadata(non_trio_module.__name__, vars(non_trio_module))
|
||||
|
||||
assert non_trio_module.some_func.__name__ == "some_func"
|
||||
assert non_trio_module.some_func.__qualname__ == "some_func"
|
||||
|
||||
# Bulild up a fake module to test. Just use lambdas since all we care about is the names.
|
||||
mod = types.ModuleType("trio._somemodule_impl")
|
||||
mod.some_func = lambda: None # type: ignore[attr-defined]
|
||||
mod.some_func.__name__ = "_something_else"
|
||||
mod.some_func.__qualname__ = "_something_else"
|
||||
|
||||
# No __module__ means it's unchanged.
|
||||
mod.not_funclike = types.SimpleNamespace() # type: ignore[attr-defined]
|
||||
mod.not_funclike.__name__ = "not_funclike"
|
||||
|
||||
# Check __qualname__ being absent works.
|
||||
mod.only_has_name = types.SimpleNamespace() # type: ignore[attr-defined]
|
||||
mod.only_has_name.__module__ = "trio._somemodule_impl"
|
||||
mod.only_has_name.__name__ = "only_name"
|
||||
|
||||
# Underscored names are unchanged.
|
||||
mod._private = lambda: None # type: ignore[attr-defined]
|
||||
mod._private.__module__ = "trio._somemodule_impl"
|
||||
mod._private.__name__ = mod._private.__qualname__ = "_private"
|
||||
|
||||
# We recurse into classes.
|
||||
mod.SomeClass = type( # type: ignore[attr-defined]
|
||||
"SomeClass",
|
||||
(),
|
||||
{
|
||||
"__init__": lambda self: None,
|
||||
"method": lambda self: None,
|
||||
},
|
||||
)
|
||||
# Reference loop is fine.
|
||||
mod.SomeClass.recursion = mod.SomeClass
|
||||
|
||||
fixup_module_metadata("trio.somemodule", vars(mod))
|
||||
assert mod.some_func.__name__ == "some_func"
|
||||
assert mod.some_func.__module__ == "trio.somemodule"
|
||||
assert mod.some_func.__qualname__ == "some_func"
|
||||
|
||||
assert mod.not_funclike.__name__ == "not_funclike"
|
||||
assert mod._private.__name__ == "_private"
|
||||
assert mod._private.__module__ == "trio._somemodule_impl"
|
||||
assert mod._private.__qualname__ == "_private"
|
||||
|
||||
assert mod.only_has_name.__name__ == "only_has_name"
|
||||
assert mod.only_has_name.__module__ == "trio.somemodule"
|
||||
assert not hasattr(mod.only_has_name, "__qualname__")
|
||||
|
||||
assert mod.SomeClass.method.__name__ == "method"
|
||||
assert mod.SomeClass.method.__module__ == "trio.somemodule"
|
||||
assert mod.SomeClass.method.__qualname__ == "SomeClass.method"
|
||||
# Make coverage happy.
|
||||
non_trio_module.some_func()
|
||||
mod.some_func()
|
||||
mod._private()
|
||||
mod.SomeClass().method()
|
||||
|
||||
|
||||
async def test_raise_single_exception_from_group() -> None:
|
||||
excinfo: pytest.ExceptionInfo[BaseException]
|
||||
|
||||
exc = ValueError("foo")
|
||||
cause = SyntaxError("cause")
|
||||
context = TypeError("context")
|
||||
exc.__cause__ = cause
|
||||
exc.__context__ = context
|
||||
cancelled = trio.Cancelled._create(source="deadline")
|
||||
|
||||
with pytest.raises(ValueError, match="foo") as excinfo:
|
||||
raise_single_exception_from_group(ExceptionGroup("", [exc]))
|
||||
assert excinfo.value.__cause__ == cause
|
||||
assert excinfo.value.__context__ == context
|
||||
|
||||
# only unwraps one layer of exceptiongroup
|
||||
inner_eg = ExceptionGroup("inner eg", [exc])
|
||||
inner_cause = SyntaxError("inner eg cause")
|
||||
inner_context = TypeError("inner eg context")
|
||||
inner_eg.__cause__ = inner_cause
|
||||
inner_eg.__context__ = inner_context
|
||||
with RaisesGroup(Matcher(ValueError, match="^foo$"), match="^inner eg$") as eginfo:
|
||||
raise_single_exception_from_group(ExceptionGroup("", [inner_eg]))
|
||||
assert eginfo.value.__cause__ == inner_cause
|
||||
assert eginfo.value.__context__ == inner_context
|
||||
|
||||
with pytest.raises(ValueError, match="foo") as excinfo:
|
||||
raise_single_exception_from_group(
|
||||
BaseExceptionGroup("", [cancelled, cancelled, exc])
|
||||
)
|
||||
assert excinfo.value.__cause__ == cause
|
||||
assert excinfo.value.__context__ == context
|
||||
|
||||
# multiple non-cancelled
|
||||
eg = ExceptionGroup("", [ValueError("foo"), ValueError("bar")])
|
||||
with pytest.raises(
|
||||
MultipleExceptionError,
|
||||
match=r"^Attempted to unwrap exceptiongroup with multiple non-cancelled exceptions. This is often caused by a bug in the caller.$",
|
||||
) as excinfo:
|
||||
raise_single_exception_from_group(eg)
|
||||
assert excinfo.value.__cause__ is eg
|
||||
assert excinfo.value.__context__ is None
|
||||
|
||||
# keyboardinterrupt overrides everything
|
||||
eg_ki = BaseExceptionGroup(
|
||||
"",
|
||||
[
|
||||
ValueError("foo"),
|
||||
ValueError("bar"),
|
||||
KeyboardInterrupt("preserve error msg"),
|
||||
],
|
||||
)
|
||||
with pytest.raises(
|
||||
KeyboardInterrupt,
|
||||
match=r"^preserve error msg$",
|
||||
) as excinfo:
|
||||
raise_single_exception_from_group(eg_ki)
|
||||
|
||||
assert excinfo.value.__cause__ is eg_ki
|
||||
assert excinfo.value.__context__ is None
|
||||
|
||||
# and same for SystemExit but verify code too
|
||||
systemexit_ki = BaseExceptionGroup(
|
||||
"",
|
||||
[
|
||||
ValueError("foo"),
|
||||
ValueError("bar"),
|
||||
SystemExit(2),
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
raise_single_exception_from_group(systemexit_ki)
|
||||
|
||||
assert excinfo.value.code == 2
|
||||
assert excinfo.value.__cause__ is systemexit_ki
|
||||
assert excinfo.value.__context__ is None
|
||||
|
||||
# if we only got cancelled, first one is reraised
|
||||
with pytest.raises(trio.Cancelled, match=r"^cancelled due to deadline$") as excinfo:
|
||||
raise_single_exception_from_group(
|
||||
BaseExceptionGroup(
|
||||
"", [cancelled, trio.Cancelled._create(source="explicit")]
|
||||
)
|
||||
)
|
||||
assert excinfo.value is cancelled
|
||||
assert excinfo.value.__cause__ is None
|
||||
assert excinfo.value.__context__ is None
|
||||
@@ -0,0 +1,225 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
on_windows = os.name == "nt"
|
||||
# Mark all the tests in this file as being windows-only
|
||||
pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
|
||||
|
||||
import trio
|
||||
|
||||
from .. import _core, _timeouts
|
||||
from .._core._tests.tutil import slow
|
||||
|
||||
if on_windows:
|
||||
from .._core._windows_cffi import Handle, ffi, kernel32
|
||||
from .._wait_for_object import WaitForMultipleObjects_sync, WaitForSingleObject
|
||||
|
||||
|
||||
def test_WaitForMultipleObjects_sync() -> None:
|
||||
# This does a series of tests where we set/close the handle before
|
||||
# initiating the waiting for it.
|
||||
#
|
||||
# Note that closing the handle (not signaling) will cause the
|
||||
# *initiation* of a wait to return immediately. But closing a handle
|
||||
# that is already being waited on will not stop whatever is waiting
|
||||
# for it.
|
||||
|
||||
# One handle
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle1)
|
||||
WaitForMultipleObjects_sync(handle1)
|
||||
kernel32.CloseHandle(handle1)
|
||||
print("test_WaitForMultipleObjects_sync one OK")
|
||||
|
||||
# Two handles, signal first
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle1)
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync set first OK")
|
||||
|
||||
# Two handles, signal second
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle2)
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync set second OK")
|
||||
|
||||
# Two handles, close first
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.CloseHandle(handle1)
|
||||
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync close first OK")
|
||||
|
||||
# Two handles, close second
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.CloseHandle(handle2)
|
||||
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle1)
|
||||
print("test_WaitForMultipleObjects_sync close second OK")
|
||||
|
||||
|
||||
@slow
|
||||
async def test_WaitForMultipleObjects_sync_slow() -> None:
|
||||
# This does a series of test in which the main thread sync-waits for
|
||||
# handles, while we spawn a thread to set the handles after a short while.
|
||||
|
||||
TIMEOUT = 0.3
|
||||
|
||||
# One handle
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
trio.to_thread.run_sync,
|
||||
WaitForMultipleObjects_sync,
|
||||
handle1,
|
||||
)
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
# If we would comment the line below, the above thread will be stuck,
|
||||
# and Trio won't exit this scope
|
||||
kernel32.SetEvent(handle1)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
kernel32.CloseHandle(handle1)
|
||||
print("test_WaitForMultipleObjects_sync_slow one OK")
|
||||
|
||||
# Two handles, signal first
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
trio.to_thread.run_sync,
|
||||
WaitForMultipleObjects_sync,
|
||||
handle1,
|
||||
handle2,
|
||||
)
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
kernel32.SetEvent(handle1)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync_slow thread-set first OK")
|
||||
|
||||
# Two handles, signal second
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
trio.to_thread.run_sync,
|
||||
WaitForMultipleObjects_sync,
|
||||
handle1,
|
||||
handle2,
|
||||
)
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
kernel32.SetEvent(handle2)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync_slow thread-set second OK")
|
||||
|
||||
|
||||
async def test_WaitForSingleObject() -> None:
|
||||
# This does a series of test for setting/closing the handle before
|
||||
# initiating the wait.
|
||||
|
||||
# Test already set
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle)
|
||||
await WaitForSingleObject(handle) # should return at once
|
||||
kernel32.CloseHandle(handle)
|
||||
print("test_WaitForSingleObject already set OK")
|
||||
|
||||
# Test already set, as int
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle_int = int(ffi.cast("intptr_t", handle))
|
||||
kernel32.SetEvent(handle)
|
||||
await WaitForSingleObject(handle_int) # should return at once
|
||||
kernel32.CloseHandle(handle)
|
||||
print("test_WaitForSingleObject already set OK")
|
||||
|
||||
# Test already closed
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.CloseHandle(handle)
|
||||
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
|
||||
await WaitForSingleObject(handle) # should return at once
|
||||
print("test_WaitForSingleObject already closed OK")
|
||||
|
||||
# Not a handle
|
||||
with pytest.raises(TypeError):
|
||||
await WaitForSingleObject("not a handle") # type: ignore[arg-type] # Wrong type
|
||||
# with pytest.raises(OSError):
|
||||
# await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :(
|
||||
print("test_WaitForSingleObject not a handle OK")
|
||||
|
||||
|
||||
@slow
|
||||
async def test_WaitForSingleObject_slow() -> None:
|
||||
# This does a series of test for setting the handle in another task,
|
||||
# and cancelling the wait task.
|
||||
|
||||
# Set the timeout used in the tests. We test the waiting time against
|
||||
# the timeout with a certain margin.
|
||||
TIMEOUT = 0.3
|
||||
|
||||
async def signal_soon_async(handle: Handle) -> None:
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
kernel32.SetEvent(handle)
|
||||
|
||||
# Test handle is SET after TIMEOUT in separate coroutine
|
||||
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(WaitForSingleObject, handle)
|
||||
nursery.start_soon(signal_soon_async, handle)
|
||||
|
||||
kernel32.CloseHandle(handle)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
print("test_WaitForSingleObject_slow set from task OK")
|
||||
|
||||
# Test handle is SET after TIMEOUT in separate coroutine, as int
|
||||
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle_int = int(ffi.cast("intptr_t", handle))
|
||||
t0 = _core.current_time()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(WaitForSingleObject, handle_int)
|
||||
nursery.start_soon(signal_soon_async, handle)
|
||||
|
||||
kernel32.CloseHandle(handle)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
print("test_WaitForSingleObject_slow set from task as int OK")
|
||||
|
||||
# Test handle is CLOSED after 1 sec - NOPE see comment above
|
||||
|
||||
# Test cancellation
|
||||
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
|
||||
with _timeouts.move_on_after(TIMEOUT):
|
||||
await WaitForSingleObject(handle)
|
||||
|
||||
kernel32.CloseHandle(handle)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
print("test_WaitForSingleObject_slow cancellation OK")
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import _core
|
||||
from ..testing import check_one_way_stream, wait_all_tasks_blocked
|
||||
|
||||
# Mark all the tests in this file as being windows-only
|
||||
pytestmark = pytest.mark.skipif(sys.platform != "win32", reason="windows only")
|
||||
|
||||
assert ( # Skip type checking when not on Windows
|
||||
sys.platform == "win32" or not TYPE_CHECKING
|
||||
)
|
||||
|
||||
if sys.platform == "win32":
|
||||
from asyncio.windows_utils import pipe
|
||||
|
||||
from .._core._windows_cffi import _handle, kernel32
|
||||
from .._windows_pipes import PipeReceiveStream, PipeSendStream
|
||||
|
||||
|
||||
async def make_pipe() -> tuple[PipeSendStream, PipeReceiveStream]:
|
||||
"""Makes a new pair of pipes."""
|
||||
r, w = pipe()
|
||||
return PipeSendStream(w), PipeReceiveStream(r)
|
||||
|
||||
|
||||
def test_pipe_typecheck() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
PipeSendStream(1.0) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
PipeReceiveStream(None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_pipe_error_on_close() -> None:
|
||||
# Make sure we correctly handle a failure from kernel32.CloseHandle
|
||||
r, w = pipe()
|
||||
|
||||
send_stream = PipeSendStream(w)
|
||||
receive_stream = PipeReceiveStream(r)
|
||||
|
||||
assert kernel32.CloseHandle(_handle(r))
|
||||
assert kernel32.CloseHandle(_handle(w))
|
||||
|
||||
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
|
||||
await send_stream.aclose()
|
||||
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
|
||||
await receive_stream.aclose()
|
||||
|
||||
|
||||
async def test_pipes_combined() -> None:
|
||||
write, read = await make_pipe()
|
||||
count = 2**20
|
||||
replicas = 3
|
||||
|
||||
async def sender() -> None:
|
||||
async with write:
|
||||
big = bytearray(count)
|
||||
for _ in range(replicas):
|
||||
await write.send_all(big)
|
||||
|
||||
async def reader() -> None:
|
||||
async with read:
|
||||
await wait_all_tasks_blocked()
|
||||
total_received = 0
|
||||
while True:
|
||||
# 5000 is chosen because it doesn't evenly divide 2**20
|
||||
received = len(await read.receive_some(5000))
|
||||
if not received:
|
||||
break
|
||||
total_received += received
|
||||
|
||||
assert total_received == count * replicas
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
n.start_soon(sender)
|
||||
n.start_soon(reader)
|
||||
|
||||
|
||||
async def test_async_with() -> None:
|
||||
w, r = await make_pipe()
|
||||
async with w, r:
|
||||
pass
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await w.send_all(b"")
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await r.receive_some(10)
|
||||
|
||||
|
||||
async def test_close_during_write() -> None:
|
||||
w, _r = await make_pipe()
|
||||
async with _core.open_nursery() as nursery:
|
||||
|
||||
async def write_forever() -> None:
|
||||
with pytest.raises(_core.ClosedResourceError) as excinfo:
|
||||
while True:
|
||||
await w.send_all(b"x" * 4096)
|
||||
assert "another task" in str(excinfo.value)
|
||||
|
||||
nursery.start_soon(write_forever)
|
||||
await wait_all_tasks_blocked(0.1)
|
||||
await w.aclose()
|
||||
|
||||
|
||||
async def test_pipe_fully() -> None:
|
||||
# passing make_clogged_pipe tests wait_send_all_might_not_block, and we
|
||||
# can't implement that on Windows
|
||||
await check_one_way_stream(make_pipe, None)
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,179 @@
|
||||
import ast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from trio._tests.pytest_plugin import skip_if_optional_else_raise
|
||||
|
||||
# imports in gen_exports that are not in `install_requires` in setup.py
|
||||
try:
|
||||
import astor # noqa: F401
|
||||
import isort # noqa: F401
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
|
||||
from trio._tools.gen_exports import (
|
||||
File,
|
||||
create_passthrough_args,
|
||||
get_public_methods,
|
||||
process,
|
||||
run_linters,
|
||||
run_ruff,
|
||||
)
|
||||
|
||||
from ..._core._tests.tutil import slow
|
||||
|
||||
SOURCE = '''from _run import _public
|
||||
from collections import Counter
|
||||
|
||||
class Test:
|
||||
@_public
|
||||
def public_func(self):
|
||||
"""With doc string"""
|
||||
|
||||
@ignore_this
|
||||
@_public
|
||||
@another_decorator
|
||||
async def public_async_func(self) -> Counter:
|
||||
pass # no doc string
|
||||
|
||||
def not_public(self):
|
||||
pass
|
||||
|
||||
async def not_public_async(self):
|
||||
pass
|
||||
'''
|
||||
|
||||
IMPORT_1 = """\
|
||||
from collections import Counter
|
||||
"""
|
||||
|
||||
IMPORT_2 = """\
|
||||
from collections import Counter
|
||||
import os
|
||||
"""
|
||||
|
||||
IMPORT_3 = """\
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from collections import Counter
|
||||
"""
|
||||
|
||||
|
||||
def test_get_public_methods() -> None:
|
||||
methods = list(get_public_methods(ast.parse(SOURCE)))
|
||||
assert {m.name for m in methods} == {"public_func", "public_async_func"}
|
||||
|
||||
|
||||
def test_create_pass_through_args() -> None:
|
||||
testcases = [
|
||||
("def f()", "()"),
|
||||
("def f(one)", "(one)"),
|
||||
("def f(one, two)", "(one, two)"),
|
||||
("def f(one, *args)", "(one, *args)"),
|
||||
(
|
||||
"def f(one, *args, kw1, kw2=None, **kwargs)",
|
||||
"(one, *args, kw1=kw1, kw2=kw2, **kwargs)",
|
||||
),
|
||||
]
|
||||
|
||||
for funcdef, expected in testcases:
|
||||
func_node = ast.parse(funcdef + ":\n pass").body[0]
|
||||
assert isinstance(func_node, ast.FunctionDef)
|
||||
assert create_passthrough_args(func_node) == expected
|
||||
|
||||
|
||||
skip_lints = pytest.mark.skipif(
|
||||
sys.implementation.name != "cpython",
|
||||
reason="gen_exports is internal, black/isort only runs on CPython",
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@skip_lints
|
||||
@pytest.mark.parametrize("imports", [IMPORT_1, IMPORT_2, IMPORT_3])
|
||||
def test_process(
|
||||
tmp_path: Path,
|
||||
imports: str,
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
try:
|
||||
import black # noqa: F401
|
||||
# there's no dedicated CI run that has astor+isort, but lacks black.
|
||||
except ImportError as error: # pragma: no cover
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
modpath = tmp_path / "_module.py"
|
||||
genpath = tmp_path / "_generated_module.py"
|
||||
modpath.write_text(SOURCE, encoding="utf-8")
|
||||
file = File(modpath, "runner", platform="linux", imports=imports)
|
||||
assert not genpath.exists()
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
process([file], do_test=True)
|
||||
assert excinfo.value.code == 1
|
||||
captured = capsys.readouterr()
|
||||
assert "Generated sources are outdated. Please regenerate." in captured.out
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
process([file], do_test=False)
|
||||
assert excinfo.value.code == 1
|
||||
captured = capsys.readouterr()
|
||||
assert "Regenerated sources successfully." in captured.out
|
||||
assert genpath.exists()
|
||||
process([file], do_test=True)
|
||||
# But if we change the lookup path it notices
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
process(
|
||||
[File(modpath, "runner.io_manager", platform="linux", imports=imports)],
|
||||
do_test=True,
|
||||
)
|
||||
assert excinfo.value.code == 1
|
||||
# Also if the platform is changed.
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
process([File(modpath, "runner", imports=imports)], do_test=True)
|
||||
assert excinfo.value.code == 1
|
||||
|
||||
|
||||
@skip_lints
|
||||
def test_run_ruff(tmp_path: Path) -> None:
|
||||
"""Test that processing properly fails if ruff does."""
|
||||
try:
|
||||
import ruff # noqa: F401
|
||||
except ImportError as error: # pragma: no cover
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
file = File(tmp_path / "module.py", "module")
|
||||
|
||||
success, _ = run_ruff(file, "class not valid code ><")
|
||||
assert not success
|
||||
|
||||
test_function = '''def combine_and(data: list[str]) -> str:
|
||||
"""Join values of text, and have 'and' with the last one properly."""
|
||||
if len(data) >= 2:
|
||||
data[-1] = 'and ' + data[-1]
|
||||
if len(data) > 2:
|
||||
return ', '.join(data)
|
||||
return ' '.join(data)'''
|
||||
|
||||
success, response = run_ruff(file, test_function)
|
||||
assert success
|
||||
assert response == test_function
|
||||
|
||||
|
||||
@skip_lints
|
||||
def test_lint_failure(tmp_path: Path) -> None:
|
||||
"""Test that processing properly fails if black or ruff does."""
|
||||
try:
|
||||
import black # noqa: F401
|
||||
import ruff # noqa: F401
|
||||
except ImportError as error: # pragma: no cover
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
file = File(tmp_path / "module.py", "module")
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
run_linters(file, "class not valid code ><")
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
run_linters(file, "import waffle\n;import trio")
|
||||
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from trio._tools.mypy_annotate import Result, export, main, process_line
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("src", "expected"),
|
||||
[
|
||||
("", None),
|
||||
("a regular line\n", None),
|
||||
(
|
||||
"package\\filename.py:42:8: note: Some info\n",
|
||||
Result(
|
||||
kind="notice",
|
||||
filename="package\\filename.py",
|
||||
start_line=42,
|
||||
start_col=8,
|
||||
end_line=None,
|
||||
end_col=None,
|
||||
message=" Some info",
|
||||
),
|
||||
),
|
||||
(
|
||||
"package/filename.py:42:1:46:3: error: Type error here [code]\n",
|
||||
Result(
|
||||
kind="error",
|
||||
filename="package/filename.py",
|
||||
start_line=42,
|
||||
start_col=1,
|
||||
end_line=46,
|
||||
end_col=3,
|
||||
message=" Type error here [code]",
|
||||
),
|
||||
),
|
||||
(
|
||||
"package/module.py:87: warn: Bad code\n",
|
||||
Result(
|
||||
kind="warning",
|
||||
filename="package/module.py",
|
||||
start_line=87,
|
||||
message=" Bad code",
|
||||
),
|
||||
),
|
||||
],
|
||||
ids=["blank", "normal", "note-wcol", "error-wend", "warn-lineonly"],
|
||||
)
|
||||
def test_processing(src: str, expected: Result | None) -> None:
|
||||
result = process_line(src)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_export(capsys: pytest.CaptureFixture[str]) -> None:
|
||||
results = {
|
||||
Result(
|
||||
kind="notice",
|
||||
filename="package\\filename.py",
|
||||
start_line=42,
|
||||
start_col=8,
|
||||
end_line=None,
|
||||
end_col=None,
|
||||
message=" Some info",
|
||||
): ["Windows", "Mac"],
|
||||
Result(
|
||||
kind="error",
|
||||
filename="package/filename.py",
|
||||
start_line=42,
|
||||
start_col=1,
|
||||
end_line=46,
|
||||
end_col=3,
|
||||
message=" Type error here [code]",
|
||||
): ["Linux", "Mac"],
|
||||
Result(
|
||||
kind="warning",
|
||||
filename="package/module.py",
|
||||
start_line=87,
|
||||
message=" Bad code",
|
||||
): ["Linux"],
|
||||
}
|
||||
export(results)
|
||||
std = capsys.readouterr()
|
||||
assert std.err == ""
|
||||
assert std.out == (
|
||||
"::notice file=package\\filename.py,line=42,col=8,"
|
||||
"title=Mypy-Windows+Mac::package\\filename.py:(42:8): Some info"
|
||||
"\n"
|
||||
"::error file=package/filename.py,line=42,col=1,endLine=46,endColumn=3,"
|
||||
"title=Mypy-Linux+Mac::package/filename.py:(42:1 - 46:3): Type error here [code]"
|
||||
"\n"
|
||||
"::warning file=package/module.py,line=87,"
|
||||
"title=Mypy-Linux::package/module.py:87: Bad code\n"
|
||||
)
|
||||
|
||||
|
||||
def test_endtoend(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
import trio._tools.mypy_annotate as mypy_annotate
|
||||
|
||||
inp_text = """\
|
||||
Mypy begun
|
||||
trio/core.py:15: error: Bad types here [misc]
|
||||
trio/package/module.py:48:4:56:18: warn: Missing annotations [no-untyped-def]
|
||||
Found 3 errors in 29 files
|
||||
"""
|
||||
result_file = tmp_path / "dump.dat"
|
||||
assert not result_file.exists()
|
||||
with monkeypatch.context():
|
||||
monkeypatch.setattr(sys, "stdin", io.StringIO(inp_text))
|
||||
|
||||
mypy_annotate.main(
|
||||
["--dumpfile", str(result_file), "--platform", "SomePlatform"],
|
||||
)
|
||||
|
||||
std = capsys.readouterr()
|
||||
assert std.err == ""
|
||||
assert std.out == inp_text # Echos the original.
|
||||
|
||||
assert result_file.exists()
|
||||
|
||||
main(["--dumpfile", str(result_file)])
|
||||
|
||||
std = capsys.readouterr()
|
||||
assert std.err == ""
|
||||
assert std.out == (
|
||||
"::error file=trio/core.py,line=15,title=Mypy-SomePlatform::trio/core.py:15: Bad types here [misc]\n"
|
||||
"::warning file=trio/package/module.py,line=48,col=4,endLine=56,endColumn=18,"
|
||||
"title=Mypy-SomePlatform::trio/package/module.py:(48:4 - 56:18): Missing "
|
||||
"annotations [no-untyped-def]\n"
|
||||
)
|
||||
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trio._tests.pytest_plugin import skip_if_optional_else_raise
|
||||
|
||||
# imports in gen_exports that are not in `install_requires` in requirements
|
||||
try:
|
||||
import yaml # noqa: F401
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
from trio._tools.sync_requirements import (
|
||||
update_requirements,
|
||||
yield_pre_commit_version_data,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_yield_pre_commit_version_data() -> None:
|
||||
text = """
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.0
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 25.1.0
|
||||
- bad: data
|
||||
"""
|
||||
results = tuple(yield_pre_commit_version_data(text))
|
||||
assert results == (
|
||||
("ruff-pre-commit", "0.11.0"),
|
||||
("black-pre-commit-mirror", "25.1.0"),
|
||||
)
|
||||
|
||||
|
||||
def test_update_requirements(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
requirements_file = tmp_path / "requirements.txt"
|
||||
assert not requirements_file.exists()
|
||||
requirements_file.write_text(
|
||||
"""# comment
|
||||
# also comment but spaces line start
|
||||
waffles are delicious no equals
|
||||
black==3.1.4 ; specific version thingy
|
||||
mypy==1.15.0
|
||||
ruff==1.2.5
|
||||
# required by soupy cat""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
assert update_requirements(requirements_file, {"black": "3.1.5", "ruff": "1.2.7"})
|
||||
assert requirements_file.read_text(encoding="utf-8") == """# comment
|
||||
# also comment but spaces line start
|
||||
waffles are delicious no equals
|
||||
black==3.1.5 ; specific version thingy
|
||||
mypy==1.15.0
|
||||
ruff==1.2.7
|
||||
# required by soupy cat"""
|
||||
|
||||
|
||||
def test_update_requirements_no_changes(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
requirements_file = tmp_path / "requirements.txt"
|
||||
assert not requirements_file.exists()
|
||||
original = """# comment
|
||||
# also comment but spaces line start
|
||||
waffles are delicious no equals
|
||||
black==3.1.4 ; specific version thingy
|
||||
mypy==1.15.0
|
||||
ruff==1.2.5
|
||||
# required by soupy cat"""
|
||||
requirements_file.write_text(original, encoding="utf-8")
|
||||
assert not update_requirements(
|
||||
requirements_file, {"black": "3.1.4", "ruff": "1.2.5"}
|
||||
)
|
||||
assert requirements_file.read_text(encoding="utf-8") == original
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,9 @@
|
||||
# https://github.com/python-trio/trio/issues/2775#issuecomment-1702267589
|
||||
# (except platform independent...)
|
||||
import trio
|
||||
from typing_extensions import assert_type
|
||||
|
||||
|
||||
async def fn(s: trio.SocketStream) -> None:
|
||||
result = await s.socket.sendto(b"a", "h")
|
||||
assert_type(result, int)
|
||||
@@ -0,0 +1,4 @@
|
||||
# https://github.com/python-trio/trio/issues/2873
|
||||
import trio
|
||||
|
||||
s, r = trio.open_memory_channel[int](0)
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Path wrapping is quite complex, ensure all methods are understood as wrapped correctly."""
|
||||
|
||||
import io
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import IO, Any, BinaryIO
|
||||
|
||||
import trio
|
||||
from trio._file_io import AsyncIOWrapper
|
||||
from typing_extensions import assert_type
|
||||
|
||||
|
||||
def operator_checks(text: str, tpath: trio.Path, ppath: pathlib.Path) -> None:
|
||||
"""Verify operators produce the right results."""
|
||||
assert_type(tpath / ppath, trio.Path)
|
||||
assert_type(tpath / tpath, trio.Path)
|
||||
assert_type(tpath / text, trio.Path)
|
||||
assert_type(text / tpath, trio.Path)
|
||||
|
||||
assert_type(tpath > tpath, bool)
|
||||
assert_type(tpath >= tpath, bool)
|
||||
assert_type(tpath < tpath, bool)
|
||||
assert_type(tpath <= tpath, bool)
|
||||
|
||||
assert_type(tpath > ppath, bool)
|
||||
assert_type(tpath >= ppath, bool)
|
||||
assert_type(tpath < ppath, bool)
|
||||
assert_type(tpath <= ppath, bool)
|
||||
|
||||
assert_type(ppath > tpath, bool)
|
||||
assert_type(ppath >= tpath, bool)
|
||||
assert_type(ppath < tpath, bool)
|
||||
assert_type(ppath <= tpath, bool)
|
||||
|
||||
|
||||
def sync_attrs(path: trio.Path) -> None:
|
||||
assert_type(path.parts, tuple[str, ...])
|
||||
assert_type(path.drive, str)
|
||||
assert_type(path.root, str)
|
||||
assert_type(path.anchor, str)
|
||||
assert_type(path.parents[3], trio.Path)
|
||||
assert_type(path.parent, trio.Path)
|
||||
assert_type(path.name, str)
|
||||
assert_type(path.suffix, str)
|
||||
assert_type(path.suffixes, list[str])
|
||||
assert_type(path.stem, str)
|
||||
assert_type(path.as_posix(), str)
|
||||
assert_type(path.as_uri(), str)
|
||||
assert_type(path.is_absolute(), bool)
|
||||
assert_type(path.is_relative_to(path), bool)
|
||||
assert_type(path.is_reserved(), bool)
|
||||
assert_type(path.joinpath(path, "folder"), trio.Path)
|
||||
assert_type(path.match("*.py"), bool)
|
||||
assert_type(path.relative_to("/usr"), trio.Path)
|
||||
if sys.version_info >= (3, 12):
|
||||
assert_type(path.relative_to("/", walk_up=True), trio.Path)
|
||||
assert_type(path.with_name("filename.txt"), trio.Path)
|
||||
assert_type(path.with_stem("readme"), trio.Path)
|
||||
assert_type(path.with_suffix(".log"), trio.Path)
|
||||
|
||||
|
||||
async def async_attrs(path: trio.Path) -> None:
|
||||
assert_type(await trio.Path.cwd(), trio.Path)
|
||||
assert_type(await trio.Path.home(), trio.Path)
|
||||
assert_type(await path.stat(), os.stat_result)
|
||||
assert_type(await path.chmod(0o777), None)
|
||||
assert_type(await path.exists(), bool)
|
||||
assert_type(await path.expanduser(), trio.Path)
|
||||
for result in await path.glob("*.py"):
|
||||
assert_type(result, trio.Path)
|
||||
if sys.platform != "win32":
|
||||
assert_type(await path.group(), str)
|
||||
assert_type(await path.is_dir(), bool)
|
||||
assert_type(await path.is_file(), bool)
|
||||
if sys.version_info >= (3, 12):
|
||||
assert_type(await path.is_junction(), bool)
|
||||
if sys.platform != "win32":
|
||||
assert_type(await path.is_mount(), bool)
|
||||
assert_type(await path.is_symlink(), bool)
|
||||
assert_type(await path.is_socket(), bool)
|
||||
assert_type(await path.is_fifo(), bool)
|
||||
assert_type(await path.is_block_device(), bool)
|
||||
assert_type(await path.is_char_device(), bool)
|
||||
for child_iter in await path.iterdir():
|
||||
assert_type(child_iter, trio.Path)
|
||||
# TODO: Path.walk() in 3.12
|
||||
assert_type(await path.lchmod(0o111), None)
|
||||
assert_type(await path.lstat(), os.stat_result)
|
||||
assert_type(await path.mkdir(mode=0o777, parents=True, exist_ok=False), None)
|
||||
# Open done separately.
|
||||
if sys.platform != "win32":
|
||||
assert_type(await path.owner(), str)
|
||||
assert_type(await path.read_bytes(), bytes)
|
||||
assert_type(await path.read_text(encoding="utf16", errors="replace"), str)
|
||||
assert_type(await path.readlink(), trio.Path)
|
||||
assert_type(await path.rename("another"), trio.Path)
|
||||
assert_type(await path.replace(path), trio.Path)
|
||||
assert_type(await path.resolve(), trio.Path)
|
||||
for child_glob in await path.glob("*.py"):
|
||||
assert_type(child_glob, trio.Path)
|
||||
for child_rglob in await path.rglob("*.py"):
|
||||
assert_type(child_rglob, trio.Path)
|
||||
assert_type(await path.rmdir(), None)
|
||||
assert_type(await path.samefile("something_else"), bool)
|
||||
assert_type(await path.symlink_to("somewhere"), None)
|
||||
assert_type(await path.hardlink_to("elsewhere"), None)
|
||||
assert_type(await path.touch(), None)
|
||||
assert_type(await path.unlink(missing_ok=True), None)
|
||||
assert_type(await path.write_bytes(b"123"), int)
|
||||
assert_type(
|
||||
await path.write_text("hello", encoding="utf32le", errors="ignore"),
|
||||
int,
|
||||
)
|
||||
|
||||
|
||||
async def open_results(path: trio.Path, some_int: int, some_str: str) -> None:
|
||||
# Check the overloads.
|
||||
assert_type(await path.open(), AsyncIOWrapper[io.TextIOWrapper])
|
||||
assert_type(await path.open("r"), AsyncIOWrapper[io.TextIOWrapper])
|
||||
assert_type(await path.open("r+"), AsyncIOWrapper[io.TextIOWrapper])
|
||||
assert_type(await path.open("w"), AsyncIOWrapper[io.TextIOWrapper])
|
||||
assert_type(await path.open("rb", buffering=0), AsyncIOWrapper[io.FileIO])
|
||||
assert_type(await path.open("rb+"), AsyncIOWrapper[io.BufferedRandom])
|
||||
assert_type(await path.open("wb"), AsyncIOWrapper[io.BufferedWriter])
|
||||
assert_type(await path.open("rb"), AsyncIOWrapper[io.BufferedReader])
|
||||
assert_type(await path.open("rb", buffering=some_int), AsyncIOWrapper[BinaryIO])
|
||||
assert_type(await path.open(some_str), AsyncIOWrapper[IO[Any]])
|
||||
|
||||
# Check they produce the right types.
|
||||
file_bin = await path.open("rb+")
|
||||
assert_type(await file_bin.read(), bytes)
|
||||
assert_type(await file_bin.write(b"test"), int)
|
||||
assert_type(await file_bin.seek(32), int)
|
||||
|
||||
file_text = await path.open("r+t")
|
||||
assert_type(await file_text.read(), str)
|
||||
assert_type(await file_text.write("test"), int)
|
||||
# TODO: report mypy bug: equiv to https://github.com/microsoft/pyright/issues/6833
|
||||
assert_type(await file_text.readlines(), list[str])
|
||||
@@ -0,0 +1,23 @@
|
||||
import sys
|
||||
|
||||
import trio
|
||||
|
||||
|
||||
async def test() -> None:
|
||||
# this could test more by using platform checks, but currently this
|
||||
# is just regression tests + sanity checks.
|
||||
await trio.run_process("python", executable="ls")
|
||||
await trio.lowlevel.open_process("python", executable="ls")
|
||||
|
||||
# note: there's no error code on the type ignore as it varies
|
||||
# between platforms.
|
||||
await trio.run_process("python", capture_stdout=True)
|
||||
await trio.lowlevel.open_process("python", capture_stdout=True) # type: ignore
|
||||
|
||||
if sys.platform != "win32" and sys.version_info >= (3, 9):
|
||||
await trio.run_process("python", extra_groups=[5])
|
||||
await trio.lowlevel.open_process("python", extra_groups=[5])
|
||||
|
||||
# 3.11+:
|
||||
await trio.run_process("python", process_group=5) # type: ignore
|
||||
await trio.lowlevel.open_process("python", process_group=5) # type: ignore
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Check that started() can only be called for TaskStatus[None]."""
|
||||
|
||||
from trio import TaskStatus
|
||||
from typing_extensions import assert_type
|
||||
|
||||
|
||||
def check_status(
|
||||
none_status_explicit: TaskStatus[None],
|
||||
none_status_implicit: TaskStatus,
|
||||
int_status: TaskStatus[int],
|
||||
) -> None:
|
||||
assert_type(none_status_explicit, TaskStatus[None])
|
||||
assert_type(none_status_implicit, TaskStatus[None]) # Default typevar
|
||||
assert_type(int_status, TaskStatus[int])
|
||||
|
||||
# Omitting the parameter is only allowed for None.
|
||||
none_status_explicit.started()
|
||||
none_status_implicit.started()
|
||||
int_status.started() # type: ignore
|
||||
|
||||
# Explicit None is allowed.
|
||||
none_status_explicit.started(None)
|
||||
none_status_implicit.started(None)
|
||||
int_status.started(None) # type: ignore
|
||||
|
||||
none_status_explicit.started(42) # type: ignore
|
||||
none_status_implicit.started(42) # type: ignore
|
||||
int_status.started(42)
|
||||
int_status.started(True)
|
||||
Reference in New Issue
Block a user