1use super::Message as CensusMessage;
2use crate::AuraxisError;
3use crate::realtime::{Action, Event, REALTIME_URL, SubscriptionSettings};
4use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, RwLock};
7use std::task::{Context, Poll};
8
9use std::time::Duration;
10
11use futures_util::stream::{SplitSink, SplitStream};
12use futures_util::{Future, Sink, SinkExt, Stream, StreamExt};
13use metrics::{counter, describe_counter};
14use stream_reconnect::{ReconnectStream, UnderlyingStream};
15use tokio::net::TcpStream;
16use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
17use tokio::sync::watch;
18use tokio_tungstenite::tungstenite::Message;
19use tokio_tungstenite::tungstenite::error::Error as WsError;
20use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
21use tracing::{debug, error, info, warn};
22
23#[derive(Debug, Clone)]
24pub struct RealtimeClientConfig {
25 pub environment: String,
26 pub service_id: String,
27 pub realtime_url: Option<String>,
28}
29
30impl Default for RealtimeClientConfig {
31 fn default() -> Self {
32 Self {
33 environment: String::from("ps2"),
34 service_id: String::new(),
35 realtime_url: None,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
41pub struct RealtimeClient {
42 config: Arc<RealtimeClientConfig>,
43 state: Arc<RwLock<RealtimeClientState>>,
44}
45
46#[derive(Debug, Clone)]
47struct RealtimeClientState {
48 subscription_config: SubscriptionSettings,
49 ws_send: Option<UnboundedSender<Message>>,
50}
51
52struct WebSocket(WebSocketStream<MaybeTlsStream<TcpStream>>);
53
54type ReconnectWs = ReconnectStream<WebSocket, String, Result<Message, WsError>, WsError>;
55
56impl RealtimeClient {
57 #[must_use]
58 pub fn new(config: RealtimeClientConfig) -> Self {
59 describe_counter!(
60 "realtime_messages_total_sent",
61 "Total number of messages sent to Census stream"
62 );
63 describe_counter!(
64 "realtime_messages_received_total",
65 "Total number of messages received from Census stream"
66 );
67 describe_counter!(
68 "realtime_messages_received_total_errored",
69 "Total number of messages received from Census stream that errored"
70 );
71 describe_counter!(
72 "realtime_total_closed_connections",
73 "Total number of closed connections to Census stream"
74 );
75 describe_counter!(
76 "realtime_total_connections",
77 "Total number of connections to Census stream"
78 );
79 describe_counter!(
80 "realtime_messages_received_heartbeat",
81 "Total number of heartbeat messages received from Census stream"
82 );
83 describe_counter!(
84 "realtime_total_pings",
85 "Total number of ping messages sent to Census stream, may include errors"
86 );
87 describe_counter!(
88 "realtime_total_ping_errors",
89 "Total number of ping messages that failed to receive a response from Census stream"
90 );
91 describe_counter!(
92 "realtime_total_resubscriptions",
93 "Total number of resubscriptions to Census stream"
94 );
95
96 Self {
97 config: Arc::new(config),
98 state: Arc::new(RwLock::new(RealtimeClientState {
99 subscription_config: SubscriptionSettings::empty(),
100 ws_send: None,
101 })),
102 }
103 }
104
105 pub async fn connect(&mut self) -> Result<Receiver<Event>, AuraxisError> {
122 if self.current_ws_sender().is_some() {
123 return Err(anyhow::anyhow!("RealtimeClient is already connected").into());
124 }
125
126 let census_addr = format!(
127 "{}?environment={}&service-id=s:{}",
128 self.config.realtime_url.as_deref().unwrap_or(REALTIME_URL),
129 self.config.environment,
130 self.config.service_id
131 );
132
133 let websocket = ReconnectWs::connect(census_addr).await?;
134
135 let (ws_send, ws_recv) = websocket.split();
136 let (ws_send_tx, ws_send_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
137 let (event_stream_tx, event_stream_rx) = tokio::sync::mpsc::channel::<Event>(1000);
138 let (shutdown_tx, shutdown_rx) = watch::channel(false);
139
140 self.set_ws_sender(Some(ws_send_tx.clone()));
141
142 tokio::spawn(Self::send_ws(ws_send, ws_send_rx, shutdown_rx.clone()));
143 tokio::spawn(Self::ping_ws(ws_send_tx.clone(), shutdown_rx.clone()));
144 tokio::spawn(Self::resubscribe(
145 self.clone(),
146 ws_send_tx.clone(),
147 shutdown_rx.clone(),
148 ));
149 tokio::spawn(Self::read_ws(
150 self.clone(),
151 ws_send_tx,
152 ws_recv,
153 event_stream_tx,
154 shutdown_tx,
155 shutdown_rx,
156 ));
157
158 Ok(event_stream_rx)
159 }
160
161 pub fn subscribe(&mut self, subscription: SubscriptionSettings) {
162 let ws_send = {
163 let mut state = self.state.write().expect("realtime client state poisoned");
164 state.subscription_config.merge(subscription);
165 state.ws_send.clone()
166 };
167
168 let subscribe_message = match self.subscribe_message() {
169 Ok(Some(message)) => message,
170 Ok(None) => return,
171 Err(err) => {
172 error!("Failed to serialize subscription update: {err}");
173 return;
174 }
175 };
176
177 if let Some(ws_send) = ws_send
178 && let Err(err) = ws_send.send(subscribe_message)
179 {
180 warn!("Failed to enqueue live subscription update: {err}");
181 self.set_ws_sender(None);
182 }
183 }
184
185 pub fn clear_subscribe(&mut self, subscription: SubscriptionSettings) {
186 let (ws_send, current_subscription) = {
187 let mut state = self.state.write().expect("realtime client state poisoned");
188 state.subscription_config.clear(&subscription);
189 (state.ws_send.clone(), state.subscription_config.clone())
190 };
191
192 let clear_message = match Self::clear_subscribe_message(&subscription) {
193 Ok(message) => message,
194 Err(err) => {
195 error!("Failed to serialize clear subscription update: {err}");
196 return;
197 }
198 };
199
200 if let Some(ws_send) = ws_send {
201 if let Err(err) = ws_send.send(clear_message) {
202 warn!("Failed to enqueue clear subscription update: {err}");
203 self.set_ws_sender(None);
204 return;
205 }
206
207 if subscription.logical_and_characters_with_worlds.is_some()
208 && !current_subscription.is_empty()
209 {
210 match serde_json::to_string(&Action::Subscribe(current_subscription))
211 .map(|message| Message::Text(message.into()))
212 {
213 Ok(message) => {
214 if let Err(err) = ws_send.send(message) {
215 warn!("Failed to enqueue resubscribe after logical-and update: {err}");
216 self.set_ws_sender(None);
217 }
218 }
219 Err(err) => {
220 error!("Failed to serialize resubscribe after logical-and update: {err}");
221 }
222 }
223 }
224 }
225 }
226
227 pub fn clear_all_subscriptions(&mut self) {
228 let ws_send = {
229 let mut state = self.state.write().expect("realtime client state poisoned");
230 state.subscription_config = SubscriptionSettings::empty();
231 state.ws_send.clone()
232 };
233
234 if let Some(ws_send) = ws_send {
235 match Self::clear_all_subscribe_message() {
236 Ok(message) => {
237 if let Err(err) = ws_send.send(message) {
238 warn!("Failed to enqueue clear-all subscription update: {err}");
239 self.set_ws_sender(None);
240 }
241 }
242 Err(err) => {
243 error!("Failed to serialize clear-all subscription update: {err}");
244 }
245 }
246 }
247 }
248
249 async fn resubscribe(
250 self,
251 ws_send: UnboundedSender<Message>,
252 mut shutdown: watch::Receiver<bool>,
253 ) -> Result<(), AuraxisError> {
254 loop {
255 if *shutdown.borrow() {
256 return Ok(());
257 }
258
259 let Some(message) = self.subscribe_message()? else {
260 tokio::select! {
261 _ = shutdown.changed() => return Ok(()),
262 _ = tokio::time::sleep(Duration::from_secs(60 * 30)) => {}
263 }
264 continue;
265 };
266
267 let res = ws_send.send(message);
268
269 match res {
270 Ok(_) => {
271 counter!("realtime_total_resubscriptions").increment(1);
272 tokio::select! {
273 _ = shutdown.changed() => return Ok(()),
274 _ = tokio::time::sleep(Duration::from_secs(60 * 30)) => {}
275 }
276 }
277 Err(err) => {
278 warn!("Subscription loop shutting down: {}", err);
279 return Ok(());
280 }
281 }
282 }
283 }
284
285 async fn ping_ws(
286 ping_send: UnboundedSender<Message>,
287 mut shutdown: watch::Receiver<bool>,
288 ) -> Result<(), AuraxisError> {
289 loop {
290 match ping_send.send(Message::Ping(b"Hello".to_vec().into())) {
291 Ok(_) => {
292 counter!("realtime_total_pings").increment(1);
293 }
294 Err(err) => {
295 warn!("Ping loop shutting down: {}", err);
296 counter!("realtime_total_ping_errors").increment(1);
297 return Ok(());
298 }
299 }
300
301 tokio::select! {
302 _ = shutdown.changed() => return Ok(()),
303 _ = tokio::time::sleep(Duration::from_secs(1)) => {}
304 }
305 }
306 }
307
308 async fn send_ws(
309 mut ws_send: SplitSink<ReconnectWs, Message>,
310 mut ws_send_rx: UnboundedReceiver<Message>,
311 mut shutdown: watch::Receiver<bool>,
312 ) -> Result<(), AuraxisError> {
313 loop {
314 let message = tokio::select! {
315 _ = shutdown.changed() => break,
316 message = ws_send_rx.recv() => message,
317 };
318
319 let Some(msg) = message else {
320 break;
321 };
322
323 if let Err(err) = ws_send.send(msg).await {
325 warn!("Send loop shutting down: {err}");
326 return Err(err.into());
327 }
328 counter!("realtime_messages_total_sent").increment(1);
329 }
330
331 Ok(())
332 }
333
334 async fn read_ws(
335 self,
336 ws_send: UnboundedSender<Message>,
337 mut ws_recv: SplitStream<ReconnectWs>,
338 event_stream_tx: Sender<Event>,
339 shutdown_tx: watch::Sender<bool>,
340 mut shutdown: watch::Receiver<bool>,
341 ) -> Result<(), AuraxisError> {
342 loop {
343 let message = tokio::select! {
344 _ = shutdown.changed() => break,
345 message = ws_recv.next() => message,
346 };
347
348 let Some(message) = message else {
349 break;
350 };
351
352 counter!("realtime_messages_received_total").increment(1);
353 match message {
354 Ok(msg) => {
355 if let Err(err) = Self::handle_ws_msg(
357 self.clone(),
358 ws_send.clone(),
359 event_stream_tx.clone(),
360 shutdown_tx.clone(),
361 msg,
362 )
363 .await
364 {
365 counter!("realtime_messages_received_total_errored").increment(1);
366 error!("{:?}", err);
367 }
368 }
369 Err(err) => {
370 counter!("realtime_messages_received_total_errored").increment(1);
372
373 match err {
374 WsError::ConnectionClosed => {
375 error!("Connection closed");
376 counter!("realtime_total_closed_connections").increment(1);
377 break;
378 }
379 WsError::AlreadyClosed
380 | WsError::Io(_)
381 | WsError::Tls(_)
382 | WsError::Capacity(_)
383 | WsError::Protocol(_)
384 | WsError::WriteBufferFull(_)
385 | WsError::Utf8(_)
386 | WsError::Url(_)
387 | WsError::Http(_)
388 | WsError::HttpFormat(_)
389 | WsError::AttackAttempt => {}
390 }
391 }
392 }
393 }
394
395 self.set_ws_sender(None);
396 signal_shutdown(&shutdown_tx);
397
398 Ok(())
399 }
400
401 async fn handle_ws_msg(
402 self,
403 ws_send: UnboundedSender<Message>,
404 events: Sender<Event>,
405 shutdown: watch::Sender<bool>,
406 msg: Message,
407 ) -> Result<(), AuraxisError> {
408 match msg {
409 Message::Text(text) => {
410 let message: CensusMessage = serde_json::from_str(&text)?;
411
412 match message {
413 CensusMessage::ConnectionStateChanged { connected } => {
414 if connected {
415 info!("Connected to Census!");
416
417 counter!("realtime_total_connections").increment(1);
418
419 let Some(subscription_message) = self.subscribe_message()? else {
420 return Ok(());
421 };
422 debug!("Subscribing with {:?}", subscription_message);
423
424 if let Err(err) = ws_send.send(subscription_message) {
425 signal_shutdown(&shutdown);
426 debug!(
427 "Subscription send aborted because ws channel closed: {err}"
428 );
429 }
430 }
431 }
432 CensusMessage::Heartbeat { .. } => {
433 counter!("realtime_messages_received_heartbeat").increment(1);
434 }
435 CensusMessage::ServiceStateChanged { .. } => {}
436 CensusMessage::ServiceMessage { payload } => {
437 if events.send(payload).await.is_err() {
438 debug!("Dropping realtime event because consumer channel is closed");
439 signal_shutdown(&shutdown);
440 return Ok(());
441 }
442 }
443 CensusMessage::Subscription { subscription } => {
444 debug!("Subscribed: {:?}", subscription);
445 }
446 }
447 }
448 Message::Binary(_) | Message::Pong(_) | Message::Frame(_) => {}
449 Message::Ping(ping) => {
450 if let Err(err) = ws_send.send(Message::Pong(ping)) {
451 signal_shutdown(&shutdown);
452 debug!("Pong send aborted because ws channel closed: {err}");
453 }
454 }
455 Message::Close(close) => {
456 counter!("realtime_total_closed_connections").increment(1);
457 if let Some(close_frame) = close {
458 error!(
459 "Websocket closed. Code: {}, Reason: {}",
460 close_frame.code, close_frame.reason
461 );
462 }
463 warn!("Websocket close frame received; waiting for reconnect");
464 }
465 }
466
467 Ok(())
468 }
469
470 fn subscribe_message(&self) -> Result<Option<Message>, AuraxisError> {
471 let subscription = self.current_subscription();
472 if subscription.is_empty() {
473 return Ok(None);
474 }
475
476 Ok(Some(Message::Text(
477 serde_json::to_string(&Action::Subscribe(subscription))?.into(),
478 )))
479 }
480
481 fn clear_subscribe_message(
482 subscription: &SubscriptionSettings,
483 ) -> Result<Message, AuraxisError> {
484 Ok(Message::Text(
485 serde_json::to_string(&Action::ClearSubscribe {
486 all: None,
487 event_names: subscription.event_names.clone(),
488 characters: subscription.characters.clone(),
489 worlds: subscription.worlds.clone(),
490 service: subscription.service.clone(),
491 })?
492 .into(),
493 ))
494 }
495
496 fn clear_all_subscribe_message() -> Result<Message, AuraxisError> {
497 Ok(Message::Text(
498 serde_json::to_string(&Action::ClearSubscribe {
499 all: Some(true),
500 event_names: None,
501 characters: None,
502 worlds: None,
503 service: crate::realtime::Service::Event,
504 })?
505 .into(),
506 ))
507 }
508
509 fn current_subscription(&self) -> SubscriptionSettings {
510 self.state
511 .read()
512 .expect("realtime client state poisoned")
513 .subscription_config
514 .clone()
515 }
516
517 fn current_ws_sender(&self) -> Option<UnboundedSender<Message>> {
518 self.state
519 .read()
520 .expect("realtime client state poisoned")
521 .ws_send
522 .clone()
523 }
524
525 fn set_ws_sender(&self, ws_send: Option<UnboundedSender<Message>>) {
526 self.state
527 .write()
528 .expect("realtime client state poisoned")
529 .ws_send = ws_send;
530 }
531}
532
533fn signal_shutdown(shutdown: &watch::Sender<bool>) {
534 let _ = shutdown.send(true);
535}
536
537impl Stream for WebSocket {
538 type Item = Result<Message, WsError>;
539
540 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
541 Pin::new(&mut self.0).poll_next(cx)
542 }
543}
544
545impl Sink<Message> for WebSocket {
546 type Error = WsError;
547
548 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
549 Pin::new(&mut self.0).poll_ready(cx)
550 }
551
552 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
553 Pin::new(&mut self.0).start_send(item)
554 }
555
556 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
557 Pin::new(&mut self.0).poll_flush(cx)
558 }
559
560 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
561 Pin::new(&mut self.0).poll_close(cx)
562 }
563}
564
565impl UnderlyingStream<String, Result<Message, WsError>, WsError> for WebSocket {
566 fn establish(addr: String) -> Pin<Box<dyn Future<Output = Result<Self, WsError>> + Send>> {
569 Box::pin(async move {
570 let (websocket, _) = connect_async(addr).await?;
572 Ok(WebSocket(websocket))
573 })
574 }
575
576 fn is_write_disconnect_error(&self, err: &WsError) -> bool {
578 matches!(
579 err,
580 WsError::ConnectionClosed
581 | WsError::AlreadyClosed
582 | WsError::Io(_)
583 | WsError::Tls(_)
584 | WsError::Protocol(_)
585 )
586 }
587
588 fn is_read_disconnect_error(&self, item: &Result<Message, WsError>) -> bool {
590 if let Err(e) = item {
591 self.is_write_disconnect_error(e)
592 } else {
593 false
594 }
595 }
596
597 fn exhaust_err() -> WsError {
599 WsError::Io(io::Error::other("Exhausted"))
600 }
601}