initial commit

This commit is contained in:
2026-06-25 21:29:21 +00:00
commit 0d0a7456de
2738 changed files with 542622 additions and 0 deletions
@@ -0,0 +1,58 @@
# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625)
from .. import _deprecate as _deprecate
from .._core import (
MockClock as MockClock,
wait_all_tasks_blocked as wait_all_tasks_blocked,
)
from .._threads import (
active_thread_count as active_thread_count,
wait_all_threads_completed as wait_all_threads_completed,
)
from .._util import fixup_module_metadata
from ._check_streams import (
check_half_closeable_stream as check_half_closeable_stream,
check_one_way_stream as check_one_way_stream,
check_two_way_stream as check_two_way_stream,
)
from ._checkpoints import (
assert_checkpoints as assert_checkpoints,
assert_no_checkpoints as assert_no_checkpoints,
)
from ._memory_streams import (
MemoryReceiveStream as MemoryReceiveStream,
MemorySendStream as MemorySendStream,
lockstep_stream_one_way_pair as lockstep_stream_one_way_pair,
lockstep_stream_pair as lockstep_stream_pair,
memory_stream_one_way_pair as memory_stream_one_way_pair,
memory_stream_pair as memory_stream_pair,
memory_stream_pump as memory_stream_pump,
)
from ._network import open_stream_to_socket_listener as open_stream_to_socket_listener
from ._raises_group import Matcher as _Matcher, RaisesGroup as _RaisesGroup
from ._sequencer import Sequencer as Sequencer
from ._trio_test import trio_test as trio_test
################################################################
_deprecate.deprecate_attributes(
__name__,
{
"RaisesGroup": _deprecate.DeprecatedAttribute(
_RaisesGroup,
version="0.33.0",
issue=3326,
instead="See https://docs.pytest.org/en/stable/reference/reference.html#pytest.RaisesGroup",
),
"Matcher": _deprecate.DeprecatedAttribute(
_Matcher,
version="0.33.0",
issue=3326,
instead="See https://docs.pytest.org/en/stable/reference/reference.html#pytest.RaisesExc",
),
},
)
fixup_module_metadata(__name__, globals())
del fixup_module_metadata
@@ -0,0 +1,570 @@
# Generic stream tests
from __future__ import annotations
import random
import sys
from collections.abc import Awaitable, Callable, Generator
from contextlib import contextmanager, suppress
from typing import (
TYPE_CHECKING,
Generic,
TypeAlias,
TypeVar,
)
from .. import CancelScope, _core
from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream
from .._highlevel_generic import aclose_forcefully
from ._checkpoints import assert_checkpoints
if TYPE_CHECKING:
from types import TracebackType
from typing_extensions import ParamSpec
ArgsT = ParamSpec("ArgsT")
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
Res1 = TypeVar("Res1", bound=AsyncResource)
Res2 = TypeVar("Res2", bound=AsyncResource)
StreamMaker: TypeAlias = Callable[[], Awaitable[tuple[Res1, Res2]]]
class _ForceCloseBoth(Generic[Res1, Res2]):
def __init__(self, both: tuple[Res1, Res2]) -> None:
self._first, self._second = both
async def __aenter__(self) -> tuple[Res1, Res2]:
return self._first, self._second
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
try:
await aclose_forcefully(self._first)
finally:
await aclose_forcefully(self._second)
# This is used in this file instead of pytest.raises in order to avoid a dependency
# on pytest, as the check_* functions are publicly exported.
@contextmanager
def _assert_raises(
expected_exc: type[BaseException],
wrapped: bool = False,
) -> Generator[None, None, None]:
__tracebackhide__ = True
try:
yield
except BaseExceptionGroup as exc:
assert wrapped, "caught exceptiongroup, but expected an unwrapped exception"
# assert in except block ignored below
assert len(exc.exceptions) == 1 # noqa: PT017
assert isinstance(exc.exceptions[0], expected_exc) # noqa: PT017
except expected_exc:
assert not wrapped, "caught exception, but expected an exceptiongroup"
else:
raise AssertionError(f"expected exception: {expected_exc}")
async def check_one_way_stream(
stream_maker: StreamMaker[SendStream, ReceiveStream],
clogged_stream_maker: StreamMaker[SendStream, ReceiveStream] | None,
) -> None:
"""Perform a number of generic tests on a custom one-way stream
implementation.
Args:
stream_maker: An async (!) function which returns a connected
(:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`)
pair.
clogged_stream_maker: Either None, or an async function similar to
stream_maker, but with the extra property that the returned stream
is in a state where ``send_all`` and
``wait_send_all_might_not_block`` will block until ``receive_some``
has been called. This allows for more thorough testing of some edge
cases, especially around ``wait_send_all_might_not_block``.
Raises:
AssertionError: if a test fails.
"""
async with _ForceCloseBoth(await stream_maker()) as (s, r):
assert isinstance(s, SendStream)
assert isinstance(r, ReceiveStream)
async def do_send_all(data: bytes | bytearray | memoryview) -> None:
with assert_checkpoints(): # We're testing that it doesn't return anything.
assert await s.send_all(data) is None # type: ignore[func-returns-value]
async def do_receive_some(max_bytes: int | None = None) -> bytes | bytearray:
with assert_checkpoints():
return await r.receive_some(max_bytes)
async def checked_receive_1(expected: bytes) -> None:
assert await do_receive_some(1) == expected
async def do_aclose(resource: AsyncResource) -> None:
with assert_checkpoints():
await resource.aclose()
# Simple sending/receiving
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, b"x")
nursery.start_soon(checked_receive_1, b"x")
async def send_empty_then_y() -> None:
# Streams should tolerate sending b"" without giving it any
# special meaning.
await do_send_all(b"")
await do_send_all(b"y")
async with _core.open_nursery() as nursery:
nursery.start_soon(send_empty_then_y)
nursery.start_soon(checked_receive_1, b"y")
# ---- Checking various argument types ----
# send_all accepts bytearray and memoryview
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, bytearray(b"1"))
nursery.start_soon(checked_receive_1, b"1")
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, memoryview(b"2"))
nursery.start_soon(checked_receive_1, b"2")
# max_bytes must be a positive integer
with _assert_raises(ValueError):
await r.receive_some(-1)
with _assert_raises(ValueError):
await r.receive_some(0)
with _assert_raises(TypeError):
await r.receive_some(1.5) # type: ignore[arg-type]
# it can also be missing or None
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, b"x")
assert await do_receive_some() == b"x"
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, b"x")
assert await do_receive_some(None) == b"x"
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(do_receive_some, 1)
nursery.start_soon(do_receive_some, 1)
# Method always has to exist, and an empty stream with a blocked
# receive_some should *always* allow send_all. (Technically it's legal
# for send_all to wait until receive_some is called to run, though; a
# stream doesn't *have* to have any internal buffering. That's why we
# start a concurrent receive_some call, then cancel it.)
async def simple_check_wait_send_all_might_not_block(
scope: CancelScope,
) -> None:
with assert_checkpoints():
await s.wait_send_all_might_not_block()
scope.cancel()
async with _core.open_nursery() as nursery:
nursery.start_soon(
simple_check_wait_send_all_might_not_block,
nursery.cancel_scope,
)
nursery.start_soon(do_receive_some, 1)
# closing the r side leads to BrokenResourceError on the s side
# (eventually)
async def expect_broken_stream_on_send() -> None:
with _assert_raises(_core.BrokenResourceError):
while True:
await do_send_all(b"x" * 100)
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_broken_stream_on_send)
nursery.start_soon(do_aclose, r)
# once detected, the stream stays broken
with _assert_raises(_core.BrokenResourceError):
await do_send_all(b"x" * 100)
# r closed -> ClosedResourceError on the receive side
with _assert_raises(_core.ClosedResourceError):
await do_receive_some(4096)
# we can close the same stream repeatedly, it's fine
await do_aclose(r)
await do_aclose(r)
# closing the sender side
await do_aclose(s)
# now trying to send raises ClosedResourceError
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"x" * 100)
# even if it's an empty send
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"")
# ditto for wait_send_all_might_not_block
with _assert_raises(_core.ClosedResourceError):
with assert_checkpoints():
await s.wait_send_all_might_not_block()
# and again, repeated closing is fine
await do_aclose(s)
await do_aclose(s)
async with _ForceCloseBoth(await stream_maker()) as (s, r):
# if send-then-graceful-close, receiver gets data then b""
async def send_then_close() -> None:
await do_send_all(b"y")
await do_aclose(s)
async def receive_send_then_close() -> None:
# We want to make sure that if the sender closes the stream before
# we read anything, then we still get all the data. But some
# streams might block on the do_send_all call. So we let the
# sender get as far as it can, then we receive.
await _core.wait_all_tasks_blocked()
await checked_receive_1(b"y")
await checked_receive_1(b"")
await do_aclose(r)
async with _core.open_nursery() as nursery:
nursery.start_soon(send_then_close)
nursery.start_soon(receive_send_then_close)
async with _ForceCloseBoth(await stream_maker()) as (s, r):
await aclose_forcefully(r)
with _assert_raises(_core.BrokenResourceError):
while True:
await do_send_all(b"x" * 100)
with _assert_raises(_core.ClosedResourceError):
await do_receive_some(4096)
async with _ForceCloseBoth(await stream_maker()) as (s, r):
await aclose_forcefully(s)
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"123")
# after the sender does a forceful close, the receiver might either
# get BrokenResourceError or a clean b""; either is OK. Not OK would be
# if it freezes, or returns data.
with suppress(_core.BrokenResourceError):
await checked_receive_1(b"")
# cancelled aclose still closes
async with _ForceCloseBoth(await stream_maker()) as (s, r):
with _core.CancelScope() as scope:
scope.cancel()
await r.aclose()
with _core.CancelScope() as scope:
scope.cancel()
await s.aclose()
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"123")
with _assert_raises(_core.ClosedResourceError):
await do_receive_some(4096)
# Check that we can still gracefully close a stream after an operation has
# been cancelled. This can be challenging if cancellation can leave the
# stream internals in an inconsistent state, e.g. for
# SSLStream. Unfortunately this test isn't very thorough; the really
# challenging case for something like SSLStream is it gets cancelled
# *while* it's sending data on the underlying, not before. But testing
# that requires some special-case handling of the particular stream setup;
# we can't do it here. Maybe we could do a bit better with
# https://github.com/python-trio/trio/issues/77
async with _ForceCloseBoth(await stream_maker()) as (s, r):
async def expect_cancelled(
afn: Callable[ArgsT, Awaitable[object]],
*args: ArgsT.args,
**kwargs: ArgsT.kwargs,
) -> None:
with _assert_raises(_core.Cancelled):
await afn(*args, **kwargs)
with _core.CancelScope() as scope:
scope.cancel()
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_cancelled, do_send_all, b"x")
nursery.start_soon(expect_cancelled, do_receive_some, 1)
async with _core.open_nursery() as nursery:
nursery.start_soon(do_aclose, s)
nursery.start_soon(do_aclose, r)
# Check that if a task is blocked in receive_some, then closing the
# receive stream causes it to wake up.
async with _ForceCloseBoth(await stream_maker()) as (s, r):
async def receive_expecting_closed() -> None:
with _assert_raises(_core.ClosedResourceError):
await r.receive_some(10)
async with _core.open_nursery() as nursery:
nursery.start_soon(receive_expecting_closed)
await _core.wait_all_tasks_blocked()
await aclose_forcefully(r)
# check wait_send_all_might_not_block, if we can
if clogged_stream_maker is not None:
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
record: list[str] = []
async def waiter(cancel_scope: CancelScope) -> None:
record.append("waiter sleeping")
with assert_checkpoints():
await s.wait_send_all_might_not_block()
record.append("waiter wokeup")
cancel_scope.cancel()
async def receiver() -> None:
# give wait_send_all_might_not_block a chance to block
await _core.wait_all_tasks_blocked()
record.append("receiver starting")
while True:
await r.receive_some(16834)
async with _core.open_nursery() as nursery:
nursery.start_soon(waiter, nursery.cancel_scope)
await _core.wait_all_tasks_blocked()
nursery.start_soon(receiver)
assert record == [
"waiter sleeping",
"receiver starting",
"waiter wokeup",
]
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
# simultaneous wait_send_all_might_not_block fails
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(s.wait_send_all_might_not_block)
nursery.start_soon(s.wait_send_all_might_not_block)
# and simultaneous send_all and wait_send_all_might_not_block (NB
# this test might destroy the stream b/c we end up cancelling
# send_all and e.g. SSLStream can't handle that, so we have to
# recreate afterwards)
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(s.wait_send_all_might_not_block)
nursery.start_soon(s.send_all, b"123")
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
# send_all and send_all blocked simultaneously should also raise
# (but again this might destroy the stream)
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(s.send_all, b"123")
nursery.start_soon(s.send_all, b"123")
# closing the receiver causes wait_send_all_might_not_block to return,
# with or without an exception
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
async def sender() -> None:
try:
with assert_checkpoints():
await s.wait_send_all_might_not_block()
except _core.BrokenResourceError: # pragma: no cover
pass
async def receiver() -> None:
await _core.wait_all_tasks_blocked()
await aclose_forcefully(r)
async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(receiver)
# and again with the call starting after the close
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
await aclose_forcefully(r)
try:
with assert_checkpoints():
await s.wait_send_all_might_not_block()
except _core.BrokenResourceError: # pragma: no cover
pass
# Check that if a task is blocked in a send-side method, then closing
# the send stream causes it to wake up.
async def close_soon(s: SendStream) -> None:
await _core.wait_all_tasks_blocked()
await aclose_forcefully(s)
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
async with _core.open_nursery() as nursery:
nursery.start_soon(close_soon, s)
with _assert_raises(_core.ClosedResourceError):
await s.send_all(b"xyzzy")
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
async with _core.open_nursery() as nursery:
nursery.start_soon(close_soon, s)
with _assert_raises(_core.ClosedResourceError):
await s.wait_send_all_might_not_block()
async def check_two_way_stream(
stream_maker: StreamMaker[Stream, Stream],
clogged_stream_maker: StreamMaker[Stream, Stream] | None,
) -> None:
"""Perform a number of generic tests on a custom two-way stream
implementation.
This is similar to :func:`check_one_way_stream`, except that the maker
functions are expected to return objects implementing the
:class:`~trio.abc.Stream` interface.
This function tests a *superset* of what :func:`check_one_way_stream`
checks if you call this, then you don't need to also call
:func:`check_one_way_stream`.
"""
await check_one_way_stream(stream_maker, clogged_stream_maker)
async def flipped_stream_maker() -> tuple[Stream, Stream]:
return (await stream_maker())[::-1]
flipped_clogged_stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]] | None
if clogged_stream_maker is not None:
async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]:
return (await clogged_stream_maker())[::-1]
else:
flipped_clogged_stream_maker = None
await check_one_way_stream(flipped_stream_maker, flipped_clogged_stream_maker)
async with _ForceCloseBoth(await stream_maker()) as (s1, s2):
assert isinstance(s1, Stream)
assert isinstance(s2, Stream)
# Duplex can be a bit tricky, might as well check it as well
DUPLEX_TEST_SIZE = 2**20
CHUNK_SIZE_MAX = 2**14
r = random.Random(0)
i = r.getrandbits(8 * DUPLEX_TEST_SIZE)
test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little")
async def sender(
s: Stream,
data: bytes | bytearray | memoryview,
seed: int,
) -> None:
r = random.Random(seed)
m = memoryview(data)
while m:
chunk_size = r.randint(1, CHUNK_SIZE_MAX)
await s.send_all(m[:chunk_size])
m = m[chunk_size:]
async def receiver(s: Stream, data: bytes | bytearray, seed: int) -> None:
r = random.Random(seed)
got = bytearray()
while len(got) < len(data):
chunk = await s.receive_some(r.randint(1, CHUNK_SIZE_MAX))
assert chunk
got += chunk
assert got == data
async with _core.open_nursery() as nursery:
nursery.start_soon(sender, s1, test_data, 0)
nursery.start_soon(sender, s2, test_data[::-1], 1)
nursery.start_soon(receiver, s1, test_data[::-1], 2)
nursery.start_soon(receiver, s2, test_data, 3)
async def expect_receive_some_empty() -> None:
assert await s2.receive_some(10) == b""
await s2.aclose()
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_receive_some_empty)
nursery.start_soon(s1.aclose)
async def check_half_closeable_stream(
stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream],
clogged_stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream] | None,
) -> None:
"""Perform a number of generic tests on a custom half-closeable stream
implementation.
This is similar to :func:`check_two_way_stream`, except that the maker
functions are expected to return objects that implement the
:class:`~trio.abc.HalfCloseableStream` interface.
This function tests a *superset* of what :func:`check_two_way_stream`
checks if you call this, then you don't need to also call
:func:`check_two_way_stream`.
"""
await check_two_way_stream(stream_maker, clogged_stream_maker)
async with _ForceCloseBoth(await stream_maker()) as (s1, s2):
assert isinstance(s1, HalfCloseableStream)
assert isinstance(s2, HalfCloseableStream)
async def send_x_then_eof(s: HalfCloseableStream) -> None:
await s.send_all(b"x")
with assert_checkpoints():
await s.send_eof()
async def expect_x_then_eof(r: HalfCloseableStream) -> None:
await _core.wait_all_tasks_blocked()
assert await r.receive_some(10) == b"x"
assert await r.receive_some(10) == b""
async with _core.open_nursery() as nursery:
nursery.start_soon(send_x_then_eof, s1)
nursery.start_soon(expect_x_then_eof, s2)
# now sending is disallowed
with _assert_raises(_core.ClosedResourceError):
await s1.send_all(b"y")
# but we can do send_eof again
with assert_checkpoints():
await s1.send_eof()
# and we can still send stuff back the other way
async with _core.open_nursery() as nursery:
nursery.start_soon(send_x_then_eof, s2)
nursery.start_soon(expect_x_then_eof, s1)
if clogged_stream_maker is not None:
async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2):
# send_all and send_eof simultaneously is not ok
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(s1.send_all, b"x")
await _core.wait_all_tasks_blocked()
nursery.start_soon(s1.send_eof)
async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2):
# wait_send_all_might_not_block and send_eof simultaneously is not
# ok either
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(s1.wait_send_all_might_not_block)
await _core.wait_all_tasks_blocked()
nursery.start_soon(s1.send_eof)
@@ -0,0 +1,69 @@
from __future__ import annotations
from contextlib import AbstractContextManager, contextmanager
from typing import TYPE_CHECKING
from .. import _core
if TYPE_CHECKING:
from collections.abc import Generator
@contextmanager
def _assert_yields_or_not(expected: bool) -> Generator[None, None, None]:
"""Check if checkpoints are executed in a block of code."""
__tracebackhide__ = True
task = _core.current_task()
orig_cancel = task._cancel_points
orig_schedule = task._schedule_points
try:
yield
if expected and (
task._cancel_points == orig_cancel or task._schedule_points == orig_schedule
):
raise AssertionError("assert_checkpoints block did not yield!")
finally:
if not expected and (
task._cancel_points != orig_cancel or task._schedule_points != orig_schedule
):
raise AssertionError("assert_no_checkpoints block yielded!")
def assert_checkpoints() -> AbstractContextManager[None]:
"""Use as a context manager to check that the code inside the ``with``
block either exits with an exception or executes at least one
:ref:`checkpoint <checkpoints>`.
Raises:
AssertionError: if no checkpoint was executed.
Example:
Check that :func:`trio.sleep` is a checkpoint, even if it doesn't
block::
with trio.testing.assert_checkpoints():
await trio.sleep(0)
"""
__tracebackhide__ = True
return _assert_yields_or_not(True)
def assert_no_checkpoints() -> AbstractContextManager[None]:
"""Use as a context manager to check that the code inside the ``with``
block does not execute any :ref:`checkpoints <checkpoints>`.
Raises:
AssertionError: if a checkpoint was executed.
Example:
Synchronous code never contains any checkpoints, but we can double-check
that::
send_channel, receive_channel = trio.open_memory_channel(10)
with trio.testing.assert_no_checkpoints():
send_channel.send_nowait(None)
"""
__tracebackhide__ = True
return _assert_yields_or_not(False)
@@ -0,0 +1,584 @@
# This should eventually be cleaned up and become public, but for right now I'm just
# implementing enough to test DTLS.
# TODO:
# - user-defined routers
# - TCP
# - UDP broadcast
from __future__ import annotations
import contextlib
import errno
import ipaddress
import os
import socket
import sys
from typing import TYPE_CHECKING, Any, NoReturn, TypeAlias, overload
import attrs
import trio
from trio._util import NoPublicConstructor, final
if TYPE_CHECKING:
import builtins
from collections.abc import Iterable
from socket import AddressFamily, SocketKind
from types import TracebackType
from typing_extensions import Buffer, Self
from trio._socket import AddressFormat
IPAddress: TypeAlias = ipaddress.IPv4Address | ipaddress.IPv6Address
def _family_for(ip: IPAddress) -> int:
if isinstance(ip, ipaddress.IPv4Address):
return trio.socket.AF_INET
elif isinstance(ip, ipaddress.IPv6Address):
return trio.socket.AF_INET6
raise NotImplementedError("Unhandled IPAddress instance type") # pragma: no cover
def _wildcard_ip_for(family: int) -> IPAddress:
if family == trio.socket.AF_INET:
return ipaddress.ip_address("0.0.0.0")
elif family == trio.socket.AF_INET6:
return ipaddress.ip_address("::")
raise NotImplementedError("Unhandled ip address family") # pragma: no cover
# not used anywhere
def _localhost_ip_for(family: int) -> IPAddress: # pragma: no cover
if family == trio.socket.AF_INET:
return ipaddress.ip_address("127.0.0.1")
elif family == trio.socket.AF_INET6:
return ipaddress.ip_address("::1")
raise NotImplementedError("Unhandled ip address family")
def _fake_err(code: int) -> NoReturn:
raise OSError(code, os.strerror(code))
def _scatter(data: bytes, buffers: Iterable[Buffer]) -> int:
written = 0
for buf in buffers: # pragma: no branch
next_piece = data[written : written + memoryview(buf).nbytes]
with memoryview(buf) as mbuf:
mbuf[: len(next_piece)] = next_piece
written += len(next_piece)
if written == len(data): # pragma: no branch
break
return written
@attrs.frozen
class UDPEndpoint:
ip: IPAddress
port: int
def as_python_sockaddr(self) -> tuple[str, int] | tuple[str, int, int, int]:
sockaddr: tuple[str, int] | tuple[str, int, int, int] = (
self.ip.compressed,
self.port,
)
if isinstance(self.ip, ipaddress.IPv6Address):
sockaddr += (0, 0) # type: ignore[assignment]
return sockaddr
@classmethod
def from_python_sockaddr(
cls,
sockaddr: tuple[str, int] | tuple[str, int, int, int],
) -> UDPEndpoint:
ip, port = sockaddr[:2]
return cls(ip=ipaddress.ip_address(ip), port=port)
@attrs.frozen
class UDPBinding:
local: UDPEndpoint
# remote: UDPEndpoint # ??
@attrs.frozen
class UDPPacket:
source: UDPEndpoint
destination: UDPEndpoint
payload: bytes = attrs.field(repr=lambda p: p.hex())
# not used/tested anywhere
def reply(self, payload: bytes) -> UDPPacket: # pragma: no cover
return UDPPacket(
source=self.destination,
destination=self.source,
payload=payload,
)
@attrs.frozen
class FakeSocketFactory(trio.abc.SocketFactory):
fake_net: FakeNet
def socket(self, family: int, type_: int, proto: int) -> FakeSocket: # type: ignore[override]
return FakeSocket._create(self.fake_net, family, type_, proto)
@attrs.frozen
class FakeHostnameResolver(trio.abc.HostnameResolver):
fake_net: FakeNet
async def getaddrinfo(
self,
host: bytes | None,
port: bytes | str | int | None,
family: int = 0,
type: int = 0,
proto: int = 0,
flags: int = 0,
) -> list[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
]
]:
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
async def getnameinfo(
self,
sockaddr: tuple[str, int] | tuple[str, int, int, int],
flags: int,
) -> tuple[str, str]:
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
@final
class FakeNet:
def __init__(self) -> None:
# When we need to pick an arbitrary unique ip address/port, use these:
self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() # untested
self._auto_ipv6_iter = ipaddress.IPv6Network("1::/16").hosts() # untested
self._auto_port_iter = iter(range(50000, 65535))
self._bound: dict[UDPBinding, FakeSocket] = {}
self.route_packet = None
def _bind(self, binding: UDPBinding, socket: FakeSocket) -> None:
if binding in self._bound:
_fake_err(errno.EADDRINUSE)
self._bound[binding] = socket
def enable(self) -> None:
trio.socket.set_custom_socket_factory(FakeSocketFactory(self))
trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self))
def send_packet(self, packet: UDPPacket) -> None:
if self.route_packet is None:
self.deliver_packet(packet)
else:
self.route_packet(packet)
def deliver_packet(self, packet: UDPPacket) -> None:
binding = UDPBinding(local=packet.destination)
if binding in self._bound:
self._bound[binding]._deliver_packet(packet)
else:
# No valid destination, so drop it
pass
@final
class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor):
def __init__(
self,
fake_net: FakeNet,
family: AddressFamily,
type: SocketKind,
proto: int,
) -> None:
self._fake_net = fake_net
if not family: # pragma: no cover
family = trio.socket.AF_INET
if not type: # pragma: no cover
type = trio.socket.SOCK_STREAM # noqa: A001 # name shadowing builtin
if family not in (trio.socket.AF_INET, trio.socket.AF_INET6):
raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}")
if type != trio.socket.SOCK_DGRAM:
raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}")
self._family = family
self._type = type
self._proto = proto
self._closed = False
self._packet_sender, self._packet_receiver = trio.open_memory_channel[
UDPPacket
](float("inf"))
# This is the source-of-truth for what port etc. this socket is bound to
self._binding: UDPBinding | None = None
@property
def type(self) -> SocketKind:
return self._type
@property
def family(self) -> AddressFamily:
return self._family
@property
def proto(self) -> int:
return self._proto
def _check_closed(self) -> None:
if self._closed:
_fake_err(errno.EBADF)
def close(self) -> None:
if self._closed:
return
self._closed = True
if self._binding is not None:
del self._fake_net._bound[self._binding]
self._packet_receiver.close()
async def _resolve_address_nocp(
self,
address: object,
*,
local: bool,
) -> tuple[str, int]:
return await trio._socket._resolve_address_nocp( # type: ignore[no-any-return]
self.type,
self.family,
self.proto,
address=address,
ipv6_v6only=False,
local=local,
)
def _deliver_packet(self, packet: UDPPacket) -> None:
# sending to a closed socket -- UDP packets get dropped
with contextlib.suppress(trio.BrokenResourceError):
self._packet_sender.send_nowait(packet)
################################################################
# Actual IO operation implementations
################################################################
async def bind(self, addr: object) -> None:
self._check_closed()
if self._binding is not None:
_fake_err(errno.EINVAL)
await trio.lowlevel.checkpoint()
ip_str, port, *_ = await self._resolve_address_nocp(addr, local=True)
assert _ == [], "TODO: handle other values?"
ip = ipaddress.ip_address(ip_str)
assert _family_for(ip) == self.family
# We convert binds to INET_ANY into binds to localhost
if ip == ipaddress.ip_address("0.0.0.0"):
ip = ipaddress.ip_address("127.0.0.1")
elif ip == ipaddress.ip_address("::"):
ip = ipaddress.ip_address("::1")
if port == 0:
port = next(self._fake_net._auto_port_iter)
binding = UDPBinding(local=UDPEndpoint(ip, port))
self._fake_net._bind(binding, self)
self._binding = binding
async def connect(self, peer: object) -> NoReturn:
raise NotImplementedError("FakeNet does not (yet) support connected sockets")
async def _sendmsg(
self,
buffers: Iterable[Buffer],
ancdata: Iterable[tuple[int, int, Buffer]] = (),
flags: int = 0,
address: AddressFormat | None = None,
) -> int:
self._check_closed()
await trio.lowlevel.checkpoint()
if address is not None:
address = await self._resolve_address_nocp(address, local=False)
if ancdata:
raise NotImplementedError("FakeNet doesn't support ancillary data")
if flags:
raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}")
if address is None:
_fake_err(errno.ENOTCONN)
destination = UDPEndpoint.from_python_sockaddr(address)
if self._binding is None:
await self.bind((_wildcard_ip_for(self.family).compressed, 0))
payload = b"".join(buffers)
assert self._binding is not None
packet = UDPPacket(
source=self._binding.local,
destination=destination,
payload=payload,
)
self._fake_net.send_packet(packet)
return len(payload)
if sys.platform != "win32" or (
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
):
sendmsg = _sendmsg
async def _recvmsg_into(
self,
buffers: Iterable[Buffer],
ancbufsize: int = 0,
flags: int = 0,
) -> tuple[
int,
list[tuple[int, int, bytes]],
int,
tuple[str, int] | tuple[str, int, int, int],
]:
if ancbufsize != 0:
raise NotImplementedError("FakeNet doesn't support ancillary data")
if flags != 0:
raise NotImplementedError("FakeNet doesn't support any recv flags")
if self._binding is None:
# I messed this up a few times when writing tests ... but it also never happens
# in any of the existing tests, so maybe it could be intentional...
raise NotImplementedError(
"The code will most likely hang if you try to receive on a fakesocket "
"without a binding. If that is not the case, or you explicitly want to "
"test that, remove this warning.",
)
self._check_closed()
ancdata: list[tuple[int, int, bytes]] = []
msg_flags = 0
packet = await self._packet_receiver.receive()
address = packet.source.as_python_sockaddr()
written = _scatter(packet.payload, buffers)
if written < len(packet.payload):
msg_flags |= trio.socket.MSG_TRUNC
return written, ancdata, msg_flags, address
if sys.platform != "win32" or (
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
):
recvmsg_into = _recvmsg_into
################################################################
# Simple state query stuff
################################################################
def getsockname(self) -> tuple[str, int] | tuple[str, int, int, int]:
self._check_closed()
if self._binding is not None:
return self._binding.local.as_python_sockaddr()
elif self.family == trio.socket.AF_INET:
return ("0.0.0.0", 0)
else:
assert self.family == trio.socket.AF_INET6
return ("::", 0)
# TODO: This method is not tested, and seems to make incorrect assumptions. It should maybe raise NotImplementedError.
def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]:
self._check_closed()
if self._binding is not None:
assert hasattr(
self._binding,
"remote",
), "This method seems to assume that self._binding has a remote UDPEndpoint"
if self._binding.remote is not None: # pragma: no cover
assert isinstance(
self._binding.remote,
UDPEndpoint,
), "Self._binding.remote should be a UDPEndpoint"
return self._binding.remote.as_python_sockaddr()
_fake_err(errno.ENOTCONN)
@overload
def getsockopt(self, /, level: int, optname: int) -> int: ...
@overload
def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ...
def getsockopt(
self,
/,
level: int,
optname: int,
buflen: int | None = None,
) -> int | bytes:
self._check_closed()
raise OSError(f"FakeNet doesn't implement getsockopt({level}, {optname})")
@overload
def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ...
@overload
def setsockopt(
self,
/,
level: int,
optname: int,
value: None,
optlen: int,
) -> None: ...
def setsockopt(
self,
/,
level: int,
optname: int,
value: int | Buffer | None,
optlen: int | None = None,
) -> None:
self._check_closed()
if (level, optname) == (
trio.socket.IPPROTO_IPV6,
trio.socket.IPV6_V6ONLY,
) and not value:
raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True")
raise OSError(f"FakeNet doesn't implement setsockopt({level}, {optname}, ...)")
################################################################
# Various boilerplate and trivial stubs
################################################################
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: builtins.type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.close()
async def send(self, data: Buffer, flags: int = 0) -> int:
return await self.sendto(data, flags, None)
# __ prefixed arguments because typeshed uses that and typechecker issues
@overload
async def sendto(
self,
__data: Buffer, # noqa: PYI063
__address: tuple[object, ...] | str | Buffer,
) -> int: ...
# __ prefixed arguments because typeshed uses that and typechecker issues
@overload
async def sendto(
self,
__data: Buffer, # noqa: PYI063
__flags: int,
__address: tuple[object, ...] | str | Buffer | None,
) -> int: ...
async def sendto( # type: ignore[explicit-any]
self,
*args: Any,
) -> int:
data: Buffer
flags: int
address: tuple[object, ...] | str | Buffer
if len(args) == 2:
data, address = args
flags = 0
elif len(args) == 3:
data, flags, address = args
else:
raise TypeError("wrong number of arguments")
return await self._sendmsg([data], [], flags, address)
async def recv(self, bufsize: int, flags: int = 0) -> bytes:
data, _address = await self.recvfrom(bufsize, flags)
return data
async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int:
got_bytes, _address = await self.recvfrom_into(buf, nbytes, flags)
return got_bytes
async def recvfrom(
self,
bufsize: int,
flags: int = 0,
) -> tuple[bytes, AddressFormat]:
data, _ancdata, _msg_flags, address = await self._recvmsg(bufsize, flags)
return data, address
async def recvfrom_into(
self,
buf: Buffer,
nbytes: int = 0,
flags: int = 0,
) -> tuple[int, AddressFormat]:
if nbytes != 0 and nbytes != memoryview(buf).nbytes:
raise NotImplementedError("partial recvfrom_into")
got_nbytes, _ancdata, _msg_flags, address = await self._recvmsg_into(
[buf],
0,
flags,
)
return got_nbytes, address
async def _recvmsg(
self,
bufsize: int,
ancbufsize: int = 0,
flags: int = 0,
) -> tuple[bytes, list[tuple[int, int, bytes]], int, AddressFormat]:
buf = bytearray(bufsize)
got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into(
[buf],
ancbufsize,
flags,
)
return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address)
if sys.platform != "win32" or (
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
):
recvmsg = _recvmsg
def fileno(self) -> int:
raise NotImplementedError("can't get fileno() for FakeNet sockets")
def detach(self) -> int:
raise NotImplementedError("can't detach() a FakeNet socket")
def get_inheritable(self) -> bool:
return False
def set_inheritable(self, inheritable: bool) -> None:
if inheritable:
raise NotImplementedError("FakeNet can't make inheritable sockets")
if sys.platform == "win32" or (
not TYPE_CHECKING and hasattr(socket.socket, "share")
):
def share(self, process_id: int) -> bytes:
raise NotImplementedError("FakeNet can't share sockets")
@@ -0,0 +1,633 @@
from __future__ import annotations
import operator
from collections.abc import Awaitable, Callable
from typing import TypeAlias, TypeVar
from .. import _core, _util
from .._highlevel_generic import StapledStream
from ..abc import ReceiveStream, SendStream
AsyncHook: TypeAlias = Callable[[], Awaitable[object]]
# Would be nice to exclude awaitable here, but currently not possible.
SyncHook: TypeAlias = Callable[[], object]
SendStreamT = TypeVar("SendStreamT", bound=SendStream)
ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream)
################################################################
# In-memory streams - Unbounded buffer version
################################################################
class _UnboundedByteQueue:
def __init__(self) -> None:
self._data = bytearray()
self._closed = False
self._lot = _core.ParkingLot()
self._fetch_lock = _util.ConflictDetector(
"another task is already fetching data",
)
# This object treats "close" as being like closing the send side of a
# channel: so after close(), calling put() raises ClosedResourceError, and
# calling the get() variants drains the buffer and then returns an empty
# bytearray.
def close(self) -> None:
self._closed = True
self._lot.unpark_all()
def close_and_wipe(self) -> None:
self._data = bytearray()
self.close()
def put(self, data: bytes | bytearray | memoryview) -> None:
if self._closed:
raise _core.ClosedResourceError("virtual connection closed")
self._data += data
self._lot.unpark_all()
def _check_max_bytes(self, max_bytes: int | None) -> None:
if max_bytes is None:
return
max_bytes = operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
def _get_impl(self, max_bytes: int | None) -> bytearray:
assert self._closed or self._data
if max_bytes is None:
max_bytes = len(self._data)
if self._data:
chunk = self._data[:max_bytes]
del self._data[:max_bytes]
assert chunk
return chunk
else:
return bytearray()
def get_nowait(self, max_bytes: int | None = None) -> bytearray:
with self._fetch_lock:
self._check_max_bytes(max_bytes)
if not self._closed and not self._data:
raise _core.WouldBlock
return self._get_impl(max_bytes)
async def get(self, max_bytes: int | None = None) -> bytearray:
with self._fetch_lock:
self._check_max_bytes(max_bytes)
if not self._closed and not self._data:
await self._lot.park()
else:
await _core.checkpoint()
return self._get_impl(max_bytes)
@_util.final
class MemorySendStream(SendStream):
"""An in-memory :class:`~trio.abc.SendStream`.
Args:
send_all_hook: An async function, or None. Called from
:meth:`send_all`. Can do whatever you like.
wait_send_all_might_not_block_hook: An async function, or None. Called
from :meth:`wait_send_all_might_not_block`. Can do whatever you
like.
close_hook: A synchronous function, or None. Called from :meth:`close`
and :meth:`aclose`. Can do whatever you like.
.. attribute:: send_all_hook
wait_send_all_might_not_block_hook
close_hook
All of these hooks are also exposed as attributes on the object, and
you can change them at any time.
"""
def __init__(
self,
send_all_hook: AsyncHook | None = None,
wait_send_all_might_not_block_hook: AsyncHook | None = None,
close_hook: SyncHook | None = None,
) -> None:
self._conflict_detector = _util.ConflictDetector(
"another task is using this stream",
)
self._outgoing = _UnboundedByteQueue()
self.send_all_hook = send_all_hook
self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook
self.close_hook = close_hook
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
"""Places the given data into the object's internal buffer, and then
calls the :attr:`send_all_hook` (if any).
"""
# Execute two checkpoints so we have more of a chance to detect
# buggy user code that calls this twice at the same time.
with self._conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
self._outgoing.put(data)
if self.send_all_hook is not None:
await self.send_all_hook()
async def wait_send_all_might_not_block(self) -> None:
"""Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and
then returns immediately.
"""
# Execute two checkpoints so that we have more of a chance to detect
# buggy user code that calls this twice at the same time.
with self._conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
# check for being closed:
self._outgoing.put(b"")
if self.wait_send_all_might_not_block_hook is not None:
await self.wait_send_all_might_not_block_hook()
def close(self) -> None:
"""Marks this stream as closed, and then calls the :attr:`close_hook`
(if any).
"""
# XXX should this cancel any pending calls to the send_all_hook and
# wait_send_all_might_not_block_hook? Those are the only places where
# send_all and wait_send_all_might_not_block can be blocked.
#
# The way we set things up, send_all_hook is memory_stream_pump, and
# wait_send_all_might_not_block_hook is unset. memory_stream_pump is
# synchronous. So normally, send_all and wait_send_all_might_not_block
# cannot block at all.
self._outgoing.close()
if self.close_hook is not None:
self.close_hook()
async def aclose(self) -> None:
"""Same as :meth:`close`, but async."""
self.close()
await _core.checkpoint()
async def get_data(self, max_bytes: int | None = None) -> bytearray:
"""Retrieves data from the internal buffer, blocking if necessary.
Args:
max_bytes (int or None): The maximum amount of data to
retrieve. None (the default) means to retrieve all the data
that's present (but still blocks until at least one byte is
available).
Returns:
If this stream has been closed, an empty bytearray. Otherwise, the
requested data.
"""
return await self._outgoing.get(max_bytes)
def get_data_nowait(self, max_bytes: int | None = None) -> bytearray:
"""Retrieves data from the internal buffer, but doesn't block.
See :meth:`get_data` for details.
Raises:
trio.WouldBlock: if no data is available to retrieve.
"""
return self._outgoing.get_nowait(max_bytes)
@_util.final
class MemoryReceiveStream(ReceiveStream):
"""An in-memory :class:`~trio.abc.ReceiveStream`.
Args:
receive_some_hook: An async function, or None. Called from
:meth:`receive_some`. Can do whatever you like.
close_hook: A synchronous function, or None. Called from :meth:`close`
and :meth:`aclose`. Can do whatever you like.
.. attribute:: receive_some_hook
close_hook
Both hooks are also exposed as attributes on the object, and you can
change them at any time.
"""
def __init__(
self,
receive_some_hook: AsyncHook | None = None,
close_hook: SyncHook | None = None,
) -> None:
self._conflict_detector = _util.ConflictDetector(
"another task is using this stream",
)
self._incoming = _UnboundedByteQueue()
self._closed = False
self.receive_some_hook = receive_some_hook
self.close_hook = close_hook
async def receive_some(self, max_bytes: int | None = None) -> bytearray:
"""Calls the :attr:`receive_some_hook` (if any), and then retrieves
data from the internal buffer, blocking if necessary.
"""
# Execute two checkpoints so we have more of a chance to detect
# buggy user code that calls this twice at the same time.
with self._conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
if self._closed:
raise _core.ClosedResourceError
if self.receive_some_hook is not None:
await self.receive_some_hook()
# self._incoming's closure state tracks whether we got an EOF.
# self._closed tracks whether we, ourselves, are closed.
# self.close() sends an EOF to wake us up and sets self._closed,
# so after we wake up we have to check self._closed again.
data = await self._incoming.get(max_bytes)
if self._closed:
raise _core.ClosedResourceError
return data
def close(self) -> None:
"""Discards any pending data from the internal buffer, and marks this
stream as closed.
"""
self._closed = True
self._incoming.close_and_wipe()
if self.close_hook is not None:
self.close_hook()
async def aclose(self) -> None:
"""Same as :meth:`close`, but async."""
self.close()
await _core.checkpoint()
def put_data(self, data: bytes | bytearray | memoryview) -> None:
"""Appends the given data to the internal buffer."""
self._incoming.put(data)
def put_eof(self) -> None:
"""Adds an end-of-file marker to the internal buffer."""
self._incoming.close()
# TODO: investigate why this is necessary for the docs
MemorySendStream.__module__ = MemorySendStream.__module__.replace(
"._memory_streams", ""
)
MemoryReceiveStream.__module__ = MemoryReceiveStream.__module__.replace(
"._memory_streams", ""
)
def memory_stream_pump(
memory_send_stream: MemorySendStream,
memory_receive_stream: MemoryReceiveStream,
*,
max_bytes: int | None = None,
) -> bool:
"""Take data out of the given :class:`MemorySendStream`'s internal buffer,
and put it into the given :class:`MemoryReceiveStream`'s internal buffer.
Args:
memory_send_stream (MemorySendStream): The stream to get data from.
memory_receive_stream (MemoryReceiveStream): The stream to put data into.
max_bytes (int or None): The maximum amount of data to transfer in this
call, or None to transfer all available data.
Returns:
True if it successfully transferred some data, or False if there was no
data to transfer.
This is used to implement :func:`memory_stream_one_way_pair` and
:func:`memory_stream_pair`; see the latter's docstring for an example
of how you might use it yourself.
"""
try:
data = memory_send_stream.get_data_nowait(max_bytes)
except _core.WouldBlock:
return False
try:
if not data:
memory_receive_stream.put_eof()
else:
memory_receive_stream.put_data(data)
except _core.ClosedResourceError:
raise _core.BrokenResourceError("MemoryReceiveStream was closed") from None
return True
def memory_stream_one_way_pair() -> tuple[MemorySendStream, MemoryReceiveStream]:
"""Create a connected, pure-Python, unidirectional stream with infinite
buffering and flexible configuration options.
You can think of this as being a no-operating-system-involved
Trio-streamsified version of :func:`os.pipe` (except that :func:`os.pipe`
returns the streams in the wrong order we follow the superior convention
that data flows from left to right).
Returns:
A tuple (:class:`MemorySendStream`, :class:`MemoryReceiveStream`), where
the :class:`MemorySendStream` has its hooks set up so that it calls
:func:`memory_stream_pump` from its
:attr:`~MemorySendStream.send_all_hook` and
:attr:`~MemorySendStream.close_hook`.
The end result is that data automatically flows from the
:class:`MemorySendStream` to the :class:`MemoryReceiveStream`. But you're
also free to rearrange things however you like. For example, you can
temporarily set the :attr:`~MemorySendStream.send_all_hook` to None if you
want to simulate a stall in data transmission. Or see
:func:`memory_stream_pair` for a more elaborate example.
"""
send_stream = MemorySendStream()
recv_stream = MemoryReceiveStream()
def pump_from_send_stream_to_recv_stream() -> None:
memory_stream_pump(send_stream, recv_stream)
# await not used
async def async_pump_from_send_stream_to_recv_stream() -> None: # noqa: RUF029
pump_from_send_stream_to_recv_stream()
send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream
send_stream.close_hook = pump_from_send_stream_to_recv_stream
return send_stream, recv_stream
def _make_stapled_pair(
one_way_pair: Callable[[], tuple[SendStreamT, ReceiveStreamT]],
) -> tuple[
StapledStream[SendStreamT, ReceiveStreamT],
StapledStream[SendStreamT, ReceiveStreamT],
]:
pipe1_send, pipe1_recv = one_way_pair()
pipe2_send, pipe2_recv = one_way_pair()
stream1 = StapledStream(pipe1_send, pipe2_recv)
stream2 = StapledStream(pipe2_send, pipe1_recv)
return stream1, stream2
def memory_stream_pair() -> tuple[
StapledStream[MemorySendStream, MemoryReceiveStream],
StapledStream[MemorySendStream, MemoryReceiveStream],
]:
"""Create a connected, pure-Python, bidirectional stream with infinite
buffering and flexible configuration options.
This is a convenience function that creates two one-way streams using
:func:`memory_stream_one_way_pair`, and then uses
:class:`~trio.StapledStream` to combine them into a single bidirectional
stream.
This is like a no-operating-system-involved, Trio-streamsified version of
:func:`socket.socketpair`.
Returns:
A pair of :class:`~trio.StapledStream` objects that are connected so
that data automatically flows from one to the other in both directions.
After creating a stream pair, you can send data back and forth, which is
enough for simple tests::
left, right = memory_stream_pair()
await left.send_all(b"123")
assert await right.receive_some() == b"123"
await right.send_all(b"456")
assert await left.receive_some() == b"456"
But if you read the docs for :class:`~trio.StapledStream` and
:func:`memory_stream_one_way_pair`, you'll see that all the pieces
involved in wiring this up are public APIs, so you can adjust to suit the
requirements of your tests. For example, here's how to tweak a stream so
that data flowing from left to right trickles in one byte at a time (but
data flowing from right to left proceeds at full speed)::
left, right = memory_stream_pair()
async def trickle():
# left is a StapledStream, and left.send_stream is a MemorySendStream
# right is a StapledStream, and right.recv_stream is a MemoryReceiveStream
while memory_stream_pump(left.send_stream, right.recv_stream, max_bytes=1):
# Pause between each byte
await trio.sleep(1)
# Normally this send_all_hook calls memory_stream_pump directly without
# passing in a max_bytes. We replace it with our custom version:
left.send_stream.send_all_hook = trickle
And here's a simple test using our modified stream objects::
async def sender():
await left.send_all(b"12345")
await left.send_eof()
async def receiver():
async for data in right:
print(data)
async with trio.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(receiver)
By default, this will print ``b"12345"`` and then immediately exit; with
our trickle stream it instead sleeps 1 second, then prints ``b"1"``, then
sleeps 1 second, then prints ``b"2"``, etc.
Pro-tip: you can insert sleep calls (like in our example above) to
manipulate the flow of data across tasks... and then use
:class:`MockClock` and its :attr:`~MockClock.autojump_threshold`
functionality to keep your test suite running quickly.
If you want to stress test a protocol implementation, one nice trick is to
use the :mod:`random` module (preferably with a fixed seed) to move random
numbers of bytes at a time, and insert random sleeps in between them. You
can also set up a custom :attr:`~MemoryReceiveStream.receive_some_hook` if
you want to manipulate things on the receiving side, and not just the
sending side.
"""
return _make_stapled_pair(memory_stream_one_way_pair)
################################################################
# In-memory streams - Lockstep version
################################################################
class _LockstepByteQueue:
def __init__(self) -> None:
self._data = bytearray()
self._sender_closed = False
self._receiver_closed = False
self._receiver_waiting = False
self._waiters = _core.ParkingLot()
self._send_conflict_detector = _util.ConflictDetector(
"another task is already sending",
)
self._receive_conflict_detector = _util.ConflictDetector(
"another task is already receiving",
)
def _something_happened(self) -> None:
self._waiters.unpark_all()
# Always wakes up when one side is closed, because everyone always reacts
# to that.
async def _wait_for(self, fn: Callable[[], bool]) -> None:
while True:
if fn():
break
if self._sender_closed or self._receiver_closed:
break
await self._waiters.park()
await _core.checkpoint()
def close_sender(self) -> None:
self._sender_closed = True
self._something_happened()
def close_receiver(self) -> None:
self._receiver_closed = True
self._something_happened()
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
with self._send_conflict_detector:
if self._sender_closed:
raise _core.ClosedResourceError
if self._receiver_closed:
raise _core.BrokenResourceError
assert not self._data
self._data += data
self._something_happened()
await self._wait_for(lambda: self._data == b"")
if self._sender_closed:
raise _core.ClosedResourceError
if self._data and self._receiver_closed:
raise _core.BrokenResourceError
async def wait_send_all_might_not_block(self) -> None:
with self._send_conflict_detector:
if self._sender_closed:
raise _core.ClosedResourceError
if self._receiver_closed:
await _core.checkpoint()
return
await self._wait_for(lambda: self._receiver_waiting)
if self._sender_closed:
raise _core.ClosedResourceError
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
with self._receive_conflict_detector:
# Argument validation
if max_bytes is not None:
max_bytes = operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
# State validation
if self._receiver_closed:
raise _core.ClosedResourceError
# Wake wait_send_all_might_not_block and wait for data
self._receiver_waiting = True
self._something_happened()
try:
await self._wait_for(lambda: self._data != b"")
finally:
self._receiver_waiting = False
if self._receiver_closed:
raise _core.ClosedResourceError
# Get data, possibly waking send_all
if self._data:
# Neat trick: if max_bytes is None, then obj[:max_bytes] is
# the same as obj[:].
got = self._data[:max_bytes]
del self._data[:max_bytes]
self._something_happened()
return got
else:
assert self._sender_closed
return b""
class _LockstepSendStream(SendStream):
def __init__(self, lbq: _LockstepByteQueue) -> None:
self._lbq = lbq
def close(self) -> None:
self._lbq.close_sender()
async def aclose(self) -> None:
self.close()
await _core.checkpoint()
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
await self._lbq.send_all(data)
async def wait_send_all_might_not_block(self) -> None:
await self._lbq.wait_send_all_might_not_block()
class _LockstepReceiveStream(ReceiveStream):
def __init__(self, lbq: _LockstepByteQueue) -> None:
self._lbq = lbq
def close(self) -> None:
self._lbq.close_receiver()
async def aclose(self) -> None:
self.close()
await _core.checkpoint()
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
return await self._lbq.receive_some(max_bytes)
def lockstep_stream_one_way_pair() -> tuple[SendStream, ReceiveStream]:
"""Create a connected, pure Python, unidirectional stream where data flows
in lockstep.
Returns:
A tuple
(:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`).
This stream has *absolutely no* buffering. Each call to
:meth:`~trio.abc.SendStream.send_all` will block until all the given data
has been returned by a call to
:meth:`~trio.abc.ReceiveStream.receive_some`.
This can be useful for testing flow control mechanisms in an extreme case,
or for setting up "clogged" streams to use with
:func:`check_one_way_stream` and friends.
In addition to fulfilling the :class:`~trio.abc.SendStream` and
:class:`~trio.abc.ReceiveStream` interfaces, the return objects
also have a synchronous ``close`` method.
"""
lbq = _LockstepByteQueue()
return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq)
def lockstep_stream_pair() -> tuple[
StapledStream[SendStream, ReceiveStream],
StapledStream[SendStream, ReceiveStream],
]:
"""Create a connected, pure-Python, bidirectional stream where data flows
in lockstep.
Returns:
A tuple (:class:`~trio.StapledStream`, :class:`~trio.StapledStream`).
This is a convenience function that creates two one-way streams using
:func:`lockstep_stream_one_way_pair`, and then uses
:class:`~trio.StapledStream` to combine them into a single bidirectional
stream.
"""
return _make_stapled_pair(lockstep_stream_one_way_pair)
@@ -0,0 +1,36 @@
from .. import socket as tsocket
from .._highlevel_socket import SocketListener, SocketStream
async def open_stream_to_socket_listener(
socket_listener: SocketListener,
) -> SocketStream:
"""Connect to the given :class:`~trio.SocketListener`.
This is particularly useful in tests when you want to let a server pick
its own port, and then connect to it::
listeners = await trio.open_tcp_listeners(0)
client = await trio.testing.open_stream_to_socket_listener(listeners[0])
Args:
socket_listener (~trio.SocketListener): The
:class:`~trio.SocketListener` to connect to.
Returns:
SocketStream: a stream connected to the given listener.
"""
family = socket_listener.socket.family
sockaddr = socket_listener.socket.getsockname()
if family in (tsocket.AF_INET, tsocket.AF_INET6):
sockaddr = list(sockaddr)
if sockaddr[0] == "0.0.0.0":
sockaddr[0] = "127.0.0.1"
if sockaddr[0] == "::":
sockaddr[0] = "::1"
sockaddr = tuple(sockaddr)
sock = tsocket.socket(family=family)
await sock.connect(sockaddr)
return SocketStream(sock)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,87 @@
from __future__ import annotations
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
import attrs
from .. import Event, _core, _util
if TYPE_CHECKING:
from collections.abc import AsyncIterator
@_util.final
@attrs.define(eq=False, slots=False)
class Sequencer:
"""A convenience class for forcing code in different tasks to run in an
explicit linear order.
Instances of this class implement a ``__call__`` method which returns an
async context manager. The idea is that you pass a sequence number to
``__call__`` to say where this block of code should go in the linear
sequence. Block 0 starts immediately, and then block N doesn't start until
block N-1 has finished.
Example:
An extremely elaborate way to print the numbers 0-5, in order::
async def worker1(seq):
async with seq(0):
print(0)
async with seq(4):
print(4)
async def worker2(seq):
async with seq(2):
print(2)
async with seq(5):
print(5)
async def worker3(seq):
async with seq(1):
print(1)
async with seq(3):
print(3)
async def main():
seq = trio.testing.Sequencer()
async with trio.open_nursery() as nursery:
nursery.start_soon(worker1, seq)
nursery.start_soon(worker2, seq)
nursery.start_soon(worker3, seq)
"""
_sequence_points: defaultdict[int, Event] = attrs.field(
factory=lambda: defaultdict(Event),
init=False,
)
_claimed: set[int] = attrs.field(factory=set, init=False)
_broken: bool = attrs.field(default=False, init=False)
@asynccontextmanager
async def __call__(self, position: int) -> AsyncIterator[None]:
if position in self._claimed:
raise RuntimeError(f"Attempted to reuse sequence point {position}")
if self._broken:
raise RuntimeError("sequence broken!")
self._claimed.add(position)
if position != 0:
try:
await self._sequence_points[position].wait()
except _core.Cancelled:
self._broken = True
for event in self._sequence_points.values():
event.set()
raise RuntimeError(
"Sequencer wait cancelled -- sequence broken",
) from None
else:
if self._broken:
raise RuntimeError("sequence broken!")
try:
yield
finally:
self._sequence_points[position + 1].set()
@@ -0,0 +1,50 @@
from __future__ import annotations
from functools import partial, wraps
from typing import TYPE_CHECKING, TypeVar
from .. import _core
from ..abc import Clock, Instrument
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from typing_extensions import ParamSpec
ArgsT = ParamSpec("ArgsT")
RetT = TypeVar("RetT")
def trio_test(fn: Callable[ArgsT, Awaitable[RetT]]) -> Callable[ArgsT, RetT]:
"""Converts an async test function to be synchronous, running via Trio.
Usage::
@trio_test
async def test_whatever():
await ...
If a pytest fixture is passed in that subclasses the :class:`~trio.abc.Clock` or
:class:`~trio.abc.Instrument` ABCs, then those are passed to :meth:`trio.run()`.
"""
@wraps(fn)
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
__tracebackhide__ = True
clocks = [c for c in kwargs.values() if isinstance(c, Clock)]
if not clocks:
clock = None
elif len(clocks) == 1:
clock = clocks[0]
else:
raise ValueError("too many clocks spoil the broth!")
instruments = [i for i in kwargs.values() if isinstance(i, Instrument)]
return _core.run(
partial(fn, *args, **kwargs),
clock=clock,
instruments=instruments,
)
return wrapper