Skip to main content

auraxis/realtime/
client.rs

1use super::Message as CensusMessage;
2use crate::AuraxisError;
3use crate::realtime::{Action, Event, REALTIME_URL, SubscriptionSettings};
4use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, RwLock};
7use std::task::{Context, Poll};
8
9use std::time::Duration;
10
11use futures_util::stream::{SplitSink, SplitStream};
12use futures_util::{Future, Sink, SinkExt, Stream, StreamExt};
13use metrics::{counter, describe_counter};
14use stream_reconnect::{ReconnectStream, UnderlyingStream};
15use tokio::net::TcpStream;
16use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
17use tokio::sync::watch;
18use tokio_tungstenite::tungstenite::Message;
19use tokio_tungstenite::tungstenite::error::Error as WsError;
20use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
21use tracing::{debug, error, info, warn};
22
23#[derive(Debug, Clone)]
24pub struct RealtimeClientConfig {
25    pub environment: String,
26    pub service_id: String,
27    pub realtime_url: Option<String>,
28}
29
30impl Default for RealtimeClientConfig {
31    fn default() -> Self {
32        Self {
33            environment: String::from("ps2"),
34            service_id: String::new(),
35            realtime_url: None,
36        }
37    }
38}
39
40#[derive(Debug, Clone)]
41pub struct RealtimeClient {
42    config: Arc<RealtimeClientConfig>,
43    state: Arc<RwLock<RealtimeClientState>>,
44}
45
46#[derive(Debug, Clone)]
47struct RealtimeClientState {
48    subscription_config: SubscriptionSettings,
49    ws_send: Option<UnboundedSender<Message>>,
50}
51
52struct WebSocket(WebSocketStream<MaybeTlsStream<TcpStream>>);
53
54type ReconnectWs = ReconnectStream<WebSocket, String, Result<Message, WsError>, WsError>;
55
56impl RealtimeClient {
57    #[must_use]
58    pub fn new(config: RealtimeClientConfig) -> Self {
59        describe_counter!(
60            "realtime_messages_total_sent",
61            "Total number of messages sent to Census stream"
62        );
63        describe_counter!(
64            "realtime_messages_received_total",
65            "Total number of messages received from Census stream"
66        );
67        describe_counter!(
68            "realtime_messages_received_total_errored",
69            "Total number of messages received from Census stream that errored"
70        );
71        describe_counter!(
72            "realtime_total_closed_connections",
73            "Total number of closed connections to Census stream"
74        );
75        describe_counter!(
76            "realtime_total_connections",
77            "Total number of connections to Census stream"
78        );
79        describe_counter!(
80            "realtime_messages_received_heartbeat",
81            "Total number of heartbeat messages received from Census stream"
82        );
83        describe_counter!(
84            "realtime_total_pings",
85            "Total number of ping messages sent to Census stream, may include errors"
86        );
87        describe_counter!(
88            "realtime_total_ping_errors",
89            "Total number of ping messages that failed to receive a response from Census stream"
90        );
91        describe_counter!(
92            "realtime_total_resubscriptions",
93            "Total number of resubscriptions to Census stream"
94        );
95
96        Self {
97            config: Arc::new(config),
98            state: Arc::new(RwLock::new(RealtimeClientState {
99                subscription_config: SubscriptionSettings::empty(),
100                ws_send: None,
101            })),
102        }
103    }
104
105    /// Send a message to the websocket connection.
106    ///
107    /// This function will be spawned as a task and will run concurrently to the
108    /// rest of the application. It will continually check for messages on the
109    /// receiver end of the channel. When a message is received, it will be sent to
110    /// the websocket connection. If sending the message fails, the error is logged
111    /// and the connection is closed.
112    ///
113    /// # Arguments
114    ///
115    /// * `websocket` - The websocket connection to send messages to.
116    /// * `receiver` - The channel receiving messages to send.
117    ///
118    /// # Errors
119    ///
120    /// This function will return an error if the websocket connection cannot be created.
121    pub async fn connect(&mut self) -> Result<Receiver<Event>, AuraxisError> {
122        if self.current_ws_sender().is_some() {
123            return Err(anyhow::anyhow!("RealtimeClient is already connected").into());
124        }
125
126        let census_addr = format!(
127            "{}?environment={}&service-id=s:{}",
128            self.config.realtime_url.as_deref().unwrap_or(REALTIME_URL),
129            self.config.environment,
130            self.config.service_id
131        );
132
133        let websocket = ReconnectWs::connect(census_addr).await?;
134
135        let (ws_send, ws_recv) = websocket.split();
136        let (ws_send_tx, ws_send_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
137        let (event_stream_tx, event_stream_rx) = tokio::sync::mpsc::channel::<Event>(1000);
138        let (shutdown_tx, shutdown_rx) = watch::channel(false);
139
140        self.set_ws_sender(Some(ws_send_tx.clone()));
141
142        tokio::spawn(Self::send_ws(ws_send, ws_send_rx, shutdown_rx.clone()));
143        tokio::spawn(Self::ping_ws(ws_send_tx.clone(), shutdown_rx.clone()));
144        tokio::spawn(Self::resubscribe(
145            self.clone(),
146            ws_send_tx.clone(),
147            shutdown_rx.clone(),
148        ));
149        tokio::spawn(Self::read_ws(
150            self.clone(),
151            ws_send_tx,
152            ws_recv,
153            event_stream_tx,
154            shutdown_tx,
155            shutdown_rx,
156        ));
157
158        Ok(event_stream_rx)
159    }
160
161    pub fn subscribe(&mut self, subscription: SubscriptionSettings) {
162        let ws_send = {
163            let mut state = self.state.write().expect("realtime client state poisoned");
164            state.subscription_config.merge(subscription);
165            state.ws_send.clone()
166        };
167
168        let subscribe_message = match self.subscribe_message() {
169            Ok(Some(message)) => message,
170            Ok(None) => return,
171            Err(err) => {
172                error!("Failed to serialize subscription update: {err}");
173                return;
174            }
175        };
176
177        if let Some(ws_send) = ws_send
178            && let Err(err) = ws_send.send(subscribe_message)
179        {
180            warn!("Failed to enqueue live subscription update: {err}");
181            self.set_ws_sender(None);
182        }
183    }
184
185    pub fn clear_subscribe(&mut self, subscription: SubscriptionSettings) {
186        let (ws_send, current_subscription) = {
187            let mut state = self.state.write().expect("realtime client state poisoned");
188            state.subscription_config.clear(&subscription);
189            (state.ws_send.clone(), state.subscription_config.clone())
190        };
191
192        let clear_message = match Self::clear_subscribe_message(&subscription) {
193            Ok(message) => message,
194            Err(err) => {
195                error!("Failed to serialize clear subscription update: {err}");
196                return;
197            }
198        };
199
200        if let Some(ws_send) = ws_send {
201            if let Err(err) = ws_send.send(clear_message) {
202                warn!("Failed to enqueue clear subscription update: {err}");
203                self.set_ws_sender(None);
204                return;
205            }
206
207            if subscription.logical_and_characters_with_worlds.is_some()
208                && !current_subscription.is_empty()
209            {
210                match serde_json::to_string(&Action::Subscribe(current_subscription))
211                    .map(|message| Message::Text(message.into()))
212                {
213                    Ok(message) => {
214                        if let Err(err) = ws_send.send(message) {
215                            warn!("Failed to enqueue resubscribe after logical-and update: {err}");
216                            self.set_ws_sender(None);
217                        }
218                    }
219                    Err(err) => {
220                        error!("Failed to serialize resubscribe after logical-and update: {err}");
221                    }
222                }
223            }
224        }
225    }
226
227    pub fn clear_all_subscriptions(&mut self) {
228        let ws_send = {
229            let mut state = self.state.write().expect("realtime client state poisoned");
230            state.subscription_config = SubscriptionSettings::empty();
231            state.ws_send.clone()
232        };
233
234        if let Some(ws_send) = ws_send {
235            match Self::clear_all_subscribe_message() {
236                Ok(message) => {
237                    if let Err(err) = ws_send.send(message) {
238                        warn!("Failed to enqueue clear-all subscription update: {err}");
239                        self.set_ws_sender(None);
240                    }
241                }
242                Err(err) => {
243                    error!("Failed to serialize clear-all subscription update: {err}");
244                }
245            }
246        }
247    }
248
249    async fn resubscribe(
250        self,
251        ws_send: UnboundedSender<Message>,
252        mut shutdown: watch::Receiver<bool>,
253    ) -> Result<(), AuraxisError> {
254        loop {
255            if *shutdown.borrow() {
256                return Ok(());
257            }
258
259            let Some(message) = self.subscribe_message()? else {
260                tokio::select! {
261                    _ = shutdown.changed() => return Ok(()),
262                    _ = tokio::time::sleep(Duration::from_secs(60 * 30)) => {}
263                }
264                continue;
265            };
266
267            let res = ws_send.send(message);
268
269            match res {
270                Ok(_) => {
271                    counter!("realtime_total_resubscriptions").increment(1);
272                    tokio::select! {
273                        _ = shutdown.changed() => return Ok(()),
274                        _ = tokio::time::sleep(Duration::from_secs(60 * 30)) => {}
275                    }
276                }
277                Err(err) => {
278                    warn!("Subscription loop shutting down: {}", err);
279                    return Ok(());
280                }
281            }
282        }
283    }
284
285    async fn ping_ws(
286        ping_send: UnboundedSender<Message>,
287        mut shutdown: watch::Receiver<bool>,
288    ) -> Result<(), AuraxisError> {
289        loop {
290            match ping_send.send(Message::Ping(b"Hello".to_vec().into())) {
291                Ok(_) => {
292                    counter!("realtime_total_pings").increment(1);
293                }
294                Err(err) => {
295                    warn!("Ping loop shutting down: {}", err);
296                    counter!("realtime_total_ping_errors").increment(1);
297                    return Ok(());
298                }
299            }
300
301            tokio::select! {
302                _ = shutdown.changed() => return Ok(()),
303                _ = tokio::time::sleep(Duration::from_secs(1)) => {}
304            }
305        }
306    }
307
308    async fn send_ws(
309        mut ws_send: SplitSink<ReconnectWs, Message>,
310        mut ws_send_rx: UnboundedReceiver<Message>,
311        mut shutdown: watch::Receiver<bool>,
312    ) -> Result<(), AuraxisError> {
313        loop {
314            let message = tokio::select! {
315                _ = shutdown.changed() => break,
316                message = ws_send_rx.recv() => message,
317            };
318
319            let Some(msg) = message else {
320                break;
321            };
322
323            // debug!("Sent: {:?}", msg.to_string());
324            if let Err(err) = ws_send.send(msg).await {
325                warn!("Send loop shutting down: {err}");
326                return Err(err.into());
327            }
328            counter!("realtime_messages_total_sent").increment(1);
329        }
330
331        Ok(())
332    }
333
334    async fn read_ws(
335        self,
336        ws_send: UnboundedSender<Message>,
337        mut ws_recv: SplitStream<ReconnectWs>,
338        event_stream_tx: Sender<Event>,
339        shutdown_tx: watch::Sender<bool>,
340        mut shutdown: watch::Receiver<bool>,
341    ) -> Result<(), AuraxisError> {
342        loop {
343            let message = tokio::select! {
344                _ = shutdown.changed() => break,
345                message = ws_recv.next() => message,
346            };
347
348            let Some(message) = message else {
349                break;
350            };
351
352            counter!("realtime_messages_received_total").increment(1);
353            match message {
354                Ok(msg) => {
355                    // debug!("Received: {:?}", msg.to_string());
356                    if let Err(err) = Self::handle_ws_msg(
357                        self.clone(),
358                        ws_send.clone(),
359                        event_stream_tx.clone(),
360                        shutdown_tx.clone(),
361                        msg,
362                    )
363                    .await
364                    {
365                        counter!("realtime_messages_received_total_errored").increment(1);
366                        error!("{:?}", err);
367                    }
368                }
369                Err(err) => {
370                    //println!("{:?}", &err);
371                    counter!("realtime_messages_received_total_errored").increment(1);
372
373                    match err {
374                        WsError::ConnectionClosed => {
375                            error!("Connection closed");
376                            counter!("realtime_total_closed_connections").increment(1);
377                            break;
378                        }
379                        WsError::AlreadyClosed
380                        | WsError::Io(_)
381                        | WsError::Tls(_)
382                        | WsError::Capacity(_)
383                        | WsError::Protocol(_)
384                        | WsError::WriteBufferFull(_)
385                        | WsError::Utf8(_)
386                        | WsError::Url(_)
387                        | WsError::Http(_)
388                        | WsError::HttpFormat(_)
389                        | WsError::AttackAttempt => {}
390                    }
391                }
392            }
393        }
394
395        self.set_ws_sender(None);
396        signal_shutdown(&shutdown_tx);
397
398        Ok(())
399    }
400
401    async fn handle_ws_msg(
402        self,
403        ws_send: UnboundedSender<Message>,
404        events: Sender<Event>,
405        shutdown: watch::Sender<bool>,
406        msg: Message,
407    ) -> Result<(), AuraxisError> {
408        match msg {
409            Message::Text(text) => {
410                let message: CensusMessage = serde_json::from_str(&text)?;
411
412                match message {
413                    CensusMessage::ConnectionStateChanged { connected } => {
414                        if connected {
415                            info!("Connected to Census!");
416
417                            counter!("realtime_total_connections").increment(1);
418
419                            let Some(subscription_message) = self.subscribe_message()? else {
420                                return Ok(());
421                            };
422                            debug!("Subscribing with {:?}", subscription_message);
423
424                            if let Err(err) = ws_send.send(subscription_message) {
425                                signal_shutdown(&shutdown);
426                                debug!(
427                                    "Subscription send aborted because ws channel closed: {err}"
428                                );
429                            }
430                        }
431                    }
432                    CensusMessage::Heartbeat { .. } => {
433                        counter!("realtime_messages_received_heartbeat").increment(1);
434                    }
435                    CensusMessage::ServiceStateChanged { .. } => {}
436                    CensusMessage::ServiceMessage { payload } => {
437                        if events.send(payload).await.is_err() {
438                            debug!("Dropping realtime event because consumer channel is closed");
439                            signal_shutdown(&shutdown);
440                            return Ok(());
441                        }
442                    }
443                    CensusMessage::Subscription { subscription } => {
444                        debug!("Subscribed: {:?}", subscription);
445                    }
446                }
447            }
448            Message::Binary(_) | Message::Pong(_) | Message::Frame(_) => {}
449            Message::Ping(ping) => {
450                if let Err(err) = ws_send.send(Message::Pong(ping)) {
451                    signal_shutdown(&shutdown);
452                    debug!("Pong send aborted because ws channel closed: {err}");
453                }
454            }
455            Message::Close(close) => {
456                counter!("realtime_total_closed_connections").increment(1);
457                if let Some(close_frame) = close {
458                    error!(
459                        "Websocket closed. Code: {}, Reason: {}",
460                        close_frame.code, close_frame.reason
461                    );
462                }
463                warn!("Websocket close frame received; waiting for reconnect");
464            }
465        }
466
467        Ok(())
468    }
469
470    fn subscribe_message(&self) -> Result<Option<Message>, AuraxisError> {
471        let subscription = self.current_subscription();
472        if subscription.is_empty() {
473            return Ok(None);
474        }
475
476        Ok(Some(Message::Text(
477            serde_json::to_string(&Action::Subscribe(subscription))?.into(),
478        )))
479    }
480
481    fn clear_subscribe_message(
482        subscription: &SubscriptionSettings,
483    ) -> Result<Message, AuraxisError> {
484        Ok(Message::Text(
485            serde_json::to_string(&Action::ClearSubscribe {
486                all: None,
487                event_names: subscription.event_names.clone(),
488                characters: subscription.characters.clone(),
489                worlds: subscription.worlds.clone(),
490                service: subscription.service.clone(),
491            })?
492            .into(),
493        ))
494    }
495
496    fn clear_all_subscribe_message() -> Result<Message, AuraxisError> {
497        Ok(Message::Text(
498            serde_json::to_string(&Action::ClearSubscribe {
499                all: Some(true),
500                event_names: None,
501                characters: None,
502                worlds: None,
503                service: crate::realtime::Service::Event,
504            })?
505            .into(),
506        ))
507    }
508
509    fn current_subscription(&self) -> SubscriptionSettings {
510        self.state
511            .read()
512            .expect("realtime client state poisoned")
513            .subscription_config
514            .clone()
515    }
516
517    fn current_ws_sender(&self) -> Option<UnboundedSender<Message>> {
518        self.state
519            .read()
520            .expect("realtime client state poisoned")
521            .ws_send
522            .clone()
523    }
524
525    fn set_ws_sender(&self, ws_send: Option<UnboundedSender<Message>>) {
526        self.state
527            .write()
528            .expect("realtime client state poisoned")
529            .ws_send = ws_send;
530    }
531}
532
533fn signal_shutdown(shutdown: &watch::Sender<bool>) {
534    let _ = shutdown.send(true);
535}
536
537impl Stream for WebSocket {
538    type Item = Result<Message, WsError>;
539
540    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
541        Pin::new(&mut self.0).poll_next(cx)
542    }
543}
544
545impl Sink<Message> for WebSocket {
546    type Error = WsError;
547
548    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
549        Pin::new(&mut self.0).poll_ready(cx)
550    }
551
552    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
553        Pin::new(&mut self.0).start_send(item)
554    }
555
556    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
557        Pin::new(&mut self.0).poll_flush(cx)
558    }
559
560    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
561        Pin::new(&mut self.0).poll_close(cx)
562    }
563}
564
565impl UnderlyingStream<String, Result<Message, WsError>, WsError> for WebSocket {
566    // Establishes connection.
567    // Additionally, this will be used when reconnect tries are attempted.
568    fn establish(addr: String) -> Pin<Box<dyn Future<Output = Result<Self, WsError>> + Send>> {
569        Box::pin(async move {
570            // In this case, we are trying to connect to the WebSocket endpoint
571            let (websocket, _) = connect_async(addr).await?;
572            Ok(WebSocket(websocket))
573        })
574    }
575
576    // The following errors are considered disconnect errors.
577    fn is_write_disconnect_error(&self, err: &WsError) -> bool {
578        matches!(
579            err,
580            WsError::ConnectionClosed
581                | WsError::AlreadyClosed
582                | WsError::Io(_)
583                | WsError::Tls(_)
584                | WsError::Protocol(_)
585        )
586    }
587
588    // If an `Err` is read, then there might be an disconnection.
589    fn is_read_disconnect_error(&self, item: &Result<Message, WsError>) -> bool {
590        if let Err(e) = item {
591            self.is_write_disconnect_error(e)
592        } else {
593            false
594        }
595    }
596
597    // Return "Exhausted" if all retry attempts are failed.
598    fn exhaust_err() -> WsError {
599        WsError::Io(io::Error::other("Exhausted"))
600    }
601}