import asyncio
import logging
from asyncio.streams import StreamReader, StreamWriter
from contextlib import suppress
from typing import AsyncGenerator, Dict, Iterable, Optional
from .commands import Command, T
from .connections import ControlConnection, DataConnection
from .responses import Data
# Define the public API of this module
__all__ = ["AsyncioClient"]
class _StreamHelper:
_reader: Optional[StreamReader] = None
_writer: Optional[StreamWriter] = None
@property
def reader(self) -> StreamReader:
assert self._reader, "connect() not called yet"
return self._reader
@property
def writer(self) -> StreamWriter:
assert self._writer, "connect() not called yet"
return self._writer
async def write_and_drain(self, data: bytes, timeout: Optional[float] = None):
writer = self.writer
writer.write(data)
# Cannot simply await the drain, as if the remote end has disconnected
# then the drain will never complete as the OS cannot clear its send buffer.
write_task = asyncio.create_task(writer.drain())
_, pending = await asyncio.wait([write_task], timeout=timeout)
if len(pending):
for task in pending:
task.cancel()
raise asyncio.TimeoutError("Timeout writing data")
async def connect(self, host: str, port: int):
self._reader, self._writer = await asyncio.open_connection(host, port)
async def close(self):
writer = self.writer
self._reader = None
self._writer = None
writer.close()
await writer.wait_closed()
[docs]
class AsyncioClient:
"""Asyncio implementation of a PandABlocks client.
For example::
async with AsyncioClient("hostname-or-ip") as client:
# Control port is now connected
resp1, resp2 = await asyncio.gather(client.send(cmd1), client.send(cmd2))
resp3 = await client.send(cmd3)
async for data in client.data():
handle(data)
# Control and data ports are now disconnected
"""
def __init__(self, host: str):
self._host = host
self._ctrl_connection = ControlConnection()
self._ctrl_task: Optional[asyncio.Task] = None
self._ctrl_queues: Dict[int, asyncio.Queue] = {}
self._ctrl_stream = _StreamHelper()
[docs]
async def connect(self):
"""Connect to the control port, and be ready to handle commands"""
await self._ctrl_stream.connect(self._host, 8888)
self._ctrl_task = asyncio.create_task(
self._ctrl_read_forever(self._ctrl_stream.reader)
)
[docs]
def is_connected(self):
"""True if there is a currently active connection.
NOTE: This does not indicate if the remote end is still connected."""
if self._ctrl_task and not self._ctrl_task.done():
return True
return False
[docs]
async def close(self):
"""Close the control connection, and wait for completion"""
assert self._ctrl_task, "connect() not called yet"
self._ctrl_task.cancel()
await self._ctrl_stream.close()
async def __aenter__(self) -> "AsyncioClient":
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async def _ctrl_read_forever(self, reader: asyncio.StreamReader):
while True:
received = await reader.read(4096)
try:
to_send = self._ctrl_connection.receive_bytes(received)
await self._ctrl_stream.write_and_drain(to_send)
for command, response in self._ctrl_connection.responses():
queue = self._ctrl_queues.pop(id(command))
queue.put_nowait(response)
except Exception:
logging.exception(f"Error handling '{received.decode()}'")
[docs]
async def send(self, command: Command[T], timeout: Optional[float] = None) -> T:
"""Send a command to control port of the PandA, returning its response.
Args:
command: The `Command` to send
"""
queue: asyncio.Queue[T] = asyncio.Queue()
# Need to use the id as non-frozen dataclasses don't hash
self._ctrl_queues[id(command)] = queue
to_send = self._ctrl_connection.send(command)
await self._ctrl_stream.write_and_drain(to_send, timeout)
response = await asyncio.wait_for(queue.get(), timeout)
if isinstance(response, Exception):
raise response
else:
return response
[docs]
async def data(
self,
scaled: bool = True,
flush_period: Optional[float] = None,
frame_timeout: Optional[float] = None,
) -> AsyncGenerator[Data, None]:
"""Connect to data port and yield data frames
Args:
scaled: Whether to scale and average data frames, reduces throughput
flush_period: How often to flush partial data frames, None is on every
chunk of data from the server
frame_timeout: If no data is received for this amount of time, raise
`asyncio.TimeoutError`
"""
stream = _StreamHelper()
connection = DataConnection()
queue: asyncio.Queue[Iterable[Data]] = asyncio.Queue()
def raise_timeouterror():
raise asyncio.TimeoutError(f"No data received for {frame_timeout}s")
yield
async def periodic_flush():
if flush_period is not None:
while True:
# Every flush_period seconds flush and queue data
await asyncio.sleep(flush_period)
queue.put_nowait(connection.flush())
async def read_from_stream():
reader = stream.reader
# Should we flush every FrameData?
flush_every_frame = flush_period is None
while True:
try:
recv = await asyncio.wait_for(reader.read(4096), frame_timeout)
except asyncio.TimeoutError:
queue.put_nowait(raise_timeouterror())
break
else:
queue.put_nowait(connection.receive_bytes(recv, flush_every_frame))
await stream.connect(self._host, 8889)
await stream.write_and_drain(connection.connect(scaled))
fut = asyncio.gather(periodic_flush(), read_from_stream())
try:
while True:
for data in await queue.get():
yield data
finally:
fut.cancel()
await stream.close()
with suppress(asyncio.CancelledError):
await fut