initial commit
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625)
|
||||
|
||||
from .. import _deprecate as _deprecate
|
||||
from .._core import (
|
||||
MockClock as MockClock,
|
||||
wait_all_tasks_blocked as wait_all_tasks_blocked,
|
||||
)
|
||||
from .._threads import (
|
||||
active_thread_count as active_thread_count,
|
||||
wait_all_threads_completed as wait_all_threads_completed,
|
||||
)
|
||||
from .._util import fixup_module_metadata
|
||||
from ._check_streams import (
|
||||
check_half_closeable_stream as check_half_closeable_stream,
|
||||
check_one_way_stream as check_one_way_stream,
|
||||
check_two_way_stream as check_two_way_stream,
|
||||
)
|
||||
from ._checkpoints import (
|
||||
assert_checkpoints as assert_checkpoints,
|
||||
assert_no_checkpoints as assert_no_checkpoints,
|
||||
)
|
||||
from ._memory_streams import (
|
||||
MemoryReceiveStream as MemoryReceiveStream,
|
||||
MemorySendStream as MemorySendStream,
|
||||
lockstep_stream_one_way_pair as lockstep_stream_one_way_pair,
|
||||
lockstep_stream_pair as lockstep_stream_pair,
|
||||
memory_stream_one_way_pair as memory_stream_one_way_pair,
|
||||
memory_stream_pair as memory_stream_pair,
|
||||
memory_stream_pump as memory_stream_pump,
|
||||
)
|
||||
from ._network import open_stream_to_socket_listener as open_stream_to_socket_listener
|
||||
from ._raises_group import Matcher as _Matcher, RaisesGroup as _RaisesGroup
|
||||
from ._sequencer import Sequencer as Sequencer
|
||||
from ._trio_test import trio_test as trio_test
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
_deprecate.deprecate_attributes(
|
||||
__name__,
|
||||
{
|
||||
"RaisesGroup": _deprecate.DeprecatedAttribute(
|
||||
_RaisesGroup,
|
||||
version="0.33.0",
|
||||
issue=3326,
|
||||
instead="See https://docs.pytest.org/en/stable/reference/reference.html#pytest.RaisesGroup",
|
||||
),
|
||||
"Matcher": _deprecate.DeprecatedAttribute(
|
||||
_Matcher,
|
||||
version="0.33.0",
|
||||
issue=3326,
|
||||
instead="See https://docs.pytest.org/en/stable/reference/reference.html#pytest.RaisesExc",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
fixup_module_metadata(__name__, globals())
|
||||
del fixup_module_metadata
|
||||
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,570 @@
|
||||
# Generic stream tests
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, Generator
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Generic,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from .. import CancelScope, _core
|
||||
from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream
|
||||
from .._highlevel_generic import aclose_forcefully
|
||||
from ._checkpoints import assert_checkpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
ArgsT = ParamSpec("ArgsT")
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
Res1 = TypeVar("Res1", bound=AsyncResource)
|
||||
Res2 = TypeVar("Res2", bound=AsyncResource)
|
||||
StreamMaker: TypeAlias = Callable[[], Awaitable[tuple[Res1, Res2]]]
|
||||
|
||||
|
||||
class _ForceCloseBoth(Generic[Res1, Res2]):
|
||||
def __init__(self, both: tuple[Res1, Res2]) -> None:
|
||||
self._first, self._second = both
|
||||
|
||||
async def __aenter__(self) -> tuple[Res1, Res2]:
|
||||
return self._first, self._second
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
try:
|
||||
await aclose_forcefully(self._first)
|
||||
finally:
|
||||
await aclose_forcefully(self._second)
|
||||
|
||||
|
||||
# This is used in this file instead of pytest.raises in order to avoid a dependency
|
||||
# on pytest, as the check_* functions are publicly exported.
|
||||
@contextmanager
|
||||
def _assert_raises(
|
||||
expected_exc: type[BaseException],
|
||||
wrapped: bool = False,
|
||||
) -> Generator[None, None, None]:
|
||||
__tracebackhide__ = True
|
||||
try:
|
||||
yield
|
||||
except BaseExceptionGroup as exc:
|
||||
assert wrapped, "caught exceptiongroup, but expected an unwrapped exception"
|
||||
# assert in except block ignored below
|
||||
assert len(exc.exceptions) == 1 # noqa: PT017
|
||||
assert isinstance(exc.exceptions[0], expected_exc) # noqa: PT017
|
||||
except expected_exc:
|
||||
assert not wrapped, "caught exception, but expected an exceptiongroup"
|
||||
else:
|
||||
raise AssertionError(f"expected exception: {expected_exc}")
|
||||
|
||||
|
||||
async def check_one_way_stream(
|
||||
stream_maker: StreamMaker[SendStream, ReceiveStream],
|
||||
clogged_stream_maker: StreamMaker[SendStream, ReceiveStream] | None,
|
||||
) -> None:
|
||||
"""Perform a number of generic tests on a custom one-way stream
|
||||
implementation.
|
||||
|
||||
Args:
|
||||
stream_maker: An async (!) function which returns a connected
|
||||
(:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`)
|
||||
pair.
|
||||
clogged_stream_maker: Either None, or an async function similar to
|
||||
stream_maker, but with the extra property that the returned stream
|
||||
is in a state where ``send_all`` and
|
||||
``wait_send_all_might_not_block`` will block until ``receive_some``
|
||||
has been called. This allows for more thorough testing of some edge
|
||||
cases, especially around ``wait_send_all_might_not_block``.
|
||||
|
||||
Raises:
|
||||
AssertionError: if a test fails.
|
||||
|
||||
"""
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
assert isinstance(s, SendStream)
|
||||
assert isinstance(r, ReceiveStream)
|
||||
|
||||
async def do_send_all(data: bytes | bytearray | memoryview) -> None:
|
||||
with assert_checkpoints(): # We're testing that it doesn't return anything.
|
||||
assert await s.send_all(data) is None # type: ignore[func-returns-value]
|
||||
|
||||
async def do_receive_some(max_bytes: int | None = None) -> bytes | bytearray:
|
||||
with assert_checkpoints():
|
||||
return await r.receive_some(max_bytes)
|
||||
|
||||
async def checked_receive_1(expected: bytes) -> None:
|
||||
assert await do_receive_some(1) == expected
|
||||
|
||||
async def do_aclose(resource: AsyncResource) -> None:
|
||||
with assert_checkpoints():
|
||||
await resource.aclose()
|
||||
|
||||
# Simple sending/receiving
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all, b"x")
|
||||
nursery.start_soon(checked_receive_1, b"x")
|
||||
|
||||
async def send_empty_then_y() -> None:
|
||||
# Streams should tolerate sending b"" without giving it any
|
||||
# special meaning.
|
||||
await do_send_all(b"")
|
||||
await do_send_all(b"y")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(send_empty_then_y)
|
||||
nursery.start_soon(checked_receive_1, b"y")
|
||||
|
||||
# ---- Checking various argument types ----
|
||||
|
||||
# send_all accepts bytearray and memoryview
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all, bytearray(b"1"))
|
||||
nursery.start_soon(checked_receive_1, b"1")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all, memoryview(b"2"))
|
||||
nursery.start_soon(checked_receive_1, b"2")
|
||||
|
||||
# max_bytes must be a positive integer
|
||||
with _assert_raises(ValueError):
|
||||
await r.receive_some(-1)
|
||||
with _assert_raises(ValueError):
|
||||
await r.receive_some(0)
|
||||
with _assert_raises(TypeError):
|
||||
await r.receive_some(1.5) # type: ignore[arg-type]
|
||||
# it can also be missing or None
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all, b"x")
|
||||
assert await do_receive_some() == b"x"
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all, b"x")
|
||||
assert await do_receive_some(None) == b"x"
|
||||
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_receive_some, 1)
|
||||
nursery.start_soon(do_receive_some, 1)
|
||||
|
||||
# Method always has to exist, and an empty stream with a blocked
|
||||
# receive_some should *always* allow send_all. (Technically it's legal
|
||||
# for send_all to wait until receive_some is called to run, though; a
|
||||
# stream doesn't *have* to have any internal buffering. That's why we
|
||||
# start a concurrent receive_some call, then cancel it.)
|
||||
async def simple_check_wait_send_all_might_not_block(
|
||||
scope: CancelScope,
|
||||
) -> None:
|
||||
with assert_checkpoints():
|
||||
await s.wait_send_all_might_not_block()
|
||||
scope.cancel()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
simple_check_wait_send_all_might_not_block,
|
||||
nursery.cancel_scope,
|
||||
)
|
||||
nursery.start_soon(do_receive_some, 1)
|
||||
|
||||
# closing the r side leads to BrokenResourceError on the s side
|
||||
# (eventually)
|
||||
async def expect_broken_stream_on_send() -> None:
|
||||
with _assert_raises(_core.BrokenResourceError):
|
||||
while True:
|
||||
await do_send_all(b"x" * 100)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_broken_stream_on_send)
|
||||
nursery.start_soon(do_aclose, r)
|
||||
|
||||
# once detected, the stream stays broken
|
||||
with _assert_raises(_core.BrokenResourceError):
|
||||
await do_send_all(b"x" * 100)
|
||||
|
||||
# r closed -> ClosedResourceError on the receive side
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_receive_some(4096)
|
||||
|
||||
# we can close the same stream repeatedly, it's fine
|
||||
await do_aclose(r)
|
||||
await do_aclose(r)
|
||||
|
||||
# closing the sender side
|
||||
await do_aclose(s)
|
||||
|
||||
# now trying to send raises ClosedResourceError
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"x" * 100)
|
||||
|
||||
# even if it's an empty send
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"")
|
||||
|
||||
# ditto for wait_send_all_might_not_block
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
with assert_checkpoints():
|
||||
await s.wait_send_all_might_not_block()
|
||||
|
||||
# and again, repeated closing is fine
|
||||
await do_aclose(s)
|
||||
await do_aclose(s)
|
||||
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
# if send-then-graceful-close, receiver gets data then b""
|
||||
async def send_then_close() -> None:
|
||||
await do_send_all(b"y")
|
||||
await do_aclose(s)
|
||||
|
||||
async def receive_send_then_close() -> None:
|
||||
# We want to make sure that if the sender closes the stream before
|
||||
# we read anything, then we still get all the data. But some
|
||||
# streams might block on the do_send_all call. So we let the
|
||||
# sender get as far as it can, then we receive.
|
||||
await _core.wait_all_tasks_blocked()
|
||||
await checked_receive_1(b"y")
|
||||
await checked_receive_1(b"")
|
||||
await do_aclose(r)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(send_then_close)
|
||||
nursery.start_soon(receive_send_then_close)
|
||||
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
await aclose_forcefully(r)
|
||||
|
||||
with _assert_raises(_core.BrokenResourceError):
|
||||
while True:
|
||||
await do_send_all(b"x" * 100)
|
||||
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_receive_some(4096)
|
||||
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
await aclose_forcefully(s)
|
||||
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"123")
|
||||
|
||||
# after the sender does a forceful close, the receiver might either
|
||||
# get BrokenResourceError or a clean b""; either is OK. Not OK would be
|
||||
# if it freezes, or returns data.
|
||||
with suppress(_core.BrokenResourceError):
|
||||
await checked_receive_1(b"")
|
||||
|
||||
# cancelled aclose still closes
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
with _core.CancelScope() as scope:
|
||||
scope.cancel()
|
||||
await r.aclose()
|
||||
|
||||
with _core.CancelScope() as scope:
|
||||
scope.cancel()
|
||||
await s.aclose()
|
||||
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"123")
|
||||
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_receive_some(4096)
|
||||
|
||||
# Check that we can still gracefully close a stream after an operation has
|
||||
# been cancelled. This can be challenging if cancellation can leave the
|
||||
# stream internals in an inconsistent state, e.g. for
|
||||
# SSLStream. Unfortunately this test isn't very thorough; the really
|
||||
# challenging case for something like SSLStream is it gets cancelled
|
||||
# *while* it's sending data on the underlying, not before. But testing
|
||||
# that requires some special-case handling of the particular stream setup;
|
||||
# we can't do it here. Maybe we could do a bit better with
|
||||
# https://github.com/python-trio/trio/issues/77
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
|
||||
async def expect_cancelled(
|
||||
afn: Callable[ArgsT, Awaitable[object]],
|
||||
*args: ArgsT.args,
|
||||
**kwargs: ArgsT.kwargs,
|
||||
) -> None:
|
||||
with _assert_raises(_core.Cancelled):
|
||||
await afn(*args, **kwargs)
|
||||
|
||||
with _core.CancelScope() as scope:
|
||||
scope.cancel()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_cancelled, do_send_all, b"x")
|
||||
nursery.start_soon(expect_cancelled, do_receive_some, 1)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_aclose, s)
|
||||
nursery.start_soon(do_aclose, r)
|
||||
|
||||
# Check that if a task is blocked in receive_some, then closing the
|
||||
# receive stream causes it to wake up.
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
|
||||
async def receive_expecting_closed() -> None:
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await r.receive_some(10)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_expecting_closed)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
await aclose_forcefully(r)
|
||||
|
||||
# check wait_send_all_might_not_block, if we can
|
||||
if clogged_stream_maker is not None:
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
record: list[str] = []
|
||||
|
||||
async def waiter(cancel_scope: CancelScope) -> None:
|
||||
record.append("waiter sleeping")
|
||||
with assert_checkpoints():
|
||||
await s.wait_send_all_might_not_block()
|
||||
record.append("waiter wokeup")
|
||||
cancel_scope.cancel()
|
||||
|
||||
async def receiver() -> None:
|
||||
# give wait_send_all_might_not_block a chance to block
|
||||
await _core.wait_all_tasks_blocked()
|
||||
record.append("receiver starting")
|
||||
while True:
|
||||
await r.receive_some(16834)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(waiter, nursery.cancel_scope)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
nursery.start_soon(receiver)
|
||||
|
||||
assert record == [
|
||||
"waiter sleeping",
|
||||
"receiver starting",
|
||||
"waiter wokeup",
|
||||
]
|
||||
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
# simultaneous wait_send_all_might_not_block fails
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(s.wait_send_all_might_not_block)
|
||||
nursery.start_soon(s.wait_send_all_might_not_block)
|
||||
|
||||
# and simultaneous send_all and wait_send_all_might_not_block (NB
|
||||
# this test might destroy the stream b/c we end up cancelling
|
||||
# send_all and e.g. SSLStream can't handle that, so we have to
|
||||
# recreate afterwards)
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(s.wait_send_all_might_not_block)
|
||||
nursery.start_soon(s.send_all, b"123")
|
||||
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
# send_all and send_all blocked simultaneously should also raise
|
||||
# (but again this might destroy the stream)
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(s.send_all, b"123")
|
||||
nursery.start_soon(s.send_all, b"123")
|
||||
|
||||
# closing the receiver causes wait_send_all_might_not_block to return,
|
||||
# with or without an exception
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
|
||||
async def sender() -> None:
|
||||
try:
|
||||
with assert_checkpoints():
|
||||
await s.wait_send_all_might_not_block()
|
||||
except _core.BrokenResourceError: # pragma: no cover
|
||||
pass
|
||||
|
||||
async def receiver() -> None:
|
||||
await _core.wait_all_tasks_blocked()
|
||||
await aclose_forcefully(r)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(receiver)
|
||||
|
||||
# and again with the call starting after the close
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
await aclose_forcefully(r)
|
||||
try:
|
||||
with assert_checkpoints():
|
||||
await s.wait_send_all_might_not_block()
|
||||
except _core.BrokenResourceError: # pragma: no cover
|
||||
pass
|
||||
|
||||
# Check that if a task is blocked in a send-side method, then closing
|
||||
# the send stream causes it to wake up.
|
||||
async def close_soon(s: SendStream) -> None:
|
||||
await _core.wait_all_tasks_blocked()
|
||||
await aclose_forcefully(s)
|
||||
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(close_soon, s)
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await s.send_all(b"xyzzy")
|
||||
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(close_soon, s)
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await s.wait_send_all_might_not_block()
|
||||
|
||||
|
||||
async def check_two_way_stream(
|
||||
stream_maker: StreamMaker[Stream, Stream],
|
||||
clogged_stream_maker: StreamMaker[Stream, Stream] | None,
|
||||
) -> None:
|
||||
"""Perform a number of generic tests on a custom two-way stream
|
||||
implementation.
|
||||
|
||||
This is similar to :func:`check_one_way_stream`, except that the maker
|
||||
functions are expected to return objects implementing the
|
||||
:class:`~trio.abc.Stream` interface.
|
||||
|
||||
This function tests a *superset* of what :func:`check_one_way_stream`
|
||||
checks – if you call this, then you don't need to also call
|
||||
:func:`check_one_way_stream`.
|
||||
|
||||
"""
|
||||
await check_one_way_stream(stream_maker, clogged_stream_maker)
|
||||
|
||||
async def flipped_stream_maker() -> tuple[Stream, Stream]:
|
||||
return (await stream_maker())[::-1]
|
||||
|
||||
flipped_clogged_stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]] | None
|
||||
|
||||
if clogged_stream_maker is not None:
|
||||
|
||||
async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]:
|
||||
return (await clogged_stream_maker())[::-1]
|
||||
|
||||
else:
|
||||
flipped_clogged_stream_maker = None
|
||||
await check_one_way_stream(flipped_stream_maker, flipped_clogged_stream_maker)
|
||||
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s1, s2):
|
||||
assert isinstance(s1, Stream)
|
||||
assert isinstance(s2, Stream)
|
||||
|
||||
# Duplex can be a bit tricky, might as well check it as well
|
||||
DUPLEX_TEST_SIZE = 2**20
|
||||
CHUNK_SIZE_MAX = 2**14
|
||||
|
||||
r = random.Random(0)
|
||||
i = r.getrandbits(8 * DUPLEX_TEST_SIZE)
|
||||
test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little")
|
||||
|
||||
async def sender(
|
||||
s: Stream,
|
||||
data: bytes | bytearray | memoryview,
|
||||
seed: int,
|
||||
) -> None:
|
||||
r = random.Random(seed)
|
||||
m = memoryview(data)
|
||||
while m:
|
||||
chunk_size = r.randint(1, CHUNK_SIZE_MAX)
|
||||
await s.send_all(m[:chunk_size])
|
||||
m = m[chunk_size:]
|
||||
|
||||
async def receiver(s: Stream, data: bytes | bytearray, seed: int) -> None:
|
||||
r = random.Random(seed)
|
||||
got = bytearray()
|
||||
while len(got) < len(data):
|
||||
chunk = await s.receive_some(r.randint(1, CHUNK_SIZE_MAX))
|
||||
assert chunk
|
||||
got += chunk
|
||||
assert got == data
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender, s1, test_data, 0)
|
||||
nursery.start_soon(sender, s2, test_data[::-1], 1)
|
||||
nursery.start_soon(receiver, s1, test_data[::-1], 2)
|
||||
nursery.start_soon(receiver, s2, test_data, 3)
|
||||
|
||||
async def expect_receive_some_empty() -> None:
|
||||
assert await s2.receive_some(10) == b""
|
||||
await s2.aclose()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_receive_some_empty)
|
||||
nursery.start_soon(s1.aclose)
|
||||
|
||||
|
||||
async def check_half_closeable_stream(
|
||||
stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream],
|
||||
clogged_stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream] | None,
|
||||
) -> None:
|
||||
"""Perform a number of generic tests on a custom half-closeable stream
|
||||
implementation.
|
||||
|
||||
This is similar to :func:`check_two_way_stream`, except that the maker
|
||||
functions are expected to return objects that implement the
|
||||
:class:`~trio.abc.HalfCloseableStream` interface.
|
||||
|
||||
This function tests a *superset* of what :func:`check_two_way_stream`
|
||||
checks – if you call this, then you don't need to also call
|
||||
:func:`check_two_way_stream`.
|
||||
|
||||
"""
|
||||
await check_two_way_stream(stream_maker, clogged_stream_maker)
|
||||
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s1, s2):
|
||||
assert isinstance(s1, HalfCloseableStream)
|
||||
assert isinstance(s2, HalfCloseableStream)
|
||||
|
||||
async def send_x_then_eof(s: HalfCloseableStream) -> None:
|
||||
await s.send_all(b"x")
|
||||
with assert_checkpoints():
|
||||
await s.send_eof()
|
||||
|
||||
async def expect_x_then_eof(r: HalfCloseableStream) -> None:
|
||||
await _core.wait_all_tasks_blocked()
|
||||
assert await r.receive_some(10) == b"x"
|
||||
assert await r.receive_some(10) == b""
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(send_x_then_eof, s1)
|
||||
nursery.start_soon(expect_x_then_eof, s2)
|
||||
|
||||
# now sending is disallowed
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await s1.send_all(b"y")
|
||||
|
||||
# but we can do send_eof again
|
||||
with assert_checkpoints():
|
||||
await s1.send_eof()
|
||||
|
||||
# and we can still send stuff back the other way
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(send_x_then_eof, s2)
|
||||
nursery.start_soon(expect_x_then_eof, s1)
|
||||
|
||||
if clogged_stream_maker is not None:
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2):
|
||||
# send_all and send_eof simultaneously is not ok
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(s1.send_all, b"x")
|
||||
await _core.wait_all_tasks_blocked()
|
||||
nursery.start_soon(s1.send_eof)
|
||||
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2):
|
||||
# wait_send_all_might_not_block and send_eof simultaneously is not
|
||||
# ok either
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(s1.wait_send_all_might_not_block)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
nursery.start_soon(s1.send_eof)
|
||||
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .. import _core
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _assert_yields_or_not(expected: bool) -> Generator[None, None, None]:
|
||||
"""Check if checkpoints are executed in a block of code."""
|
||||
__tracebackhide__ = True
|
||||
task = _core.current_task()
|
||||
orig_cancel = task._cancel_points
|
||||
orig_schedule = task._schedule_points
|
||||
try:
|
||||
yield
|
||||
if expected and (
|
||||
task._cancel_points == orig_cancel or task._schedule_points == orig_schedule
|
||||
):
|
||||
raise AssertionError("assert_checkpoints block did not yield!")
|
||||
finally:
|
||||
if not expected and (
|
||||
task._cancel_points != orig_cancel or task._schedule_points != orig_schedule
|
||||
):
|
||||
raise AssertionError("assert_no_checkpoints block yielded!")
|
||||
|
||||
|
||||
def assert_checkpoints() -> AbstractContextManager[None]:
|
||||
"""Use as a context manager to check that the code inside the ``with``
|
||||
block either exits with an exception or executes at least one
|
||||
:ref:`checkpoint <checkpoints>`.
|
||||
|
||||
Raises:
|
||||
AssertionError: if no checkpoint was executed.
|
||||
|
||||
Example:
|
||||
Check that :func:`trio.sleep` is a checkpoint, even if it doesn't
|
||||
block::
|
||||
|
||||
with trio.testing.assert_checkpoints():
|
||||
await trio.sleep(0)
|
||||
|
||||
"""
|
||||
__tracebackhide__ = True
|
||||
return _assert_yields_or_not(True)
|
||||
|
||||
|
||||
def assert_no_checkpoints() -> AbstractContextManager[None]:
|
||||
"""Use as a context manager to check that the code inside the ``with``
|
||||
block does not execute any :ref:`checkpoints <checkpoints>`.
|
||||
|
||||
Raises:
|
||||
AssertionError: if a checkpoint was executed.
|
||||
|
||||
Example:
|
||||
Synchronous code never contains any checkpoints, but we can double-check
|
||||
that::
|
||||
|
||||
send_channel, receive_channel = trio.open_memory_channel(10)
|
||||
with trio.testing.assert_no_checkpoints():
|
||||
send_channel.send_nowait(None)
|
||||
|
||||
"""
|
||||
__tracebackhide__ = True
|
||||
return _assert_yields_or_not(False)
|
||||
@@ -0,0 +1,584 @@
|
||||
# This should eventually be cleaned up and become public, but for right now I'm just
|
||||
# implementing enough to test DTLS.
|
||||
|
||||
# TODO:
|
||||
# - user-defined routers
|
||||
# - TCP
|
||||
# - UDP broadcast
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import errno
|
||||
import ipaddress
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, NoReturn, TypeAlias, overload
|
||||
|
||||
import attrs
|
||||
|
||||
import trio
|
||||
from trio._util import NoPublicConstructor, final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import builtins
|
||||
from collections.abc import Iterable
|
||||
from socket import AddressFamily, SocketKind
|
||||
from types import TracebackType
|
||||
|
||||
from typing_extensions import Buffer, Self
|
||||
|
||||
from trio._socket import AddressFormat
|
||||
|
||||
IPAddress: TypeAlias = ipaddress.IPv4Address | ipaddress.IPv6Address
|
||||
|
||||
|
||||
def _family_for(ip: IPAddress) -> int:
|
||||
if isinstance(ip, ipaddress.IPv4Address):
|
||||
return trio.socket.AF_INET
|
||||
elif isinstance(ip, ipaddress.IPv6Address):
|
||||
return trio.socket.AF_INET6
|
||||
raise NotImplementedError("Unhandled IPAddress instance type") # pragma: no cover
|
||||
|
||||
|
||||
def _wildcard_ip_for(family: int) -> IPAddress:
|
||||
if family == trio.socket.AF_INET:
|
||||
return ipaddress.ip_address("0.0.0.0")
|
||||
elif family == trio.socket.AF_INET6:
|
||||
return ipaddress.ip_address("::")
|
||||
raise NotImplementedError("Unhandled ip address family") # pragma: no cover
|
||||
|
||||
|
||||
# not used anywhere
|
||||
def _localhost_ip_for(family: int) -> IPAddress: # pragma: no cover
|
||||
if family == trio.socket.AF_INET:
|
||||
return ipaddress.ip_address("127.0.0.1")
|
||||
elif family == trio.socket.AF_INET6:
|
||||
return ipaddress.ip_address("::1")
|
||||
raise NotImplementedError("Unhandled ip address family")
|
||||
|
||||
|
||||
def _fake_err(code: int) -> NoReturn:
|
||||
raise OSError(code, os.strerror(code))
|
||||
|
||||
|
||||
def _scatter(data: bytes, buffers: Iterable[Buffer]) -> int:
|
||||
written = 0
|
||||
for buf in buffers: # pragma: no branch
|
||||
next_piece = data[written : written + memoryview(buf).nbytes]
|
||||
with memoryview(buf) as mbuf:
|
||||
mbuf[: len(next_piece)] = next_piece
|
||||
written += len(next_piece)
|
||||
if written == len(data): # pragma: no branch
|
||||
break
|
||||
return written
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class UDPEndpoint:
|
||||
ip: IPAddress
|
||||
port: int
|
||||
|
||||
def as_python_sockaddr(self) -> tuple[str, int] | tuple[str, int, int, int]:
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
self.ip.compressed,
|
||||
self.port,
|
||||
)
|
||||
if isinstance(self.ip, ipaddress.IPv6Address):
|
||||
sockaddr += (0, 0) # type: ignore[assignment]
|
||||
return sockaddr
|
||||
|
||||
@classmethod
|
||||
def from_python_sockaddr(
|
||||
cls,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
) -> UDPEndpoint:
|
||||
ip, port = sockaddr[:2]
|
||||
return cls(ip=ipaddress.ip_address(ip), port=port)
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class UDPBinding:
|
||||
local: UDPEndpoint
|
||||
# remote: UDPEndpoint # ??
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class UDPPacket:
|
||||
source: UDPEndpoint
|
||||
destination: UDPEndpoint
|
||||
payload: bytes = attrs.field(repr=lambda p: p.hex())
|
||||
|
||||
# not used/tested anywhere
|
||||
def reply(self, payload: bytes) -> UDPPacket: # pragma: no cover
|
||||
return UDPPacket(
|
||||
source=self.destination,
|
||||
destination=self.source,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class FakeSocketFactory(trio.abc.SocketFactory):
|
||||
fake_net: FakeNet
|
||||
|
||||
def socket(self, family: int, type_: int, proto: int) -> FakeSocket: # type: ignore[override]
|
||||
return FakeSocket._create(self.fake_net, family, type_, proto)
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class FakeHostnameResolver(trio.abc.HostnameResolver):
|
||||
fake_net: FakeNet
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = 0,
|
||||
type: int = 0,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
|
||||
]
|
||||
]:
|
||||
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
|
||||
|
||||
async def getnameinfo(
|
||||
self,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
flags: int,
|
||||
) -> tuple[str, str]:
|
||||
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
|
||||
|
||||
|
||||
@final
|
||||
class FakeNet:
|
||||
def __init__(self) -> None:
|
||||
# When we need to pick an arbitrary unique ip address/port, use these:
|
||||
self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() # untested
|
||||
self._auto_ipv6_iter = ipaddress.IPv6Network("1::/16").hosts() # untested
|
||||
self._auto_port_iter = iter(range(50000, 65535))
|
||||
|
||||
self._bound: dict[UDPBinding, FakeSocket] = {}
|
||||
|
||||
self.route_packet = None
|
||||
|
||||
def _bind(self, binding: UDPBinding, socket: FakeSocket) -> None:
|
||||
if binding in self._bound:
|
||||
_fake_err(errno.EADDRINUSE)
|
||||
self._bound[binding] = socket
|
||||
|
||||
def enable(self) -> None:
|
||||
trio.socket.set_custom_socket_factory(FakeSocketFactory(self))
|
||||
trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self))
|
||||
|
||||
def send_packet(self, packet: UDPPacket) -> None:
|
||||
if self.route_packet is None:
|
||||
self.deliver_packet(packet)
|
||||
else:
|
||||
self.route_packet(packet)
|
||||
|
||||
def deliver_packet(self, packet: UDPPacket) -> None:
|
||||
binding = UDPBinding(local=packet.destination)
|
||||
if binding in self._bound:
|
||||
self._bound[binding]._deliver_packet(packet)
|
||||
else:
|
||||
# No valid destination, so drop it
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor):
|
||||
def __init__(
|
||||
self,
|
||||
fake_net: FakeNet,
|
||||
family: AddressFamily,
|
||||
type: SocketKind,
|
||||
proto: int,
|
||||
) -> None:
|
||||
self._fake_net = fake_net
|
||||
|
||||
if not family: # pragma: no cover
|
||||
family = trio.socket.AF_INET
|
||||
if not type: # pragma: no cover
|
||||
type = trio.socket.SOCK_STREAM # noqa: A001 # name shadowing builtin
|
||||
|
||||
if family not in (trio.socket.AF_INET, trio.socket.AF_INET6):
|
||||
raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}")
|
||||
if type != trio.socket.SOCK_DGRAM:
|
||||
raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}")
|
||||
|
||||
self._family = family
|
||||
self._type = type
|
||||
self._proto = proto
|
||||
|
||||
self._closed = False
|
||||
|
||||
self._packet_sender, self._packet_receiver = trio.open_memory_channel[
|
||||
UDPPacket
|
||||
](float("inf"))
|
||||
|
||||
# This is the source-of-truth for what port etc. this socket is bound to
|
||||
self._binding: UDPBinding | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> SocketKind:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def family(self) -> AddressFamily:
|
||||
return self._family
|
||||
|
||||
@property
|
||||
def proto(self) -> int:
|
||||
return self._proto
|
||||
|
||||
def _check_closed(self) -> None:
|
||||
if self._closed:
|
||||
_fake_err(errno.EBADF)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if self._binding is not None:
|
||||
del self._fake_net._bound[self._binding]
|
||||
self._packet_receiver.close()
|
||||
|
||||
async def _resolve_address_nocp(
|
||||
self,
|
||||
address: object,
|
||||
*,
|
||||
local: bool,
|
||||
) -> tuple[str, int]:
|
||||
return await trio._socket._resolve_address_nocp( # type: ignore[no-any-return]
|
||||
self.type,
|
||||
self.family,
|
||||
self.proto,
|
||||
address=address,
|
||||
ipv6_v6only=False,
|
||||
local=local,
|
||||
)
|
||||
|
||||
def _deliver_packet(self, packet: UDPPacket) -> None:
|
||||
# sending to a closed socket -- UDP packets get dropped
|
||||
with contextlib.suppress(trio.BrokenResourceError):
|
||||
self._packet_sender.send_nowait(packet)
|
||||
|
||||
################################################################
|
||||
# Actual IO operation implementations
|
||||
################################################################
|
||||
|
||||
async def bind(self, addr: object) -> None:
|
||||
self._check_closed()
|
||||
if self._binding is not None:
|
||||
_fake_err(errno.EINVAL)
|
||||
await trio.lowlevel.checkpoint()
|
||||
ip_str, port, *_ = await self._resolve_address_nocp(addr, local=True)
|
||||
assert _ == [], "TODO: handle other values?"
|
||||
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
assert _family_for(ip) == self.family
|
||||
# We convert binds to INET_ANY into binds to localhost
|
||||
if ip == ipaddress.ip_address("0.0.0.0"):
|
||||
ip = ipaddress.ip_address("127.0.0.1")
|
||||
elif ip == ipaddress.ip_address("::"):
|
||||
ip = ipaddress.ip_address("::1")
|
||||
if port == 0:
|
||||
port = next(self._fake_net._auto_port_iter)
|
||||
binding = UDPBinding(local=UDPEndpoint(ip, port))
|
||||
self._fake_net._bind(binding, self)
|
||||
self._binding = binding
|
||||
|
||||
async def connect(self, peer: object) -> NoReturn:
|
||||
raise NotImplementedError("FakeNet does not (yet) support connected sockets")
|
||||
|
||||
async def _sendmsg(
|
||||
self,
|
||||
buffers: Iterable[Buffer],
|
||||
ancdata: Iterable[tuple[int, int, Buffer]] = (),
|
||||
flags: int = 0,
|
||||
address: AddressFormat | None = None,
|
||||
) -> int:
|
||||
self._check_closed()
|
||||
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
if address is not None:
|
||||
address = await self._resolve_address_nocp(address, local=False)
|
||||
if ancdata:
|
||||
raise NotImplementedError("FakeNet doesn't support ancillary data")
|
||||
if flags:
|
||||
raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}")
|
||||
|
||||
if address is None:
|
||||
_fake_err(errno.ENOTCONN)
|
||||
|
||||
destination = UDPEndpoint.from_python_sockaddr(address)
|
||||
|
||||
if self._binding is None:
|
||||
await self.bind((_wildcard_ip_for(self.family).compressed, 0))
|
||||
|
||||
payload = b"".join(buffers)
|
||||
|
||||
assert self._binding is not None
|
||||
packet = UDPPacket(
|
||||
source=self._binding.local,
|
||||
destination=destination,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
self._fake_net.send_packet(packet)
|
||||
|
||||
return len(payload)
|
||||
|
||||
if sys.platform != "win32" or (
|
||||
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
|
||||
):
|
||||
sendmsg = _sendmsg
|
||||
|
||||
async def _recvmsg_into(
|
||||
self,
|
||||
buffers: Iterable[Buffer],
|
||||
ancbufsize: int = 0,
|
||||
flags: int = 0,
|
||||
) -> tuple[
|
||||
int,
|
||||
list[tuple[int, int, bytes]],
|
||||
int,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]:
|
||||
if ancbufsize != 0:
|
||||
raise NotImplementedError("FakeNet doesn't support ancillary data")
|
||||
if flags != 0:
|
||||
raise NotImplementedError("FakeNet doesn't support any recv flags")
|
||||
if self._binding is None:
|
||||
# I messed this up a few times when writing tests ... but it also never happens
|
||||
# in any of the existing tests, so maybe it could be intentional...
|
||||
raise NotImplementedError(
|
||||
"The code will most likely hang if you try to receive on a fakesocket "
|
||||
"without a binding. If that is not the case, or you explicitly want to "
|
||||
"test that, remove this warning.",
|
||||
)
|
||||
|
||||
self._check_closed()
|
||||
|
||||
ancdata: list[tuple[int, int, bytes]] = []
|
||||
msg_flags = 0
|
||||
|
||||
packet = await self._packet_receiver.receive()
|
||||
address = packet.source.as_python_sockaddr()
|
||||
written = _scatter(packet.payload, buffers)
|
||||
if written < len(packet.payload):
|
||||
msg_flags |= trio.socket.MSG_TRUNC
|
||||
return written, ancdata, msg_flags, address
|
||||
|
||||
if sys.platform != "win32" or (
|
||||
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
|
||||
):
|
||||
recvmsg_into = _recvmsg_into
|
||||
|
||||
################################################################
|
||||
# Simple state query stuff
|
||||
################################################################
|
||||
|
||||
def getsockname(self) -> tuple[str, int] | tuple[str, int, int, int]:
|
||||
self._check_closed()
|
||||
if self._binding is not None:
|
||||
return self._binding.local.as_python_sockaddr()
|
||||
elif self.family == trio.socket.AF_INET:
|
||||
return ("0.0.0.0", 0)
|
||||
else:
|
||||
assert self.family == trio.socket.AF_INET6
|
||||
return ("::", 0)
|
||||
|
||||
# TODO: This method is not tested, and seems to make incorrect assumptions. It should maybe raise NotImplementedError.
|
||||
def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]:
|
||||
self._check_closed()
|
||||
if self._binding is not None:
|
||||
assert hasattr(
|
||||
self._binding,
|
||||
"remote",
|
||||
), "This method seems to assume that self._binding has a remote UDPEndpoint"
|
||||
if self._binding.remote is not None: # pragma: no cover
|
||||
assert isinstance(
|
||||
self._binding.remote,
|
||||
UDPEndpoint,
|
||||
), "Self._binding.remote should be a UDPEndpoint"
|
||||
return self._binding.remote.as_python_sockaddr()
|
||||
_fake_err(errno.ENOTCONN)
|
||||
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ...
|
||||
|
||||
def getsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int | None = None,
|
||||
) -> int | bytes:
|
||||
self._check_closed()
|
||||
raise OSError(f"FakeNet doesn't implement getsockopt({level}, {optname})")
|
||||
|
||||
@overload
|
||||
def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ...
|
||||
|
||||
@overload
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: None,
|
||||
optlen: int,
|
||||
) -> None: ...
|
||||
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer | None,
|
||||
optlen: int | None = None,
|
||||
) -> None:
|
||||
self._check_closed()
|
||||
|
||||
if (level, optname) == (
|
||||
trio.socket.IPPROTO_IPV6,
|
||||
trio.socket.IPV6_V6ONLY,
|
||||
) and not value:
|
||||
raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True")
|
||||
|
||||
raise OSError(f"FakeNet doesn't implement setsockopt({level}, {optname}, ...)")
|
||||
|
||||
################################################################
|
||||
# Various boilerplate and trivial stubs
|
||||
################################################################
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: builtins.type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
async def send(self, data: Buffer, flags: int = 0) -> int:
|
||||
return await self.sendto(data, flags, None)
|
||||
|
||||
# __ prefixed arguments because typeshed uses that and typechecker issues
|
||||
@overload
|
||||
async def sendto(
|
||||
self,
|
||||
__data: Buffer, # noqa: PYI063
|
||||
__address: tuple[object, ...] | str | Buffer,
|
||||
) -> int: ...
|
||||
|
||||
# __ prefixed arguments because typeshed uses that and typechecker issues
|
||||
@overload
|
||||
async def sendto(
|
||||
self,
|
||||
__data: Buffer, # noqa: PYI063
|
||||
__flags: int,
|
||||
__address: tuple[object, ...] | str | Buffer | None,
|
||||
) -> int: ...
|
||||
|
||||
async def sendto( # type: ignore[explicit-any]
|
||||
self,
|
||||
*args: Any,
|
||||
) -> int:
|
||||
data: Buffer
|
||||
flags: int
|
||||
address: tuple[object, ...] | str | Buffer
|
||||
if len(args) == 2:
|
||||
data, address = args
|
||||
flags = 0
|
||||
elif len(args) == 3:
|
||||
data, flags, address = args
|
||||
else:
|
||||
raise TypeError("wrong number of arguments")
|
||||
return await self._sendmsg([data], [], flags, address)
|
||||
|
||||
async def recv(self, bufsize: int, flags: int = 0) -> bytes:
|
||||
data, _address = await self.recvfrom(bufsize, flags)
|
||||
return data
|
||||
|
||||
async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int:
|
||||
got_bytes, _address = await self.recvfrom_into(buf, nbytes, flags)
|
||||
return got_bytes
|
||||
|
||||
async def recvfrom(
|
||||
self,
|
||||
bufsize: int,
|
||||
flags: int = 0,
|
||||
) -> tuple[bytes, AddressFormat]:
|
||||
data, _ancdata, _msg_flags, address = await self._recvmsg(bufsize, flags)
|
||||
return data, address
|
||||
|
||||
async def recvfrom_into(
|
||||
self,
|
||||
buf: Buffer,
|
||||
nbytes: int = 0,
|
||||
flags: int = 0,
|
||||
) -> tuple[int, AddressFormat]:
|
||||
if nbytes != 0 and nbytes != memoryview(buf).nbytes:
|
||||
raise NotImplementedError("partial recvfrom_into")
|
||||
got_nbytes, _ancdata, _msg_flags, address = await self._recvmsg_into(
|
||||
[buf],
|
||||
0,
|
||||
flags,
|
||||
)
|
||||
return got_nbytes, address
|
||||
|
||||
async def _recvmsg(
|
||||
self,
|
||||
bufsize: int,
|
||||
ancbufsize: int = 0,
|
||||
flags: int = 0,
|
||||
) -> tuple[bytes, list[tuple[int, int, bytes]], int, AddressFormat]:
|
||||
buf = bytearray(bufsize)
|
||||
got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into(
|
||||
[buf],
|
||||
ancbufsize,
|
||||
flags,
|
||||
)
|
||||
return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address)
|
||||
|
||||
if sys.platform != "win32" or (
|
||||
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
|
||||
):
|
||||
recvmsg = _recvmsg
|
||||
|
||||
def fileno(self) -> int:
|
||||
raise NotImplementedError("can't get fileno() for FakeNet sockets")
|
||||
|
||||
def detach(self) -> int:
|
||||
raise NotImplementedError("can't detach() a FakeNet socket")
|
||||
|
||||
def get_inheritable(self) -> bool:
|
||||
return False
|
||||
|
||||
def set_inheritable(self, inheritable: bool) -> None:
|
||||
if inheritable:
|
||||
raise NotImplementedError("FakeNet can't make inheritable sockets")
|
||||
|
||||
if sys.platform == "win32" or (
|
||||
not TYPE_CHECKING and hasattr(socket.socket, "share")
|
||||
):
|
||||
|
||||
def share(self, process_id: int) -> bytes:
|
||||
raise NotImplementedError("FakeNet can't share sockets")
|
||||
@@ -0,0 +1,633 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypeAlias, TypeVar
|
||||
|
||||
from .. import _core, _util
|
||||
from .._highlevel_generic import StapledStream
|
||||
from ..abc import ReceiveStream, SendStream
|
||||
|
||||
AsyncHook: TypeAlias = Callable[[], Awaitable[object]]
|
||||
# Would be nice to exclude awaitable here, but currently not possible.
|
||||
SyncHook: TypeAlias = Callable[[], object]
|
||||
SendStreamT = TypeVar("SendStreamT", bound=SendStream)
|
||||
ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream)
|
||||
|
||||
|
||||
################################################################
|
||||
# In-memory streams - Unbounded buffer version
|
||||
################################################################
|
||||
|
||||
|
||||
class _UnboundedByteQueue:
|
||||
def __init__(self) -> None:
|
||||
self._data = bytearray()
|
||||
self._closed = False
|
||||
self._lot = _core.ParkingLot()
|
||||
self._fetch_lock = _util.ConflictDetector(
|
||||
"another task is already fetching data",
|
||||
)
|
||||
|
||||
# This object treats "close" as being like closing the send side of a
|
||||
# channel: so after close(), calling put() raises ClosedResourceError, and
|
||||
# calling the get() variants drains the buffer and then returns an empty
|
||||
# bytearray.
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
self._lot.unpark_all()
|
||||
|
||||
def close_and_wipe(self) -> None:
|
||||
self._data = bytearray()
|
||||
self.close()
|
||||
|
||||
def put(self, data: bytes | bytearray | memoryview) -> None:
|
||||
if self._closed:
|
||||
raise _core.ClosedResourceError("virtual connection closed")
|
||||
self._data += data
|
||||
self._lot.unpark_all()
|
||||
|
||||
def _check_max_bytes(self, max_bytes: int | None) -> None:
|
||||
if max_bytes is None:
|
||||
return
|
||||
max_bytes = operator.index(max_bytes)
|
||||
if max_bytes < 1:
|
||||
raise ValueError("max_bytes must be >= 1")
|
||||
|
||||
def _get_impl(self, max_bytes: int | None) -> bytearray:
|
||||
assert self._closed or self._data
|
||||
if max_bytes is None:
|
||||
max_bytes = len(self._data)
|
||||
if self._data:
|
||||
chunk = self._data[:max_bytes]
|
||||
del self._data[:max_bytes]
|
||||
assert chunk
|
||||
return chunk
|
||||
else:
|
||||
return bytearray()
|
||||
|
||||
def get_nowait(self, max_bytes: int | None = None) -> bytearray:
|
||||
with self._fetch_lock:
|
||||
self._check_max_bytes(max_bytes)
|
||||
if not self._closed and not self._data:
|
||||
raise _core.WouldBlock
|
||||
return self._get_impl(max_bytes)
|
||||
|
||||
async def get(self, max_bytes: int | None = None) -> bytearray:
|
||||
with self._fetch_lock:
|
||||
self._check_max_bytes(max_bytes)
|
||||
if not self._closed and not self._data:
|
||||
await self._lot.park()
|
||||
else:
|
||||
await _core.checkpoint()
|
||||
return self._get_impl(max_bytes)
|
||||
|
||||
|
||||
@_util.final
|
||||
class MemorySendStream(SendStream):
|
||||
"""An in-memory :class:`~trio.abc.SendStream`.
|
||||
|
||||
Args:
|
||||
send_all_hook: An async function, or None. Called from
|
||||
:meth:`send_all`. Can do whatever you like.
|
||||
wait_send_all_might_not_block_hook: An async function, or None. Called
|
||||
from :meth:`wait_send_all_might_not_block`. Can do whatever you
|
||||
like.
|
||||
close_hook: A synchronous function, or None. Called from :meth:`close`
|
||||
and :meth:`aclose`. Can do whatever you like.
|
||||
|
||||
.. attribute:: send_all_hook
|
||||
wait_send_all_might_not_block_hook
|
||||
close_hook
|
||||
|
||||
All of these hooks are also exposed as attributes on the object, and
|
||||
you can change them at any time.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
send_all_hook: AsyncHook | None = None,
|
||||
wait_send_all_might_not_block_hook: AsyncHook | None = None,
|
||||
close_hook: SyncHook | None = None,
|
||||
) -> None:
|
||||
self._conflict_detector = _util.ConflictDetector(
|
||||
"another task is using this stream",
|
||||
)
|
||||
self._outgoing = _UnboundedByteQueue()
|
||||
self.send_all_hook = send_all_hook
|
||||
self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook
|
||||
self.close_hook = close_hook
|
||||
|
||||
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""Places the given data into the object's internal buffer, and then
|
||||
calls the :attr:`send_all_hook` (if any).
|
||||
|
||||
"""
|
||||
# Execute two checkpoints so we have more of a chance to detect
|
||||
# buggy user code that calls this twice at the same time.
|
||||
with self._conflict_detector:
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
self._outgoing.put(data)
|
||||
if self.send_all_hook is not None:
|
||||
await self.send_all_hook()
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
"""Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and
|
||||
then returns immediately.
|
||||
|
||||
"""
|
||||
# Execute two checkpoints so that we have more of a chance to detect
|
||||
# buggy user code that calls this twice at the same time.
|
||||
with self._conflict_detector:
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
# check for being closed:
|
||||
self._outgoing.put(b"")
|
||||
if self.wait_send_all_might_not_block_hook is not None:
|
||||
await self.wait_send_all_might_not_block_hook()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Marks this stream as closed, and then calls the :attr:`close_hook`
|
||||
(if any).
|
||||
|
||||
"""
|
||||
# XXX should this cancel any pending calls to the send_all_hook and
|
||||
# wait_send_all_might_not_block_hook? Those are the only places where
|
||||
# send_all and wait_send_all_might_not_block can be blocked.
|
||||
#
|
||||
# The way we set things up, send_all_hook is memory_stream_pump, and
|
||||
# wait_send_all_might_not_block_hook is unset. memory_stream_pump is
|
||||
# synchronous. So normally, send_all and wait_send_all_might_not_block
|
||||
# cannot block at all.
|
||||
self._outgoing.close()
|
||||
if self.close_hook is not None:
|
||||
self.close_hook()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Same as :meth:`close`, but async."""
|
||||
self.close()
|
||||
await _core.checkpoint()
|
||||
|
||||
async def get_data(self, max_bytes: int | None = None) -> bytearray:
|
||||
"""Retrieves data from the internal buffer, blocking if necessary.
|
||||
|
||||
Args:
|
||||
max_bytes (int or None): The maximum amount of data to
|
||||
retrieve. None (the default) means to retrieve all the data
|
||||
that's present (but still blocks until at least one byte is
|
||||
available).
|
||||
|
||||
Returns:
|
||||
If this stream has been closed, an empty bytearray. Otherwise, the
|
||||
requested data.
|
||||
|
||||
"""
|
||||
return await self._outgoing.get(max_bytes)
|
||||
|
||||
def get_data_nowait(self, max_bytes: int | None = None) -> bytearray:
|
||||
"""Retrieves data from the internal buffer, but doesn't block.
|
||||
|
||||
See :meth:`get_data` for details.
|
||||
|
||||
Raises:
|
||||
trio.WouldBlock: if no data is available to retrieve.
|
||||
|
||||
"""
|
||||
return self._outgoing.get_nowait(max_bytes)
|
||||
|
||||
|
||||
@_util.final
|
||||
class MemoryReceiveStream(ReceiveStream):
|
||||
"""An in-memory :class:`~trio.abc.ReceiveStream`.
|
||||
|
||||
Args:
|
||||
receive_some_hook: An async function, or None. Called from
|
||||
:meth:`receive_some`. Can do whatever you like.
|
||||
close_hook: A synchronous function, or None. Called from :meth:`close`
|
||||
and :meth:`aclose`. Can do whatever you like.
|
||||
|
||||
.. attribute:: receive_some_hook
|
||||
close_hook
|
||||
|
||||
Both hooks are also exposed as attributes on the object, and you can
|
||||
change them at any time.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
receive_some_hook: AsyncHook | None = None,
|
||||
close_hook: SyncHook | None = None,
|
||||
) -> None:
|
||||
self._conflict_detector = _util.ConflictDetector(
|
||||
"another task is using this stream",
|
||||
)
|
||||
self._incoming = _UnboundedByteQueue()
|
||||
self._closed = False
|
||||
self.receive_some_hook = receive_some_hook
|
||||
self.close_hook = close_hook
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytearray:
|
||||
"""Calls the :attr:`receive_some_hook` (if any), and then retrieves
|
||||
data from the internal buffer, blocking if necessary.
|
||||
|
||||
"""
|
||||
# Execute two checkpoints so we have more of a chance to detect
|
||||
# buggy user code that calls this twice at the same time.
|
||||
with self._conflict_detector:
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
if self._closed:
|
||||
raise _core.ClosedResourceError
|
||||
if self.receive_some_hook is not None:
|
||||
await self.receive_some_hook()
|
||||
# self._incoming's closure state tracks whether we got an EOF.
|
||||
# self._closed tracks whether we, ourselves, are closed.
|
||||
# self.close() sends an EOF to wake us up and sets self._closed,
|
||||
# so after we wake up we have to check self._closed again.
|
||||
data = await self._incoming.get(max_bytes)
|
||||
if self._closed:
|
||||
raise _core.ClosedResourceError
|
||||
return data
|
||||
|
||||
def close(self) -> None:
|
||||
"""Discards any pending data from the internal buffer, and marks this
|
||||
stream as closed.
|
||||
|
||||
"""
|
||||
self._closed = True
|
||||
self._incoming.close_and_wipe()
|
||||
if self.close_hook is not None:
|
||||
self.close_hook()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Same as :meth:`close`, but async."""
|
||||
self.close()
|
||||
await _core.checkpoint()
|
||||
|
||||
def put_data(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""Appends the given data to the internal buffer."""
|
||||
self._incoming.put(data)
|
||||
|
||||
def put_eof(self) -> None:
|
||||
"""Adds an end-of-file marker to the internal buffer."""
|
||||
self._incoming.close()
|
||||
|
||||
|
||||
# TODO: investigate why this is necessary for the docs
|
||||
MemorySendStream.__module__ = MemorySendStream.__module__.replace(
|
||||
"._memory_streams", ""
|
||||
)
|
||||
MemoryReceiveStream.__module__ = MemoryReceiveStream.__module__.replace(
|
||||
"._memory_streams", ""
|
||||
)
|
||||
|
||||
|
||||
def memory_stream_pump(
|
||||
memory_send_stream: MemorySendStream,
|
||||
memory_receive_stream: MemoryReceiveStream,
|
||||
*,
|
||||
max_bytes: int | None = None,
|
||||
) -> bool:
|
||||
"""Take data out of the given :class:`MemorySendStream`'s internal buffer,
|
||||
and put it into the given :class:`MemoryReceiveStream`'s internal buffer.
|
||||
|
||||
Args:
|
||||
memory_send_stream (MemorySendStream): The stream to get data from.
|
||||
memory_receive_stream (MemoryReceiveStream): The stream to put data into.
|
||||
max_bytes (int or None): The maximum amount of data to transfer in this
|
||||
call, or None to transfer all available data.
|
||||
|
||||
Returns:
|
||||
True if it successfully transferred some data, or False if there was no
|
||||
data to transfer.
|
||||
|
||||
This is used to implement :func:`memory_stream_one_way_pair` and
|
||||
:func:`memory_stream_pair`; see the latter's docstring for an example
|
||||
of how you might use it yourself.
|
||||
|
||||
"""
|
||||
try:
|
||||
data = memory_send_stream.get_data_nowait(max_bytes)
|
||||
except _core.WouldBlock:
|
||||
return False
|
||||
try:
|
||||
if not data:
|
||||
memory_receive_stream.put_eof()
|
||||
else:
|
||||
memory_receive_stream.put_data(data)
|
||||
except _core.ClosedResourceError:
|
||||
raise _core.BrokenResourceError("MemoryReceiveStream was closed") from None
|
||||
return True
|
||||
|
||||
|
||||
def memory_stream_one_way_pair() -> tuple[MemorySendStream, MemoryReceiveStream]:
|
||||
"""Create a connected, pure-Python, unidirectional stream with infinite
|
||||
buffering and flexible configuration options.
|
||||
|
||||
You can think of this as being a no-operating-system-involved
|
||||
Trio-streamsified version of :func:`os.pipe` (except that :func:`os.pipe`
|
||||
returns the streams in the wrong order – we follow the superior convention
|
||||
that data flows from left to right).
|
||||
|
||||
Returns:
|
||||
A tuple (:class:`MemorySendStream`, :class:`MemoryReceiveStream`), where
|
||||
the :class:`MemorySendStream` has its hooks set up so that it calls
|
||||
:func:`memory_stream_pump` from its
|
||||
:attr:`~MemorySendStream.send_all_hook` and
|
||||
:attr:`~MemorySendStream.close_hook`.
|
||||
|
||||
The end result is that data automatically flows from the
|
||||
:class:`MemorySendStream` to the :class:`MemoryReceiveStream`. But you're
|
||||
also free to rearrange things however you like. For example, you can
|
||||
temporarily set the :attr:`~MemorySendStream.send_all_hook` to None if you
|
||||
want to simulate a stall in data transmission. Or see
|
||||
:func:`memory_stream_pair` for a more elaborate example.
|
||||
|
||||
"""
|
||||
send_stream = MemorySendStream()
|
||||
recv_stream = MemoryReceiveStream()
|
||||
|
||||
def pump_from_send_stream_to_recv_stream() -> None:
|
||||
memory_stream_pump(send_stream, recv_stream)
|
||||
|
||||
# await not used
|
||||
async def async_pump_from_send_stream_to_recv_stream() -> None: # noqa: RUF029
|
||||
pump_from_send_stream_to_recv_stream()
|
||||
|
||||
send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream
|
||||
send_stream.close_hook = pump_from_send_stream_to_recv_stream
|
||||
return send_stream, recv_stream
|
||||
|
||||
|
||||
def _make_stapled_pair(
|
||||
one_way_pair: Callable[[], tuple[SendStreamT, ReceiveStreamT]],
|
||||
) -> tuple[
|
||||
StapledStream[SendStreamT, ReceiveStreamT],
|
||||
StapledStream[SendStreamT, ReceiveStreamT],
|
||||
]:
|
||||
pipe1_send, pipe1_recv = one_way_pair()
|
||||
pipe2_send, pipe2_recv = one_way_pair()
|
||||
stream1 = StapledStream(pipe1_send, pipe2_recv)
|
||||
stream2 = StapledStream(pipe2_send, pipe1_recv)
|
||||
return stream1, stream2
|
||||
|
||||
|
||||
def memory_stream_pair() -> tuple[
|
||||
StapledStream[MemorySendStream, MemoryReceiveStream],
|
||||
StapledStream[MemorySendStream, MemoryReceiveStream],
|
||||
]:
|
||||
"""Create a connected, pure-Python, bidirectional stream with infinite
|
||||
buffering and flexible configuration options.
|
||||
|
||||
This is a convenience function that creates two one-way streams using
|
||||
:func:`memory_stream_one_way_pair`, and then uses
|
||||
:class:`~trio.StapledStream` to combine them into a single bidirectional
|
||||
stream.
|
||||
|
||||
This is like a no-operating-system-involved, Trio-streamsified version of
|
||||
:func:`socket.socketpair`.
|
||||
|
||||
Returns:
|
||||
A pair of :class:`~trio.StapledStream` objects that are connected so
|
||||
that data automatically flows from one to the other in both directions.
|
||||
|
||||
After creating a stream pair, you can send data back and forth, which is
|
||||
enough for simple tests::
|
||||
|
||||
left, right = memory_stream_pair()
|
||||
await left.send_all(b"123")
|
||||
assert await right.receive_some() == b"123"
|
||||
await right.send_all(b"456")
|
||||
assert await left.receive_some() == b"456"
|
||||
|
||||
But if you read the docs for :class:`~trio.StapledStream` and
|
||||
:func:`memory_stream_one_way_pair`, you'll see that all the pieces
|
||||
involved in wiring this up are public APIs, so you can adjust to suit the
|
||||
requirements of your tests. For example, here's how to tweak a stream so
|
||||
that data flowing from left to right trickles in one byte at a time (but
|
||||
data flowing from right to left proceeds at full speed)::
|
||||
|
||||
left, right = memory_stream_pair()
|
||||
async def trickle():
|
||||
# left is a StapledStream, and left.send_stream is a MemorySendStream
|
||||
# right is a StapledStream, and right.recv_stream is a MemoryReceiveStream
|
||||
while memory_stream_pump(left.send_stream, right.recv_stream, max_bytes=1):
|
||||
# Pause between each byte
|
||||
await trio.sleep(1)
|
||||
# Normally this send_all_hook calls memory_stream_pump directly without
|
||||
# passing in a max_bytes. We replace it with our custom version:
|
||||
left.send_stream.send_all_hook = trickle
|
||||
|
||||
And here's a simple test using our modified stream objects::
|
||||
|
||||
async def sender():
|
||||
await left.send_all(b"12345")
|
||||
await left.send_eof()
|
||||
|
||||
async def receiver():
|
||||
async for data in right:
|
||||
print(data)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(receiver)
|
||||
|
||||
By default, this will print ``b"12345"`` and then immediately exit; with
|
||||
our trickle stream it instead sleeps 1 second, then prints ``b"1"``, then
|
||||
sleeps 1 second, then prints ``b"2"``, etc.
|
||||
|
||||
Pro-tip: you can insert sleep calls (like in our example above) to
|
||||
manipulate the flow of data across tasks... and then use
|
||||
:class:`MockClock` and its :attr:`~MockClock.autojump_threshold`
|
||||
functionality to keep your test suite running quickly.
|
||||
|
||||
If you want to stress test a protocol implementation, one nice trick is to
|
||||
use the :mod:`random` module (preferably with a fixed seed) to move random
|
||||
numbers of bytes at a time, and insert random sleeps in between them. You
|
||||
can also set up a custom :attr:`~MemoryReceiveStream.receive_some_hook` if
|
||||
you want to manipulate things on the receiving side, and not just the
|
||||
sending side.
|
||||
|
||||
"""
|
||||
return _make_stapled_pair(memory_stream_one_way_pair)
|
||||
|
||||
|
||||
################################################################
|
||||
# In-memory streams - Lockstep version
|
||||
################################################################
|
||||
|
||||
|
||||
class _LockstepByteQueue:
|
||||
def __init__(self) -> None:
|
||||
self._data = bytearray()
|
||||
self._sender_closed = False
|
||||
self._receiver_closed = False
|
||||
self._receiver_waiting = False
|
||||
self._waiters = _core.ParkingLot()
|
||||
self._send_conflict_detector = _util.ConflictDetector(
|
||||
"another task is already sending",
|
||||
)
|
||||
self._receive_conflict_detector = _util.ConflictDetector(
|
||||
"another task is already receiving",
|
||||
)
|
||||
|
||||
def _something_happened(self) -> None:
|
||||
self._waiters.unpark_all()
|
||||
|
||||
# Always wakes up when one side is closed, because everyone always reacts
|
||||
# to that.
|
||||
async def _wait_for(self, fn: Callable[[], bool]) -> None:
|
||||
while True:
|
||||
if fn():
|
||||
break
|
||||
if self._sender_closed or self._receiver_closed:
|
||||
break
|
||||
await self._waiters.park()
|
||||
await _core.checkpoint()
|
||||
|
||||
def close_sender(self) -> None:
|
||||
self._sender_closed = True
|
||||
self._something_happened()
|
||||
|
||||
def close_receiver(self) -> None:
|
||||
self._receiver_closed = True
|
||||
self._something_happened()
|
||||
|
||||
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
||||
with self._send_conflict_detector:
|
||||
if self._sender_closed:
|
||||
raise _core.ClosedResourceError
|
||||
if self._receiver_closed:
|
||||
raise _core.BrokenResourceError
|
||||
assert not self._data
|
||||
self._data += data
|
||||
self._something_happened()
|
||||
await self._wait_for(lambda: self._data == b"")
|
||||
if self._sender_closed:
|
||||
raise _core.ClosedResourceError
|
||||
if self._data and self._receiver_closed:
|
||||
raise _core.BrokenResourceError
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
with self._send_conflict_detector:
|
||||
if self._sender_closed:
|
||||
raise _core.ClosedResourceError
|
||||
if self._receiver_closed:
|
||||
await _core.checkpoint()
|
||||
return
|
||||
await self._wait_for(lambda: self._receiver_waiting)
|
||||
if self._sender_closed:
|
||||
raise _core.ClosedResourceError
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
|
||||
with self._receive_conflict_detector:
|
||||
# Argument validation
|
||||
if max_bytes is not None:
|
||||
max_bytes = operator.index(max_bytes)
|
||||
if max_bytes < 1:
|
||||
raise ValueError("max_bytes must be >= 1")
|
||||
# State validation
|
||||
if self._receiver_closed:
|
||||
raise _core.ClosedResourceError
|
||||
# Wake wait_send_all_might_not_block and wait for data
|
||||
self._receiver_waiting = True
|
||||
self._something_happened()
|
||||
try:
|
||||
await self._wait_for(lambda: self._data != b"")
|
||||
finally:
|
||||
self._receiver_waiting = False
|
||||
if self._receiver_closed:
|
||||
raise _core.ClosedResourceError
|
||||
# Get data, possibly waking send_all
|
||||
if self._data:
|
||||
# Neat trick: if max_bytes is None, then obj[:max_bytes] is
|
||||
# the same as obj[:].
|
||||
got = self._data[:max_bytes]
|
||||
del self._data[:max_bytes]
|
||||
self._something_happened()
|
||||
return got
|
||||
else:
|
||||
assert self._sender_closed
|
||||
return b""
|
||||
|
||||
|
||||
class _LockstepSendStream(SendStream):
|
||||
def __init__(self, lbq: _LockstepByteQueue) -> None:
|
||||
self._lbq = lbq
|
||||
|
||||
def close(self) -> None:
|
||||
self._lbq.close_sender()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.close()
|
||||
await _core.checkpoint()
|
||||
|
||||
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
||||
await self._lbq.send_all(data)
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
await self._lbq.wait_send_all_might_not_block()
|
||||
|
||||
|
||||
class _LockstepReceiveStream(ReceiveStream):
|
||||
def __init__(self, lbq: _LockstepByteQueue) -> None:
|
||||
self._lbq = lbq
|
||||
|
||||
def close(self) -> None:
|
||||
self._lbq.close_receiver()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.close()
|
||||
await _core.checkpoint()
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
|
||||
return await self._lbq.receive_some(max_bytes)
|
||||
|
||||
|
||||
def lockstep_stream_one_way_pair() -> tuple[SendStream, ReceiveStream]:
|
||||
"""Create a connected, pure Python, unidirectional stream where data flows
|
||||
in lockstep.
|
||||
|
||||
Returns:
|
||||
A tuple
|
||||
(:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`).
|
||||
|
||||
This stream has *absolutely no* buffering. Each call to
|
||||
:meth:`~trio.abc.SendStream.send_all` will block until all the given data
|
||||
has been returned by a call to
|
||||
:meth:`~trio.abc.ReceiveStream.receive_some`.
|
||||
|
||||
This can be useful for testing flow control mechanisms in an extreme case,
|
||||
or for setting up "clogged" streams to use with
|
||||
:func:`check_one_way_stream` and friends.
|
||||
|
||||
In addition to fulfilling the :class:`~trio.abc.SendStream` and
|
||||
:class:`~trio.abc.ReceiveStream` interfaces, the return objects
|
||||
also have a synchronous ``close`` method.
|
||||
|
||||
"""
|
||||
|
||||
lbq = _LockstepByteQueue()
|
||||
return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq)
|
||||
|
||||
|
||||
def lockstep_stream_pair() -> tuple[
|
||||
StapledStream[SendStream, ReceiveStream],
|
||||
StapledStream[SendStream, ReceiveStream],
|
||||
]:
|
||||
"""Create a connected, pure-Python, bidirectional stream where data flows
|
||||
in lockstep.
|
||||
|
||||
Returns:
|
||||
A tuple (:class:`~trio.StapledStream`, :class:`~trio.StapledStream`).
|
||||
|
||||
This is a convenience function that creates two one-way streams using
|
||||
:func:`lockstep_stream_one_way_pair`, and then uses
|
||||
:class:`~trio.StapledStream` to combine them into a single bidirectional
|
||||
stream.
|
||||
|
||||
"""
|
||||
return _make_stapled_pair(lockstep_stream_one_way_pair)
|
||||
@@ -0,0 +1,36 @@
|
||||
from .. import socket as tsocket
|
||||
from .._highlevel_socket import SocketListener, SocketStream
|
||||
|
||||
|
||||
async def open_stream_to_socket_listener(
|
||||
socket_listener: SocketListener,
|
||||
) -> SocketStream:
|
||||
"""Connect to the given :class:`~trio.SocketListener`.
|
||||
|
||||
This is particularly useful in tests when you want to let a server pick
|
||||
its own port, and then connect to it::
|
||||
|
||||
listeners = await trio.open_tcp_listeners(0)
|
||||
client = await trio.testing.open_stream_to_socket_listener(listeners[0])
|
||||
|
||||
Args:
|
||||
socket_listener (~trio.SocketListener): The
|
||||
:class:`~trio.SocketListener` to connect to.
|
||||
|
||||
Returns:
|
||||
SocketStream: a stream connected to the given listener.
|
||||
|
||||
"""
|
||||
family = socket_listener.socket.family
|
||||
sockaddr = socket_listener.socket.getsockname()
|
||||
if family in (tsocket.AF_INET, tsocket.AF_INET6):
|
||||
sockaddr = list(sockaddr)
|
||||
if sockaddr[0] == "0.0.0.0":
|
||||
sockaddr[0] = "127.0.0.1"
|
||||
if sockaddr[0] == "::":
|
||||
sockaddr[0] = "::1"
|
||||
sockaddr = tuple(sockaddr)
|
||||
|
||||
sock = tsocket.socket(family=family)
|
||||
await sock.connect(sockaddr)
|
||||
return SocketStream(sock)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import attrs
|
||||
|
||||
from .. import Event, _core, _util
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
|
||||
@_util.final
|
||||
@attrs.define(eq=False, slots=False)
|
||||
class Sequencer:
|
||||
"""A convenience class for forcing code in different tasks to run in an
|
||||
explicit linear order.
|
||||
|
||||
Instances of this class implement a ``__call__`` method which returns an
|
||||
async context manager. The idea is that you pass a sequence number to
|
||||
``__call__`` to say where this block of code should go in the linear
|
||||
sequence. Block 0 starts immediately, and then block N doesn't start until
|
||||
block N-1 has finished.
|
||||
|
||||
Example:
|
||||
An extremely elaborate way to print the numbers 0-5, in order::
|
||||
|
||||
async def worker1(seq):
|
||||
async with seq(0):
|
||||
print(0)
|
||||
async with seq(4):
|
||||
print(4)
|
||||
|
||||
async def worker2(seq):
|
||||
async with seq(2):
|
||||
print(2)
|
||||
async with seq(5):
|
||||
print(5)
|
||||
|
||||
async def worker3(seq):
|
||||
async with seq(1):
|
||||
print(1)
|
||||
async with seq(3):
|
||||
print(3)
|
||||
|
||||
async def main():
|
||||
seq = trio.testing.Sequencer()
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(worker1, seq)
|
||||
nursery.start_soon(worker2, seq)
|
||||
nursery.start_soon(worker3, seq)
|
||||
|
||||
"""
|
||||
|
||||
_sequence_points: defaultdict[int, Event] = attrs.field(
|
||||
factory=lambda: defaultdict(Event),
|
||||
init=False,
|
||||
)
|
||||
_claimed: set[int] = attrs.field(factory=set, init=False)
|
||||
_broken: bool = attrs.field(default=False, init=False)
|
||||
|
||||
@asynccontextmanager
|
||||
async def __call__(self, position: int) -> AsyncIterator[None]:
|
||||
if position in self._claimed:
|
||||
raise RuntimeError(f"Attempted to reuse sequence point {position}")
|
||||
if self._broken:
|
||||
raise RuntimeError("sequence broken!")
|
||||
self._claimed.add(position)
|
||||
if position != 0:
|
||||
try:
|
||||
await self._sequence_points[position].wait()
|
||||
except _core.Cancelled:
|
||||
self._broken = True
|
||||
for event in self._sequence_points.values():
|
||||
event.set()
|
||||
raise RuntimeError(
|
||||
"Sequencer wait cancelled -- sequence broken",
|
||||
) from None
|
||||
else:
|
||||
if self._broken:
|
||||
raise RuntimeError("sequence broken!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._sequence_points[position + 1].set()
|
||||
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial, wraps
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from .. import _core
|
||||
from ..abc import Clock, Instrument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
ArgsT = ParamSpec("ArgsT")
|
||||
|
||||
|
||||
RetT = TypeVar("RetT")
|
||||
|
||||
|
||||
def trio_test(fn: Callable[ArgsT, Awaitable[RetT]]) -> Callable[ArgsT, RetT]:
|
||||
"""Converts an async test function to be synchronous, running via Trio.
|
||||
|
||||
Usage::
|
||||
|
||||
@trio_test
|
||||
async def test_whatever():
|
||||
await ...
|
||||
|
||||
If a pytest fixture is passed in that subclasses the :class:`~trio.abc.Clock` or
|
||||
:class:`~trio.abc.Instrument` ABCs, then those are passed to :meth:`trio.run()`.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
|
||||
__tracebackhide__ = True
|
||||
clocks = [c for c in kwargs.values() if isinstance(c, Clock)]
|
||||
if not clocks:
|
||||
clock = None
|
||||
elif len(clocks) == 1:
|
||||
clock = clocks[0]
|
||||
else:
|
||||
raise ValueError("too many clocks spoil the broth!")
|
||||
instruments = [i for i in kwargs.values() if isinstance(i, Instrument)]
|
||||
return _core.run(
|
||||
partial(fn, *args, **kwargs),
|
||||
clock=clock,
|
||||
instruments=instruments,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user