Contact Us 1-800-596-4880

Support WebSocket APIs

To view an example policy project that uses Omni Gateway Policy Development Kit (PDK)'s WebSocket library, see WebSocket Policy Example.

Use the Omni Gateway Policy Development Kit (PDK) WebSocket library to implement custom policies that support WebSocket APIs.

Enable WebSocket Support

To enable WebSocket functionality, add the experimental_websocket feature at the PDK dependency in your Cargo.toml file:

[dependencies]
pdk = { version = "1.9.0", features = ["experimental_websocket"]}

Configure WebSocket Handlers

To access WebSocket functionality, use the FilterBuilder with on_upgrade_upstream and on_upgrade_downstream handlers:

use pdk::hl::*;
use pdk::websockets::{Decoder, Encoder, Frame, FrameType};

pub async fn configure(launcher: Launcher) -> Result<()> {
    let handler = FilterBuilder::new()
        .on_create(|| YourPolicyState {})
        .on_request(|_req: RequestState| async move { Flow::Continue(()) })
        .on_upgrade_upstream(handle_upstream)
        .on_upgrade_downstream(handle_downstream)
        .build();

    launcher.launch(handler).await?;
    Ok(())
}

Supported Frame Types

The PDK supports these WebSocket frame types:

  • FrameType::Text: Text data frames

  • FrameType::Binary: Binary data frames

  • FrameType::Close: Connection close frames

  • FrameType::Ping: Ping frames

  • FrameType::Pong: Pong frames

To maintain the WebSocket protocol integrity, pass control frames (Close, Ping, Pong) through without modification.

Maintain Frame Boundaries

Make sure your policy processes complete WebSocket messages regardless of how they arrive from the network.

WebSocket frames might be fragmented across multiple chunks. PDK ensures your policy processes complete WebSocket messages regardless of how they arrive by:

  1. Accumulating bytes using state.accumulate().await when complete frames aren’t available.

  2. Tracking remainder bytes between iterations.

  3. Parsing complete frames only when sufficient data is available.

Process Client to Server Frames

Process client to server frames in the upstream handler using UpstreamState. The handler must manage frame boundaries by accumulating partial frames:

async fn handle_upstream(mut state: UpstreamState) -> Result<(), BoxError> {
    let mut remainder_bytes = Vec::new();

    loop {
        let current_bytes = state.bytes();

        let mut bytes = remainder_bytes.clone();
        bytes.extend_from_slice(&current_bytes);

        let (frames, remainder) = Decoder::parse(bytes);

        // Wait for complete frames before processing
        if frames.is_empty() {
            state = state.accumulate().await;
        } else {
            remainder_bytes = remainder;

            // Inspect and modify frames
            let modified: Vec<Frame> = frames
                .into_iter()
                .map(|frame| match frame.frame_type() {
                    FrameType::Text => {
                        let text = String::from_utf8_lossy(frame.data());
                        Frame::text(format!("Modified: {}", text), frame.fin())
                    }
                    _ => frame, // Control frames pass through
                })
                .collect();

            let encoded = Encoder::default().encode_client(modified);
            state.set_body(&encoded);
            state = state.next().await;
        }
    }
}

Process Server to Client Frames

Process server to client frames in the downstream handler using DownstreamState:

async fn handle_downstream(mut state: DownstreamState) -> Result<(), BoxError> {
    let mut remainder_bytes = Vec::new();

    loop {
        let current_bytes = state.bytes();

        let mut bytes = remainder_bytes.clone();
        bytes.extend_from_slice(&current_bytes);

        let (frames, remainder) = Decoder::parse(bytes);

        // Wait for complete frames before processing
        if frames.is_empty() {
            state = state.accumulate().await;
        } else {
            remainder_bytes = remainder;

            // Process frames
            let modified: Vec<Frame> = frames
                .into_iter()
                .map(|frame| match frame.frame_type() {
                    FrameType::Text => {
                        let text = String::from_utf8_lossy(frame.data());
                        Frame::text(format!("Echo: {}", text), frame.fin())
                    }
                    _ => frame, // Control frames pass through
                })
                .collect();

            let encoded = Encoder::default().encode_server(modified);
            state.set_body(&encoded);
            state = state.next().await;
        }
    }
}