Source code for appmesh.transport_mixin

# transport_mixin.py
"""Shared transport logic for TCP and WSS clients."""

# Standard library imports
import json
import logging
import threading
import uuid
from http import HTTPStatus
from typing import Optional

# Third-party imports
import requests
from requests.structures import CaseInsensitiveDict

# Local imports
from .app import App
from .app_run import OutputHandler
from .client_http import AppMeshClient
from .exceptions import AppMeshConnectionError
from .subscribe import (
    EVENT_TYPE_DISCONNECTED,
    AppEvent,
    EventCallback,
    MessageDemuxer,
    SubscriptionResult,
)
from .tcp_messages import RequestMessage, ResponseMessage

logger = logging.getLogger(__name__)

# Auth endpoints where the server returns a new access_token in the JSON body.
# Login/auth/totp_validate: apply token only when X-Set-Cookie header is present
_AUTH_SET_COOKIE_PATHS = frozenset({"/appmesh/login", "/appmesh/auth", "/appmesh/totp/validate"})
# Renew/setup: always apply (client already has an active session)
_AUTH_RENEW_PATHS = frozenset({"/appmesh/token/renew", "/appmesh/totp/setup"})
_LOGOFF_PATH = "/appmesh/self/logoff"


[docs] class TransportClientMixin: """Mixin providing shared request/response logic for TCP and WSS transport clients. Subclasses must define: - _transport: the transport object (TCPTransport or WSSTransport) - _token: the current access token string - _HTTP_USER_AGENT_TRANSPORT: user agent string for this transport """ _ENCODING_UTF8 = "utf-8" def _convert_bytes(self, body) -> bytes: """Prepare request body for transmission.""" if body is None: return b"" if isinstance(body, (bytes, bytearray, memoryview)): return bytes(body) if isinstance(body, str): return body.encode(self._ENCODING_UTF8) if isinstance(body, (dict, list)): return json.dumps(body).encode(self._ENCODING_UTF8) raise TypeError(f"Unsupported body type: {type(body)}") def _on_token_changed(self, token: Optional[str]) -> None: """Store token locally and delegate to base class.""" self._token = token super()._on_token_changed(token) def _get_access_token(self) -> Optional[str]: """Get the current access token.""" return self._token def _sync_transport_token(self, response, path: str, request_headers: Optional[dict]) -> None: """Extract and apply token from auth endpoint responses (TCP/WSS only). HTTP transport relies on Set-Cookie for automatic cookie jar updates; TCP/WSS must extract the token from the JSON response body. """ if response.status_code != HTTPStatus.OK: return if path == _LOGOFF_PATH: self._on_token_changed(None) return # Login/auth/totp_validate: apply only when client requested cookie mode if path in _AUTH_SET_COOKIE_PATHS: if not request_headers or request_headers.get("X-Set-Cookie") != "true": return elif path not in _AUTH_RENEW_PATHS: return # Extract access_token from JSON body try: token = response.json().get("access_token") if token: self._on_token_changed(token) except Exception: # pylint: disable=broad-exception-caught pass def _request_http( self, method: AppMeshClient._Method, path: str, query: Optional[dict] = None, header: Optional[dict] = None, body=None, raise_on_fail: bool = True, ) -> requests.Response: """Send HTTP request over transport. Args: method: HTTP method. path: URI path. query: Query parameters. header: HTTP headers. body: Request body. raise_on_fail: Raise exception on HTTP error. Returns: Simulated HTTP response. """ transport = self._transport if not transport.connected(): transport.connect() # Prepare request message (ensure no fields are assigned None!) appmesh_request = RequestMessage() appmesh_request.uuid = str(uuid.uuid4()) appmesh_request.http_method = method.value appmesh_request.request_uri = path appmesh_request.client_addr = self._transport_client_addr appmesh_request.headers[self._HTTP_HEADER_KEY_USER_AGENT] = self._HTTP_USER_AGENT_TRANSPORT # Add authentication token token = self._get_access_token() if token: appmesh_request.headers[self._HTTP_HEADER_KEY_AUTH] = token # Add forwarding host target_host = self.forward_to if target_host: appmesh_request.headers[self._HTTP_HEADER_KEY_X_TARGET_HOST] = target_host # Add custom headers if header: appmesh_request.headers.update(header) # Add query parameters if query: appmesh_request.query.update(query) # Prepare body body_bytes = self._convert_bytes(body) if body_bytes: appmesh_request.body = body_bytes # Send request and receive response data = appmesh_request.serialize() if hasattr(self, "_demuxer") and self._demuxer and self._demuxer._running: # Demuxer is active — route through it to avoid concurrent socket reads appmesh_resp = self._demuxer.send_and_receive(appmesh_request.uuid, data, timeout=60.0) if not appmesh_resp: transport.close() raise AppMeshConnectionError(f"{self._transport_name} demuxer response timeout") else: transport.send_message(data) resp_data = transport.receive_message() if not resp_data: # Covers None and empty bytes transport.close() raise AppMeshConnectionError(f"{self._transport_name} connection broken") appmesh_resp = ResponseMessage.from_bytes(resp_data) response = requests.Response() response.status_code = appmesh_resp.http_status response.headers = CaseInsensitiveDict(appmesh_resp.headers) # Set response content if isinstance(appmesh_resp.body, bytes): response._content = appmesh_resp.body else: response._content = str(appmesh_resp.body).encode(self._ENCODING_UTF8) # Set content type if appmesh_resp.body_msg_type: response.headers["Content-Type"] = appmesh_resp.body_msg_type if raise_on_fail and response.status_code != HTTPStatus.PRECONDITION_REQUIRED: response.reason = str(response._content) response.url = f"{str(transport)}/{path.lstrip('/')}" response.raise_for_status() # Auto-sync token from auth endpoint responses self._sync_transport_token(response, path, header) return AppMeshClient._EncodingResponse(response)
[docs] def add_app(self, app: App, subscribe_events: Optional[list] = None, callback: Optional[EventCallback] = None) -> App: """Register an app, optionally subscribing atomically and wiring a local callback. Reuses the base ``add_app`` for the HTTP round-trip + ``subscription_id`` parsing, then registers ``callback`` against the local demuxer keyed by the new subscription. """ result_app = super().add_app(app, subscribe_events=subscribe_events) if callback and result_app.subscription_id: self._ensure_demuxer() self._demuxer.register_event_callback(result_app.subscription_id, callback) return result_app
[docs] def subscribe(self, app_name: str, events: Optional[list] = None, callback: Optional[EventCallback] = None) -> SubscriptionResult: """Subscribe to app events over the transport connection. Args: app_name: Application name, or "*" for all apps. events: List of event types (e.g. ["START", "EXIT", "STDOUT"]). callback: Function called with AppEvent for each received event. Returns: SubscriptionResult with subscription_id, app_name, and events. """ path = "/appmesh/subscribe" if app_name and app_name != "*": path = f"/appmesh/app/{app_name}/subscribe" query = {} if events: query["events"] = ",".join(events) resp = self._request_http(AppMeshClient._Method.POST, path=path, query=query) result_data = resp.json() result = SubscriptionResult( subscription_id=result_data.get("subscription_id", ""), app_name=result_data.get("app_name", ""), events=result_data.get("events", []), ) if callback and result.subscription_id: self._ensure_demuxer() self._demuxer.register_event_callback(result.subscription_id, callback) return result
[docs] def unsubscribe(self, subscription_id: str) -> None: """Remove an event subscription. Args: subscription_id: The subscription ID returned by subscribe(). """ query = {"subscription_id": subscription_id} self._request_http(AppMeshClient._Method.DELETE, path="/appmesh/subscribe", query=query) if hasattr(self, "_demuxer") and self._demuxer: self._demuxer.unregister_event_callback(subscription_id)
def _ensure_demuxer(self) -> None: """Start the message demuxer if not already running.""" if hasattr(self, "_demuxer") and self._demuxer: return self._demuxer = MessageDemuxer(self._transport) self._demuxer.start()
[docs] def wait_for_async_run(self, run, stdout_handler: OutputHandler = None, timeout: int = 0) -> Optional[int]: """Override: use subscribe-based streaming on TCP/WSS instead of polling. Subscribes to ``STDOUT`` + ``EXIT`` + ``REMOVED``, then does a one-shot ``get_app_output`` to backfill bytes emitted before the subscribe took effect. Stdout events whose ``position`` is already covered by an earlier delivery are deduped (partial overlap → prefix trimmed). """ if not run or not run.app_name: return None wait_timeout: Optional[float] = None if timeout in (0, None) else float(timeout) # Sentinel exit codes distinguishable from real ones (>=0): # None → caller-side timeout (done.wait returned without done.set) # -1 → REMOVED before EXIT observed # -2 → demuxer disconnected (transport failure) exit_code: Optional[int] = None delivered_until = 0 # next-byte offset already passed to stdout_handler done = threading.Event() lock = threading.Lock() def deliver(chunk, pos: int) -> None: nonlocal delivered_until if not chunk: return chunk_bytes = chunk.encode("utf-8") if isinstance(chunk, str) else bytes(chunk) with lock: end = pos + len(chunk_bytes) if end <= delivered_until: return start_pos = pos if pos < delivered_until: chunk_bytes = chunk_bytes[delivered_until - pos:] start_pos = delivered_until delivered_until = end if stdout_handler is not None: try: stdout_handler(chunk_bytes.decode("utf-8", errors="replace"), start_pos) except Exception: pass def on_event(event: AppEvent) -> None: nonlocal exit_code if event.event_type == "STDOUT": try: pos = int(event.data.get("position", 0)) except (TypeError, ValueError): pos = 0 deliver(event.data.get("output", ""), pos) elif event.event_type == "EXIT": try: exit_code = int(event.data.get("exit_code", -1)) except (TypeError, ValueError): exit_code = -1 done.set() elif event.event_type == "REMOVED": if exit_code is None: exit_code = -1 done.set() elif event.event_type == EVENT_TYPE_DISCONNECTED: if exit_code is None: exit_code = -2 done.set() sub = self.subscribe(run.app_name, ["STDOUT", "EXIT", "REMOVED"], callback=on_event) try: # Backfill bytes emitted before subscribe took effect; also catches # the case where the process already exited. try: backfill = self.get_app_output( app_name=run.app_name, stdout_position=0, stdout_index=0, process_uuid=run.proc_uid, timeout=0, ) if backfill.output: deliver(backfill.output, 0) if backfill.exit_code is not None and exit_code is None: exit_code = backfill.exit_code done.set() except Exception as exc: logger.warning("backfill failed for %s: %s", run.app_name, exc) done.wait(timeout=wait_timeout) finally: try: if sub.subscription_id: self.unsubscribe(sub.subscription_id) except Exception: pass # Best-effort delete on a real exit. Sentinels (-1 REMOVED, -2 disconnected) # mean the daemon already lost track or the app is gone — don't try to delete. if exit_code is not None and exit_code >= 0: try: self.delete_app(run.app_name) except Exception: pass return exit_code