95 lines
2.6 KiB
Python
95 lines
2.6 KiB
Python
import pytest
|
|
|
|
import attr
|
|
|
|
from ..abc import SendStream, ReceiveStream
|
|
from .._highlevel_generic import StapledStream
|
|
|
|
|
|
@attr.s
|
|
class RecordSendStream(SendStream):
|
|
record = attr.ib(factory=list)
|
|
|
|
async def send_all(self, data):
|
|
self.record.append(("send_all", data))
|
|
|
|
async def wait_send_all_might_not_block(self):
|
|
self.record.append("wait_send_all_might_not_block")
|
|
|
|
async def aclose(self):
|
|
self.record.append("aclose")
|
|
|
|
|
|
@attr.s
|
|
class RecordReceiveStream(ReceiveStream):
|
|
record = attr.ib(factory=list)
|
|
|
|
async def receive_some(self, max_bytes=None):
|
|
self.record.append(("receive_some", max_bytes))
|
|
|
|
async def aclose(self):
|
|
self.record.append("aclose")
|
|
|
|
|
|
async def test_StapledStream():
|
|
send_stream = RecordSendStream()
|
|
receive_stream = RecordReceiveStream()
|
|
stapled = StapledStream(send_stream, receive_stream)
|
|
|
|
assert stapled.send_stream is send_stream
|
|
assert stapled.receive_stream is receive_stream
|
|
|
|
await stapled.send_all(b"foo")
|
|
await stapled.wait_send_all_might_not_block()
|
|
assert send_stream.record == [
|
|
("send_all", b"foo"),
|
|
"wait_send_all_might_not_block",
|
|
]
|
|
send_stream.record.clear()
|
|
|
|
await stapled.send_eof()
|
|
assert send_stream.record == ["aclose"]
|
|
send_stream.record.clear()
|
|
|
|
async def fake_send_eof():
|
|
send_stream.record.append("send_eof")
|
|
|
|
send_stream.send_eof = fake_send_eof
|
|
await stapled.send_eof()
|
|
assert send_stream.record == ["send_eof"]
|
|
|
|
send_stream.record.clear()
|
|
assert receive_stream.record == []
|
|
|
|
await stapled.receive_some(1234)
|
|
assert receive_stream.record == [("receive_some", 1234)]
|
|
assert send_stream.record == []
|
|
receive_stream.record.clear()
|
|
|
|
await stapled.aclose()
|
|
assert receive_stream.record == ["aclose"]
|
|
assert send_stream.record == ["aclose"]
|
|
|
|
|
|
async def test_StapledStream_with_erroring_close():
|
|
# Make sure that if one of the aclose methods errors out, then the other
|
|
# one still gets called.
|
|
class BrokenSendStream(RecordSendStream):
|
|
async def aclose(self):
|
|
await super().aclose()
|
|
raise ValueError
|
|
|
|
class BrokenReceiveStream(RecordReceiveStream):
|
|
async def aclose(self):
|
|
await super().aclose()
|
|
raise ValueError
|
|
|
|
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
|
|
|
|
with pytest.raises(ValueError) as excinfo:
|
|
await stapled.aclose()
|
|
assert isinstance(excinfo.value.__context__, ValueError)
|
|
|
|
assert stapled.send_stream.record == ["aclose"]
|
|
assert stapled.receive_stream.record == ["aclose"]
|