Source code for appmesh.tcp_transport

# tcp_transport.py
"""TCP Transport layer handling socket connections."""
__all__ = ["TCPTransport"]

import logging
import socket
import ssl
import struct
from pathlib import Path
from typing import Optional, Union, Tuple

from .exceptions import AppMeshConnectionError

logger = logging.getLogger(__name__)


[docs] class TCPTransport: """TCP Transport layer with TLS support.""" # Number of bytes used for the message length header # Must match the C++ service implementation which uses uint32_t (4 bytes) # Format: Big-endian unsigned 32-bit integer TCP_MESSAGE_HEADER_LENGTH = 8 TCP_MESSAGE_MAGIC = 0x07C707F8 # Magic number TCP_MAX_BLOCK_SIZE = 100 * 1024 * 1024 # 100 MB def __init__( self, address: Tuple[str, int], ssl_verify: Union[bool, str], ssl_client_cert: Optional[Union[str, Tuple[str, str]]] = None, ): """ Initialize TCP transport with TLS configuration. Args: address: Server address as (host, port) tuple. ssl_verify: SSL server verification mode: - True: Use system CA certificates - False: Disable verification (insecure) - str: Path to custom CA bundle or directory ssl_client_cert: SSL client certificate: - str: Path to PEM file with cert and key - tuple: (cert_path, key_path) Note: TCP connections require an explicit full-chain CA specification for certificate validation, unlike HTTP, which can retrieve intermediate certificates automatically. """ self.address = address self.ssl_verify = ssl_verify self.ssl_client_cert = ssl_client_cert self._socket: Optional[ssl.SSLSocket] = None def __enter__(self): """Context manager entry.""" if not self.connected(): self.connect() return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self.close() def __del__(self): try: self.close() except Exception: # pylint: disable=broad-exception-caught pass # suppress all exceptions def __str__(self) -> str: """Return TCP URI representation.""" scheme = "tcps" if self.ssl_verify else "tcp" host, port = self.address return f"{scheme}://{host}:{port}"
[docs] def connect(self) -> None: """Establish TLS connection to server.""" context = self._create_ssl_context() # Create a TCP socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setblocking(True) try: # Wrap the socket with SSL/TLS ssl_socket = context.wrap_socket(sock, server_hostname=self.address[0]) ssl_socket.connect(self.address) # Disable Nagle's algorithm ssl_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._socket = ssl_socket except (socket.error, ssl.SSLError) as e: sock.close() raise AppMeshConnectionError(f"Failed to connect to {self.address}: {e}") from e
def _create_ssl_context(self) -> ssl.SSLContext: """Create and configure SSL context.""" context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) # Set minimum TLS version if hasattr(context, "minimum_version"): context.minimum_version = ssl.TLSVersion.TLSv1_2 else: context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 # Configure SSL verification if not self.ssl_verify: context.check_hostname = False context.verify_mode = ssl.CERT_NONE else: context.verify_mode = ssl.CERT_REQUIRED # Require certificate verification context.load_default_certs() # Load system's default CA certificates if isinstance(self.ssl_verify, str): path = Path(self.ssl_verify) if path.is_file(): # Load custom CA certificate file context.load_verify_locations(cafile=str(path)) elif path.is_dir(): # Load CA certificates from directory context.load_verify_locations(capath=str(path)) else: raise ValueError(f"ssl_verify path '{self.ssl_verify}' is invalid") # Load client certificate if provided if self.ssl_client_cert: if isinstance(self.ssl_client_cert, tuple): # Separate cert and key files context.load_cert_chain(certfile=self.ssl_client_cert[0], keyfile=self.ssl_client_cert[1]) else: # Cert and key in the same PEM file context.load_cert_chain(certfile=self.ssl_client_cert) return context
[docs] def close(self) -> None: """Close socket""" if self._socket: try: self._socket.close() except Exception as e: logger.warning("Error closing socket: %s", e) finally: self._socket = None
[docs] def connected(self) -> bool: """Socket is connected or not""" return self._socket is not None
[docs] def send_message(self, data: Union[bytes, bytearray, list]) -> None: """ Send a message with prefixed header. Args: data: Message data to send, or empty list for EOF signal. """ if not self._socket: raise AppMeshConnectionError("Cannot send message: not connected") try: length = len(data) if data else 0 # Pack the header into 8 bytes using big-endian format header = struct.pack("!II", self.TCP_MESSAGE_MAGIC, length) self._socket.sendall(header) if length > 0: self._socket.sendall(data) except (socket.error, ssl.SSLError) as e: self.close() raise AppMeshConnectionError(f"Error sending message: {e}") from e
[docs] def receive_message(self) -> Optional[bytearray]: """ Receive a message with prefixed header. Returns: Message data, or None for EOF signal. """ if not self._socket: raise AppMeshConnectionError("Cannot receive message: not connected") try: # Unpack the data (big-endian format) magic, length = struct.unpack("!II", self._recvall(self.TCP_MESSAGE_HEADER_LENGTH)) if magic != self.TCP_MESSAGE_MAGIC: raise ValueError(f"Invalid magic number: 0x{magic:X}") if length > self.TCP_MAX_BLOCK_SIZE: raise ValueError(f"Message size {length} exceeds maximum {self.TCP_MAX_BLOCK_SIZE}") return self._recvall(length) if length > 0 else None except (socket.error, ssl.SSLError, EOFError, ValueError) as e: self.close() raise AppMeshConnectionError(f"Error receiving message: {e}") from e
def _recvall(self, length: int) -> bytes: """ Receive exactly `length` bytes from socket. https://stackoverflow.com/questions/64466530/using-a-custom-socket-recvall-function-works-only-if-thread-is-put-to-sleep Args: length: Number of bytes to receive. Returns: Received data. Raises: EOFError: If connection closes before receiving all data. ValueError: If length is not positive. """ if length <= 0: raise ValueError(f"Invalid length: {length}") # Pre-allocate buffer buffer = bytearray(length) view = memoryview(buffer) bytes_received = 0 while bytes_received < length: try: # Receive directly into buffer remaining = length - bytes_received chunk_size = self._socket.recv_into(view, remaining) if chunk_size == 0: raise EOFError("Connection closed by peer") view = view[chunk_size:] # advance memoryview bytes_received += chunk_size except InterruptedError: continue except socket.timeout as e: raise socket.timeout(f"Socket timed out after receiving {bytes_received}/{length} bytes") from e return bytes(buffer)