1016 lines
35 KiB
Python
1016 lines
35 KiB
Python
import errno
|
|
|
|
import pytest
|
|
import attr
|
|
|
|
import os
|
|
import socket as stdlib_socket
|
|
import inspect
|
|
import tempfile
|
|
import sys as _sys
|
|
from .._core.tests.tutil import creates_ipv6, binds_ipv6
|
|
from .. import _core
|
|
from .. import _socket as _tsocket
|
|
from .. import socket as tsocket
|
|
from .._socket import _NUMERIC_ONLY, _try_sync
|
|
from ..testing import assert_checkpoints, wait_all_tasks_blocked
|
|
|
|
################################################################
|
|
# utils
|
|
################################################################
|
|
|
|
|
|
class MonkeypatchedGAI:
|
|
def __init__(self, orig_getaddrinfo):
|
|
self._orig_getaddrinfo = orig_getaddrinfo
|
|
self._responses = {}
|
|
self.record = []
|
|
|
|
# get a normalized getaddrinfo argument tuple
|
|
def _frozenbind(self, *args, **kwargs):
|
|
sig = inspect.signature(self._orig_getaddrinfo)
|
|
bound = sig.bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
frozenbound = bound.args
|
|
assert not bound.kwargs
|
|
return frozenbound
|
|
|
|
def set(self, response, *args, **kwargs):
|
|
self._responses[self._frozenbind(*args, **kwargs)] = response
|
|
|
|
def getaddrinfo(self, *args, **kwargs):
|
|
bound = self._frozenbind(*args, **kwargs)
|
|
self.record.append(bound)
|
|
if bound in self._responses:
|
|
return self._responses[bound]
|
|
elif bound[-1] & stdlib_socket.AI_NUMERICHOST:
|
|
return self._orig_getaddrinfo(*args, **kwargs)
|
|
else:
|
|
raise RuntimeError("gai called with unexpected arguments {}".format(bound))
|
|
|
|
|
|
@pytest.fixture
|
|
def monkeygai(monkeypatch):
|
|
controller = MonkeypatchedGAI(stdlib_socket.getaddrinfo)
|
|
monkeypatch.setattr(stdlib_socket, "getaddrinfo", controller.getaddrinfo)
|
|
return controller
|
|
|
|
|
|
async def test__try_sync():
|
|
with assert_checkpoints():
|
|
async with _try_sync():
|
|
pass
|
|
|
|
with assert_checkpoints():
|
|
with pytest.raises(KeyError):
|
|
async with _try_sync():
|
|
raise KeyError
|
|
|
|
async with _try_sync():
|
|
raise BlockingIOError
|
|
|
|
def _is_ValueError(exc):
|
|
return isinstance(exc, ValueError)
|
|
|
|
async with _try_sync(_is_ValueError):
|
|
raise ValueError
|
|
|
|
with assert_checkpoints():
|
|
with pytest.raises(BlockingIOError):
|
|
async with _try_sync(_is_ValueError):
|
|
raise BlockingIOError
|
|
|
|
|
|
################################################################
|
|
# basic re-exports
|
|
################################################################
|
|
|
|
|
|
def test_socket_has_some_reexports():
|
|
assert tsocket.SOL_SOCKET == stdlib_socket.SOL_SOCKET
|
|
assert tsocket.TCP_NODELAY == stdlib_socket.TCP_NODELAY
|
|
assert tsocket.gaierror == stdlib_socket.gaierror
|
|
assert tsocket.ntohs == stdlib_socket.ntohs
|
|
|
|
|
|
################################################################
|
|
# name resolution
|
|
################################################################
|
|
|
|
|
|
async def test_getaddrinfo(monkeygai):
|
|
def check(got, expected):
|
|
# win32 returns 0 for the proto field
|
|
# musl and glibc have inconsistent handling of the canonical name
|
|
# field (https://github.com/python-trio/trio/issues/1499)
|
|
# Neither field gets used much and there isn't much opportunity for us
|
|
# to mess them up, so we don't bother checking them here
|
|
def interesting_fields(gai_tup):
|
|
# (family, type, proto, canonname, sockaddr)
|
|
family, type, proto, canonname, sockaddr = gai_tup
|
|
return (family, type, sockaddr)
|
|
|
|
def filtered(gai_list):
|
|
return [interesting_fields(gai_tup) for gai_tup in gai_list]
|
|
|
|
assert filtered(got) == filtered(expected)
|
|
|
|
# Simple non-blocking non-error cases, ipv4 and ipv6:
|
|
with assert_checkpoints():
|
|
res = await tsocket.getaddrinfo("127.0.0.1", "12345", type=tsocket.SOCK_STREAM)
|
|
|
|
check(
|
|
res,
|
|
[
|
|
(
|
|
tsocket.AF_INET, # 127.0.0.1 is ipv4
|
|
tsocket.SOCK_STREAM,
|
|
tsocket.IPPROTO_TCP,
|
|
"",
|
|
("127.0.0.1", 12345),
|
|
),
|
|
],
|
|
)
|
|
|
|
with assert_checkpoints():
|
|
res = await tsocket.getaddrinfo("::1", "12345", type=tsocket.SOCK_DGRAM)
|
|
check(
|
|
res,
|
|
[
|
|
(
|
|
tsocket.AF_INET6,
|
|
tsocket.SOCK_DGRAM,
|
|
tsocket.IPPROTO_UDP,
|
|
"",
|
|
("::1", 12345, 0, 0),
|
|
),
|
|
],
|
|
)
|
|
|
|
monkeygai.set("x", b"host", "port", family=0, type=0, proto=0, flags=0)
|
|
with assert_checkpoints():
|
|
res = await tsocket.getaddrinfo("host", "port")
|
|
assert res == "x"
|
|
assert monkeygai.record[-1] == (b"host", "port", 0, 0, 0, 0)
|
|
|
|
# check raising an error from a non-blocking getaddrinfo
|
|
with assert_checkpoints():
|
|
with pytest.raises(tsocket.gaierror) as excinfo:
|
|
await tsocket.getaddrinfo("::1", "12345", type=-1)
|
|
# Linux + glibc, Windows
|
|
expected_errnos = {tsocket.EAI_SOCKTYPE}
|
|
# Linux + musl
|
|
expected_errnos.add(tsocket.EAI_SERVICE)
|
|
# macOS
|
|
if hasattr(tsocket, "EAI_BADHINTS"):
|
|
expected_errnos.add(tsocket.EAI_BADHINTS)
|
|
assert excinfo.value.errno in expected_errnos
|
|
|
|
# check raising an error from a blocking getaddrinfo (exploits the fact
|
|
# that monkeygai raises if it gets a non-numeric request it hasn't been
|
|
# given an answer for)
|
|
with assert_checkpoints():
|
|
with pytest.raises(RuntimeError):
|
|
await tsocket.getaddrinfo("asdf", "12345")
|
|
|
|
|
|
async def test_getnameinfo():
|
|
# Trivial test:
|
|
ni_numeric = stdlib_socket.NI_NUMERICHOST | stdlib_socket.NI_NUMERICSERV
|
|
with assert_checkpoints():
|
|
got = await tsocket.getnameinfo(("127.0.0.1", 1234), ni_numeric)
|
|
assert got == ("127.0.0.1", "1234")
|
|
|
|
# getnameinfo requires a numeric address as input:
|
|
with assert_checkpoints():
|
|
with pytest.raises(tsocket.gaierror):
|
|
await tsocket.getnameinfo(("google.com", 80), 0)
|
|
|
|
with assert_checkpoints():
|
|
with pytest.raises(tsocket.gaierror):
|
|
await tsocket.getnameinfo(("localhost", 80), 0)
|
|
|
|
# Blocking call to get expected values:
|
|
host, service = stdlib_socket.getnameinfo(("127.0.0.1", 80), 0)
|
|
|
|
# Some working calls:
|
|
got = await tsocket.getnameinfo(("127.0.0.1", 80), 0)
|
|
assert got == (host, service)
|
|
|
|
got = await tsocket.getnameinfo(("127.0.0.1", 80), tsocket.NI_NUMERICHOST)
|
|
assert got == ("127.0.0.1", service)
|
|
|
|
got = await tsocket.getnameinfo(("127.0.0.1", 80), tsocket.NI_NUMERICSERV)
|
|
assert got == (host, "80")
|
|
|
|
|
|
################################################################
|
|
# constructors
|
|
################################################################
|
|
|
|
|
|
async def test_from_stdlib_socket():
|
|
sa, sb = stdlib_socket.socketpair()
|
|
assert not isinstance(sa, tsocket.SocketType)
|
|
with sa, sb:
|
|
ta = tsocket.from_stdlib_socket(sa)
|
|
assert isinstance(ta, tsocket.SocketType)
|
|
assert sa.fileno() == ta.fileno()
|
|
await ta.send(b"x")
|
|
assert sb.recv(1) == b"x"
|
|
|
|
# rejects other types
|
|
with pytest.raises(TypeError):
|
|
tsocket.from_stdlib_socket(1)
|
|
|
|
class MySocket(stdlib_socket.socket):
|
|
pass
|
|
|
|
with MySocket() as mysock:
|
|
with pytest.raises(TypeError):
|
|
tsocket.from_stdlib_socket(mysock)
|
|
|
|
|
|
async def test_from_fd():
|
|
sa, sb = stdlib_socket.socketpair()
|
|
ta = tsocket.fromfd(sa.fileno(), sa.family, sa.type, sa.proto)
|
|
with sa, sb, ta:
|
|
assert ta.fileno() != sa.fileno()
|
|
await ta.send(b"x")
|
|
assert sb.recv(3) == b"x"
|
|
|
|
|
|
async def test_socketpair_simple():
|
|
async def child(sock):
|
|
print("sending hello")
|
|
await sock.send(b"h")
|
|
assert await sock.recv(1) == b"h"
|
|
|
|
a, b = tsocket.socketpair()
|
|
with a, b:
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(child, a)
|
|
nursery.start_soon(child, b)
|
|
|
|
|
|
@pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only")
|
|
async def test_fromshare():
|
|
a, b = tsocket.socketpair()
|
|
with a, b:
|
|
# share with ourselves
|
|
shared = a.share(os.getpid())
|
|
a2 = tsocket.fromshare(shared)
|
|
with a2:
|
|
assert a.fileno() != a2.fileno()
|
|
await a2.send(b"x")
|
|
assert await b.recv(1) == b"x"
|
|
|
|
|
|
async def test_socket():
|
|
with tsocket.socket() as s:
|
|
assert isinstance(s, tsocket.SocketType)
|
|
assert s.family == tsocket.AF_INET
|
|
|
|
|
|
@creates_ipv6
|
|
async def test_socket_v6():
|
|
with tsocket.socket(tsocket.AF_INET6, tsocket.SOCK_DGRAM) as s:
|
|
assert isinstance(s, tsocket.SocketType)
|
|
assert s.family == tsocket.AF_INET6
|
|
|
|
|
|
@pytest.mark.skipif(not _sys.platform == "linux", reason="linux only")
|
|
async def test_sniff_sockopts():
|
|
from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM
|
|
|
|
# generate the combinations of families/types we're testing:
|
|
sockets = []
|
|
for family in [AF_INET, AF_INET6]:
|
|
for type in [SOCK_DGRAM, SOCK_STREAM]:
|
|
sockets.append(stdlib_socket.socket(family, type))
|
|
for socket in sockets:
|
|
# regular Trio socket constructor
|
|
tsocket_socket = tsocket.socket(fileno=socket.fileno())
|
|
# check family / type for correctness:
|
|
assert tsocket_socket.family == socket.family
|
|
assert tsocket_socket.type == socket.type
|
|
tsocket_socket.detach()
|
|
|
|
# fromfd constructor
|
|
tsocket_from_fd = tsocket.fromfd(socket.fileno(), AF_INET, SOCK_STREAM)
|
|
# check family / type for correctness:
|
|
assert tsocket_from_fd.family == socket.family
|
|
assert tsocket_from_fd.type == socket.type
|
|
tsocket_from_fd.close()
|
|
|
|
socket.close()
|
|
|
|
|
|
################################################################
|
|
# _SocketType
|
|
################################################################
|
|
|
|
|
|
async def test_SocketType_basics():
|
|
sock = tsocket.socket()
|
|
with sock as cm_enter_value:
|
|
assert cm_enter_value is sock
|
|
assert isinstance(sock.fileno(), int)
|
|
assert not sock.get_inheritable()
|
|
sock.set_inheritable(True)
|
|
assert sock.get_inheritable()
|
|
|
|
sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
|
|
assert not sock.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
|
sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, True)
|
|
assert sock.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
|
# closed sockets have fileno() == -1
|
|
assert sock.fileno() == -1
|
|
|
|
# smoke test
|
|
repr(sock)
|
|
|
|
# detach
|
|
with tsocket.socket() as sock:
|
|
fd = sock.fileno()
|
|
assert sock.detach() == fd
|
|
assert sock.fileno() == -1
|
|
|
|
# close
|
|
sock = tsocket.socket()
|
|
assert sock.fileno() >= 0
|
|
sock.close()
|
|
assert sock.fileno() == -1
|
|
|
|
# share was tested above together with fromshare
|
|
|
|
# check __dir__
|
|
assert "family" in dir(sock)
|
|
assert "recv" in dir(sock)
|
|
assert "setsockopt" in dir(sock)
|
|
|
|
# our __getattr__ handles unknown names
|
|
with pytest.raises(AttributeError):
|
|
sock.asdf
|
|
|
|
# type family proto
|
|
stdlib_sock = stdlib_socket.socket()
|
|
sock = tsocket.from_stdlib_socket(stdlib_sock)
|
|
assert sock.type == _tsocket.real_socket_type(stdlib_sock.type)
|
|
assert sock.family == stdlib_sock.family
|
|
assert sock.proto == stdlib_sock.proto
|
|
sock.close()
|
|
|
|
|
|
async def test_SocketType_dup():
|
|
a, b = tsocket.socketpair()
|
|
with a, b:
|
|
a2 = a.dup()
|
|
with a2:
|
|
assert isinstance(a2, tsocket.SocketType)
|
|
assert a2.fileno() != a.fileno()
|
|
a.close()
|
|
await a2.send(b"x")
|
|
assert await b.recv(1) == b"x"
|
|
|
|
|
|
async def test_SocketType_shutdown():
|
|
a, b = tsocket.socketpair()
|
|
with a, b:
|
|
await a.send(b"x")
|
|
assert await b.recv(1) == b"x"
|
|
assert not a.did_shutdown_SHUT_WR
|
|
assert not b.did_shutdown_SHUT_WR
|
|
a.shutdown(tsocket.SHUT_WR)
|
|
assert a.did_shutdown_SHUT_WR
|
|
assert not b.did_shutdown_SHUT_WR
|
|
assert await b.recv(1) == b""
|
|
await b.send(b"y")
|
|
assert await a.recv(1) == b"y"
|
|
|
|
a, b = tsocket.socketpair()
|
|
with a, b:
|
|
assert not a.did_shutdown_SHUT_WR
|
|
a.shutdown(tsocket.SHUT_RD)
|
|
assert not a.did_shutdown_SHUT_WR
|
|
|
|
a, b = tsocket.socketpair()
|
|
with a, b:
|
|
assert not a.did_shutdown_SHUT_WR
|
|
a.shutdown(tsocket.SHUT_RDWR)
|
|
assert a.did_shutdown_SHUT_WR
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"address, socket_type",
|
|
[
|
|
("127.0.0.1", tsocket.AF_INET),
|
|
pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6),
|
|
],
|
|
)
|
|
async def test_SocketType_simple_server(address, socket_type):
|
|
# listen, bind, accept, connect, getpeername, getsockname
|
|
listener = tsocket.socket(socket_type)
|
|
client = tsocket.socket(socket_type)
|
|
with listener, client:
|
|
await listener.bind((address, 0))
|
|
listener.listen(20)
|
|
addr = listener.getsockname()[:2]
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.connect, addr)
|
|
server, client_addr = await listener.accept()
|
|
with server:
|
|
assert client_addr == server.getpeername() == client.getsockname()
|
|
await server.send(b"x")
|
|
assert await client.recv(1) == b"x"
|
|
|
|
|
|
async def test_SocketType_is_readable():
|
|
a, b = tsocket.socketpair()
|
|
with a, b:
|
|
assert not a.is_readable()
|
|
await b.send(b"x")
|
|
await _core.wait_readable(a)
|
|
assert a.is_readable()
|
|
assert await a.recv(1) == b"x"
|
|
assert not a.is_readable()
|
|
|
|
|
|
# On some macOS systems, getaddrinfo likes to return V4-mapped addresses even
|
|
# when we *don't* pass AI_V4MAPPED.
|
|
# https://github.com/python-trio/trio/issues/580
|
|
def gai_without_v4mapped_is_buggy(): # pragma: no cover
|
|
try:
|
|
stdlib_socket.getaddrinfo("1.2.3.4", 0, family=stdlib_socket.AF_INET6)
|
|
except stdlib_socket.gaierror:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
@attr.s
|
|
class Addresses:
|
|
bind_all = attr.ib()
|
|
localhost = attr.ib()
|
|
arbitrary = attr.ib()
|
|
broadcast = attr.ib()
|
|
|
|
|
|
# Direct thorough tests of the implicit resolver helpers
|
|
@pytest.mark.parametrize(
|
|
"socket_type, addrs",
|
|
[
|
|
(
|
|
tsocket.AF_INET,
|
|
Addresses(
|
|
bind_all="0.0.0.0",
|
|
localhost="127.0.0.1",
|
|
arbitrary="1.2.3.4",
|
|
broadcast="255.255.255.255",
|
|
),
|
|
),
|
|
pytest.param(
|
|
tsocket.AF_INET6,
|
|
Addresses(
|
|
bind_all="::",
|
|
localhost="::1",
|
|
arbitrary="1::2",
|
|
broadcast="::ffff:255.255.255.255",
|
|
),
|
|
marks=creates_ipv6,
|
|
),
|
|
],
|
|
)
|
|
async def test_SocketType_resolve(socket_type, addrs):
|
|
v6 = socket_type == tsocket.AF_INET6
|
|
|
|
def pad(addr):
|
|
if v6:
|
|
while len(addr) < 4:
|
|
addr += (0,)
|
|
return addr
|
|
|
|
def assert_eq(actual, expected):
|
|
assert pad(expected) == pad(actual)
|
|
|
|
with tsocket.socket(family=socket_type) as sock:
|
|
# For some reason the stdlib special-cases "" to pass NULL to
|
|
# getaddrinfo They also error out on None, but whatever, None is much
|
|
# more consistent, so we accept it too.
|
|
for null in [None, ""]:
|
|
got = await sock._resolve_local_address_nocp((null, 80))
|
|
assert_eq(got, (addrs.bind_all, 80))
|
|
got = await sock._resolve_remote_address_nocp((null, 80))
|
|
assert_eq(got, (addrs.localhost, 80))
|
|
|
|
# AI_PASSIVE only affects the wildcard address, so for everything else
|
|
# _resolve_local_address_nocp and _resolve_remote_address_nocp should
|
|
# work the same:
|
|
for resolver in ["_resolve_local_address_nocp", "_resolve_remote_address_nocp"]:
|
|
|
|
async def res(*args):
|
|
return await getattr(sock, resolver)(*args)
|
|
|
|
assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80))
|
|
if v6:
|
|
# Check handling of different length ipv6 address tuples
|
|
assert_eq(await res(("1::2", 80)), ("1::2", 80, 0, 0))
|
|
assert_eq(await res(("1::2", 80, 0)), ("1::2", 80, 0, 0))
|
|
assert_eq(await res(("1::2", 80, 0, 0)), ("1::2", 80, 0, 0))
|
|
# Non-zero flowinfo/scopeid get passed through
|
|
assert_eq(await res(("1::2", 80, 1)), ("1::2", 80, 1, 0))
|
|
assert_eq(await res(("1::2", 80, 1, 2)), ("1::2", 80, 1, 2))
|
|
|
|
# And again with a string port, as a trick to avoid the
|
|
# already-resolved address fastpath and make sure we call
|
|
# getaddrinfo
|
|
assert_eq(await res(("1::2", "80")), ("1::2", 80, 0, 0))
|
|
assert_eq(await res(("1::2", "80", 0)), ("1::2", 80, 0, 0))
|
|
assert_eq(await res(("1::2", "80", 0, 0)), ("1::2", 80, 0, 0))
|
|
assert_eq(await res(("1::2", "80", 1)), ("1::2", 80, 1, 0))
|
|
assert_eq(await res(("1::2", "80", 1, 2)), ("1::2", 80, 1, 2))
|
|
|
|
# V4 mapped addresses resolved if V6ONLY is False
|
|
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, False)
|
|
assert_eq(await res(("1.2.3.4", "http")), ("::ffff:1.2.3.4", 80))
|
|
|
|
# Check the <broadcast> special case, because why not
|
|
assert_eq(await res(("<broadcast>", 123)), (addrs.broadcast, 123))
|
|
|
|
# But not if it's true (at least on systems where getaddrinfo works
|
|
# correctly)
|
|
if v6 and not gai_without_v4mapped_is_buggy():
|
|
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, True)
|
|
with pytest.raises(tsocket.gaierror) as excinfo:
|
|
await res(("1.2.3.4", 80))
|
|
# Windows, macOS
|
|
expected_errnos = {tsocket.EAI_NONAME}
|
|
# Linux
|
|
if hasattr(tsocket, "EAI_ADDRFAMILY"):
|
|
expected_errnos.add(tsocket.EAI_ADDRFAMILY)
|
|
assert excinfo.value.errno in expected_errnos
|
|
|
|
# A family where we know nothing about the addresses, so should just
|
|
# pass them through. This should work on Linux, which is enough to
|
|
# smoke test the basic functionality...
|
|
try:
|
|
netlink_sock = tsocket.socket(
|
|
family=tsocket.AF_NETLINK, type=tsocket.SOCK_DGRAM
|
|
)
|
|
except (AttributeError, OSError):
|
|
pass
|
|
else:
|
|
assert await getattr(netlink_sock, resolver)("asdf") == "asdf"
|
|
netlink_sock.close()
|
|
|
|
with pytest.raises(ValueError):
|
|
await res("1.2.3.4")
|
|
with pytest.raises(ValueError):
|
|
await res(("1.2.3.4",))
|
|
with pytest.raises(ValueError):
|
|
if v6:
|
|
await res(("1.2.3.4", 80, 0, 0, 0))
|
|
else:
|
|
await res(("1.2.3.4", 80, 0, 0))
|
|
|
|
|
|
async def test_SocketType_unresolved_names():
|
|
with tsocket.socket() as sock:
|
|
await sock.bind(("localhost", 0))
|
|
assert sock.getsockname()[0] == "127.0.0.1"
|
|
sock.listen(10)
|
|
|
|
with tsocket.socket() as sock2:
|
|
await sock2.connect(("localhost", sock.getsockname()[1]))
|
|
assert sock2.getpeername() == sock.getsockname()
|
|
|
|
# check gaierror propagates out
|
|
with tsocket.socket() as sock:
|
|
with pytest.raises(tsocket.gaierror):
|
|
# definitely not a valid request
|
|
await sock.bind(("1.2:3", -1))
|
|
|
|
|
|
# This tests all the complicated paths through _nonblocking_helper, using recv
|
|
# as a stand-in for all the methods that use _nonblocking_helper.
|
|
async def test_SocketType_non_blocking_paths():
|
|
a, b = stdlib_socket.socketpair()
|
|
with a, b:
|
|
ta = tsocket.from_stdlib_socket(a)
|
|
b.setblocking(False)
|
|
|
|
# cancel before even calling
|
|
b.send(b"1")
|
|
with _core.CancelScope() as cscope:
|
|
cscope.cancel()
|
|
with assert_checkpoints():
|
|
with pytest.raises(_core.Cancelled):
|
|
await ta.recv(10)
|
|
# immediate success (also checks that the previous attempt didn't
|
|
# actually read anything)
|
|
with assert_checkpoints():
|
|
await ta.recv(10) == b"1"
|
|
# immediate failure
|
|
with assert_checkpoints():
|
|
with pytest.raises(TypeError):
|
|
await ta.recv("haha")
|
|
# block then succeed
|
|
|
|
async def do_successful_blocking_recv():
|
|
with assert_checkpoints():
|
|
assert await ta.recv(10) == b"2"
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(do_successful_blocking_recv)
|
|
await wait_all_tasks_blocked()
|
|
b.send(b"2")
|
|
# block then cancelled
|
|
|
|
async def do_cancelled_blocking_recv():
|
|
with assert_checkpoints():
|
|
with pytest.raises(_core.Cancelled):
|
|
await ta.recv(10)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(do_cancelled_blocking_recv)
|
|
await wait_all_tasks_blocked()
|
|
nursery.cancel_scope.cancel()
|
|
# Okay, here's the trickiest one: we want to exercise the path where
|
|
# the task is signaled to wake, goes to recv, but then the recv fails,
|
|
# so it has to go back to sleep and try again. Strategy: have two
|
|
# tasks waiting on two sockets (to work around the rule against having
|
|
# two tasks waiting on the same socket), wake them both up at the same
|
|
# time, and whichever one runs first "steals" the data from the
|
|
# other:
|
|
tb = tsocket.from_stdlib_socket(b)
|
|
|
|
async def t1():
|
|
with assert_checkpoints():
|
|
assert await ta.recv(1) == b"a"
|
|
with assert_checkpoints():
|
|
assert await tb.recv(1) == b"b"
|
|
|
|
async def t2():
|
|
with assert_checkpoints():
|
|
assert await tb.recv(1) == b"b"
|
|
with assert_checkpoints():
|
|
assert await ta.recv(1) == b"a"
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(t1)
|
|
nursery.start_soon(t2)
|
|
await wait_all_tasks_blocked()
|
|
a.send(b"b")
|
|
b.send(b"a")
|
|
await wait_all_tasks_blocked()
|
|
a.send(b"b")
|
|
b.send(b"a")
|
|
|
|
|
|
# This tests the complicated paths through connect
|
|
async def test_SocketType_connect_paths():
|
|
with tsocket.socket() as sock:
|
|
with pytest.raises(ValueError):
|
|
# Should be a tuple
|
|
await sock.connect("localhost")
|
|
|
|
# cancelled before we start
|
|
with tsocket.socket() as sock:
|
|
with _core.CancelScope() as cancel_scope:
|
|
cancel_scope.cancel()
|
|
with pytest.raises(_core.Cancelled):
|
|
await sock.connect(("127.0.0.1", 80))
|
|
|
|
# Cancelled in between the connect() call and the connect completing
|
|
with _core.CancelScope() as cancel_scope:
|
|
with tsocket.socket() as sock, tsocket.socket() as listener:
|
|
await listener.bind(("127.0.0.1", 0))
|
|
listener.listen()
|
|
|
|
# Swap in our weird subclass under the trio.socket._SocketType's
|
|
# nose -- and then swap it back out again before we hit
|
|
# wait_socket_writable, which insists on a real socket.
|
|
class CancelSocket(stdlib_socket.socket):
|
|
def connect(self, *args, **kwargs):
|
|
cancel_scope.cancel()
|
|
sock._sock = stdlib_socket.fromfd(
|
|
self.detach(), self.family, self.type
|
|
)
|
|
sock._sock.connect(*args, **kwargs)
|
|
# If connect *doesn't* raise, then pretend it did
|
|
raise BlockingIOError # pragma: no cover
|
|
|
|
sock._sock.close()
|
|
sock._sock = CancelSocket()
|
|
|
|
with assert_checkpoints():
|
|
with pytest.raises(_core.Cancelled):
|
|
await sock.connect(listener.getsockname())
|
|
assert sock.fileno() == -1
|
|
|
|
# Failed connect (hopefully after raising BlockingIOError)
|
|
with tsocket.socket() as sock:
|
|
with pytest.raises(OSError):
|
|
# TCP port 2 is not assigned. Pretty sure nothing will be
|
|
# listening there. (We used to bind a port and then *not* call
|
|
# listen() to ensure nothing was listening there, but it turns
|
|
# out on macOS if you do this it takes 30 seconds for the
|
|
# connect to fail. Really. Also if you use a non-routable
|
|
# address. This way fails instantly though. As long as nothing
|
|
# is listening on port 2.)
|
|
await sock.connect(("127.0.0.1", 2))
|
|
|
|
|
|
async def test_resolve_remote_address_exception_closes_socket():
|
|
# Here we are testing issue 247, any cancellation will leave the socket closed
|
|
with _core.CancelScope() as cancel_scope:
|
|
with tsocket.socket() as sock:
|
|
|
|
async def _resolve_remote_address_nocp(self, *args, **kwargs):
|
|
cancel_scope.cancel()
|
|
await _core.checkpoint()
|
|
|
|
sock._resolve_remote_address_nocp = _resolve_remote_address_nocp
|
|
with assert_checkpoints():
|
|
with pytest.raises(_core.Cancelled):
|
|
await sock.connect("")
|
|
assert sock.fileno() == -1
|
|
|
|
|
|
async def test_send_recv_variants():
|
|
a, b = tsocket.socketpair()
|
|
with a, b:
|
|
# recv, including with flags
|
|
assert await a.send(b"x") == 1
|
|
assert await b.recv(10, tsocket.MSG_PEEK) == b"x"
|
|
assert await b.recv(10) == b"x"
|
|
|
|
# recv_into
|
|
await a.send(b"x")
|
|
buf = bytearray(10)
|
|
await b.recv_into(buf)
|
|
assert buf == b"x" + b"\x00" * 9
|
|
|
|
if hasattr(a, "sendmsg"):
|
|
assert await a.sendmsg([b"xxx"], []) == 3
|
|
assert await b.recv(10) == b"xxx"
|
|
|
|
a = tsocket.socket(type=tsocket.SOCK_DGRAM)
|
|
b = tsocket.socket(type=tsocket.SOCK_DGRAM)
|
|
with a, b:
|
|
await a.bind(("127.0.0.1", 0))
|
|
await b.bind(("127.0.0.1", 0))
|
|
|
|
targets = [b.getsockname(), ("localhost", b.getsockname()[1])]
|
|
|
|
# recvfrom + sendto, with and without names
|
|
for target in targets:
|
|
assert await a.sendto(b"xxx", target) == 3
|
|
(data, addr) = await b.recvfrom(10)
|
|
assert data == b"xxx"
|
|
assert addr == a.getsockname()
|
|
|
|
# sendto + flags
|
|
#
|
|
# I can't find any flags that send() accepts... on Linux at least
|
|
# passing MSG_MORE to send_some on a connected UDP socket seems to
|
|
# just be ignored.
|
|
#
|
|
# But there's no MSG_MORE on Windows or macOS. I guess send_some flags
|
|
# are really not very useful, but at least this tests them a bit.
|
|
if hasattr(tsocket, "MSG_MORE"):
|
|
await a.sendto(b"xxx", tsocket.MSG_MORE, b.getsockname())
|
|
await a.sendto(b"yyy", tsocket.MSG_MORE, b.getsockname())
|
|
await a.sendto(b"zzz", b.getsockname())
|
|
(data, addr) = await b.recvfrom(10)
|
|
assert data == b"xxxyyyzzz"
|
|
assert addr == a.getsockname()
|
|
|
|
# recvfrom_into
|
|
assert await a.sendto(b"xxx", b.getsockname()) == 3
|
|
buf = bytearray(10)
|
|
(nbytes, addr) = await b.recvfrom_into(buf)
|
|
assert nbytes == 3
|
|
assert buf == b"xxx" + b"\x00" * 7
|
|
assert addr == a.getsockname()
|
|
|
|
if hasattr(b, "recvmsg"):
|
|
assert await a.sendto(b"xxx", b.getsockname()) == 3
|
|
(data, ancdata, msg_flags, addr) = await b.recvmsg(10)
|
|
assert data == b"xxx"
|
|
assert ancdata == []
|
|
assert msg_flags == 0
|
|
assert addr == a.getsockname()
|
|
|
|
if hasattr(b, "recvmsg_into"):
|
|
assert await a.sendto(b"xyzw", b.getsockname()) == 4
|
|
buf1 = bytearray(2)
|
|
buf2 = bytearray(3)
|
|
ret = await b.recvmsg_into([buf1, buf2])
|
|
(nbytes, ancdata, msg_flags, addr) = ret
|
|
assert nbytes == 4
|
|
assert buf1 == b"xy"
|
|
assert buf2 == b"zw" + b"\x00"
|
|
assert ancdata == []
|
|
assert msg_flags == 0
|
|
assert addr == a.getsockname()
|
|
|
|
if hasattr(a, "sendmsg"):
|
|
for target in targets:
|
|
assert await a.sendmsg([b"x", b"yz"], [], 0, target) == 3
|
|
assert await b.recvfrom(10) == (b"xyz", a.getsockname())
|
|
|
|
a = tsocket.socket(type=tsocket.SOCK_DGRAM)
|
|
b = tsocket.socket(type=tsocket.SOCK_DGRAM)
|
|
with a, b:
|
|
await b.bind(("127.0.0.1", 0))
|
|
await a.connect(b.getsockname())
|
|
# send on a connected udp socket; each call creates a separate
|
|
# datagram
|
|
await a.send(b"xxx")
|
|
await a.send(b"yyy")
|
|
assert await b.recv(10) == b"xxx"
|
|
assert await b.recv(10) == b"yyy"
|
|
|
|
|
|
async def test_idna(monkeygai):
|
|
# This is the encoding for "faß.de", which uses one of the characters that
|
|
# IDNA 2003 handles incorrectly:
|
|
monkeygai.set("ok faß.de", b"xn--fa-hia.de", 80)
|
|
monkeygai.set("ok ::1", "::1", 80, flags=_NUMERIC_ONLY)
|
|
monkeygai.set("ok ::1", b"::1", 80, flags=_NUMERIC_ONLY)
|
|
# Some things that should not reach the underlying socket.getaddrinfo:
|
|
monkeygai.set("bad", "fass.de", 80)
|
|
# We always call socket.getaddrinfo with bytes objects:
|
|
monkeygai.set("bad", "xn--fa-hia.de", 80)
|
|
|
|
assert "ok ::1" == await tsocket.getaddrinfo("::1", 80)
|
|
assert "ok ::1" == await tsocket.getaddrinfo(b"::1", 80)
|
|
assert "ok faß.de" == await tsocket.getaddrinfo("faß.de", 80)
|
|
assert "ok faß.de" == await tsocket.getaddrinfo("xn--fa-hia.de", 80)
|
|
assert "ok faß.de" == await tsocket.getaddrinfo(b"xn--fa-hia.de", 80)
|
|
|
|
|
|
async def test_getprotobyname():
|
|
# These are the constants used in IP header fields, so the numeric values
|
|
# had *better* be stable across systems...
|
|
assert await tsocket.getprotobyname("udp") == 17
|
|
assert await tsocket.getprotobyname("tcp") == 6
|
|
|
|
|
|
async def test_custom_hostname_resolver(monkeygai):
|
|
class CustomResolver:
|
|
async def getaddrinfo(self, host, port, family, type, proto, flags):
|
|
return ("custom_gai", host, port, family, type, proto, flags)
|
|
|
|
async def getnameinfo(self, sockaddr, flags):
|
|
return ("custom_gni", sockaddr, flags)
|
|
|
|
cr = CustomResolver()
|
|
|
|
assert tsocket.set_custom_hostname_resolver(cr) is None
|
|
|
|
# Check that the arguments are all getting passed through.
|
|
# We have to use valid calls to avoid making the underlying system
|
|
# getaddrinfo cranky when it's used for NUMERIC checks.
|
|
for vals in [
|
|
(tsocket.AF_INET, 0, 0, 0),
|
|
(0, tsocket.SOCK_STREAM, 0, 0),
|
|
(0, 0, tsocket.IPPROTO_TCP, 0),
|
|
(0, 0, 0, tsocket.AI_CANONNAME),
|
|
]:
|
|
assert await tsocket.getaddrinfo("localhost", "foo", *vals) == (
|
|
"custom_gai",
|
|
b"localhost",
|
|
"foo",
|
|
*vals,
|
|
)
|
|
|
|
# IDNA encoding is handled before calling the special object
|
|
got = await tsocket.getaddrinfo("föö", "foo")
|
|
expected = ("custom_gai", b"xn--f-1gaa", "foo", 0, 0, 0, 0)
|
|
assert got == expected
|
|
|
|
assert await tsocket.getnameinfo("a", 0) == ("custom_gni", "a", 0)
|
|
|
|
# We can set it back to None
|
|
assert tsocket.set_custom_hostname_resolver(None) is cr
|
|
|
|
# And now Trio switches back to calling socket.getaddrinfo (specifically
|
|
# our monkeypatched version of socket.getaddrinfo)
|
|
monkeygai.set("x", b"host", "port", family=0, type=0, proto=0, flags=0)
|
|
assert await tsocket.getaddrinfo("host", "port") == "x"
|
|
|
|
|
|
async def test_custom_socket_factory():
|
|
class CustomSocketFactory:
|
|
def socket(self, family, type, proto):
|
|
return ("hi", family, type, proto)
|
|
|
|
csf = CustomSocketFactory()
|
|
|
|
assert tsocket.set_custom_socket_factory(csf) is None
|
|
|
|
assert tsocket.socket() == ("hi", tsocket.AF_INET, tsocket.SOCK_STREAM, 0)
|
|
assert tsocket.socket(1, 2, 3) == ("hi", 1, 2, 3)
|
|
|
|
# socket with fileno= doesn't call our custom method
|
|
fd = stdlib_socket.socket().detach()
|
|
wrapped = tsocket.socket(fileno=fd)
|
|
assert hasattr(wrapped, "bind")
|
|
wrapped.close()
|
|
|
|
# Likewise for socketpair
|
|
a, b = tsocket.socketpair()
|
|
with a, b:
|
|
assert hasattr(a, "bind")
|
|
assert hasattr(b, "bind")
|
|
|
|
assert tsocket.set_custom_socket_factory(None) is csf
|
|
|
|
|
|
async def test_SocketType_is_abstract():
|
|
with pytest.raises(TypeError):
|
|
tsocket.SocketType()
|
|
|
|
|
|
@pytest.mark.skipif(not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets")
|
|
async def test_unix_domain_socket():
|
|
# Bind has a special branch to use a thread, since it has to do filesystem
|
|
# traversal. Maybe connect should too? Not sure.
|
|
|
|
async def check_AF_UNIX(path):
|
|
with tsocket.socket(family=tsocket.AF_UNIX) as lsock:
|
|
await lsock.bind(path)
|
|
lsock.listen(10)
|
|
with tsocket.socket(family=tsocket.AF_UNIX) as csock:
|
|
await csock.connect(path)
|
|
ssock, _ = await lsock.accept()
|
|
with ssock:
|
|
await csock.send(b"x")
|
|
assert await ssock.recv(1) == b"x"
|
|
|
|
# Can't use tmpdir fixture, because we can exceed the maximum AF_UNIX path
|
|
# length on macOS.
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
path = "{}/sock".format(tmpdir)
|
|
await check_AF_UNIX(path)
|
|
|
|
try:
|
|
cookie = os.urandom(20).hex().encode("ascii")
|
|
await check_AF_UNIX(b"\x00trio-test-" + cookie)
|
|
except FileNotFoundError:
|
|
# macOS doesn't support abstract filenames with the leading NUL byte
|
|
pass
|
|
|
|
|
|
async def test_interrupted_by_close():
|
|
a_stdlib, b_stdlib = stdlib_socket.socketpair()
|
|
with a_stdlib, b_stdlib:
|
|
a_stdlib.setblocking(False)
|
|
|
|
data = b"x" * 99999
|
|
|
|
try:
|
|
while True:
|
|
a_stdlib.send(data)
|
|
except BlockingIOError:
|
|
pass
|
|
|
|
a = tsocket.from_stdlib_socket(a_stdlib)
|
|
|
|
async def sender():
|
|
with pytest.raises(_core.ClosedResourceError):
|
|
await a.send(data)
|
|
|
|
async def receiver():
|
|
with pytest.raises(_core.ClosedResourceError):
|
|
await a.recv(1)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(sender)
|
|
nursery.start_soon(receiver)
|
|
await wait_all_tasks_blocked()
|
|
a.close()
|
|
|
|
|
|
async def test_many_sockets():
|
|
total = 5000 # Must be more than MAX_AFD_GROUP_SIZE
|
|
sockets = []
|
|
for x in range(total // 2):
|
|
try:
|
|
a, b = stdlib_socket.socketpair()
|
|
except OSError as e: # pragma: no cover
|
|
assert e.errno in (errno.EMFILE, errno.ENFILE)
|
|
break
|
|
sockets += [a, b]
|
|
async with _core.open_nursery() as nursery:
|
|
for s in sockets:
|
|
nursery.start_soon(_core.wait_readable, s)
|
|
await _core.wait_all_tasks_blocked()
|
|
nursery.cancel_scope.cancel()
|
|
for sock in sockets:
|
|
sock.close()
|
|
if x != total // 2 - 1: # pragma: no cover
|
|
print(f"Unable to open more than {(x-1)*2} sockets.")
|