use atomic_waker::AtomicWaker; use once_cell::unsync::Lazy; use std::future; use std::marker::PhantomData; use std::ops::Deref; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::{self, Receiver, RecvError, SendError, Sender, TryRecvError}; use std::sync::{Arc, Condvar, Mutex, RwLock}; use std::task::Poll; use wasm_bindgen::prelude::wasm_bindgen; use wasm_bindgen::{JsCast, JsValue}; // Unsafe wrapper type that allows us to use `T` when it's not `Send` from other threads. // `value` **must** only be accessed on the main thread. pub struct MainThreadSafe { // We wrap this in an `Arc` to allow it to be safely cloned without accessing the value. // The `RwLock` lets us safely drop in any thread. // The `Option` lets us safely drop `T` only in the main thread, while letting other threads drop `None`. value: Arc>>, handler: fn(&RwLock>, E), sender: AsyncSender, // Prevent's `Send` or `Sync` to be automatically implemented. local: PhantomData<*const ()>, } impl MainThreadSafe { thread_local! { static MAIN_THREAD: Lazy = Lazy::new(|| { #[wasm_bindgen] extern "C" { #[derive(Clone)] type Global; #[wasm_bindgen(method, getter, js_name = Window)] fn window(this: &Global) -> JsValue; } let global: Global = js_sys::global().unchecked_into(); !global.window().is_undefined() }); } #[track_caller] fn new(value: T, handler: fn(&RwLock>, E)) -> Option { Self::MAIN_THREAD.with(|safe| { if !*safe.deref() { panic!("only callable from inside the `Window`") } }); let value = Arc::new(RwLock::new(Some(value))); let (sender, receiver) = channel::(); wasm_bindgen_futures::spawn_local({ let value = Arc::clone(&value); async move { while let Ok(event) = receiver.next().await { handler(&value, event) } // An error was returned because the channel was closed, which // happens when all senders are dropped. value.write().unwrap().take().unwrap(); } }); Some(Self { value, handler, sender, local: PhantomData, }) } pub fn send(&self, event: E) { Self::MAIN_THREAD.with(|is_main_thread| { if *is_main_thread.deref() { (self.handler)(&self.value, event) } else { self.sender.send(event).unwrap() } }) } fn is_main_thread(&self) -> bool { Self::MAIN_THREAD.with(|is_main_thread| *is_main_thread.deref()) } pub fn with(&self, f: impl FnOnce(&T) -> R) -> Option { Self::MAIN_THREAD.with(|is_main_thread| { if *is_main_thread.deref() { Some(f(self.value.read().unwrap().as_ref().unwrap())) } else { None } }) } fn with_mut(&self, f: impl FnOnce(&mut T) -> R) -> Option { Self::MAIN_THREAD.with(|is_main_thread| { if *is_main_thread.deref() { Some(f(self.value.write().unwrap().as_mut().unwrap())) } else { None } }) } } impl Clone for MainThreadSafe { fn clone(&self) -> Self { Self { value: self.value.clone(), handler: self.handler, sender: self.sender.clone(), local: PhantomData, } } } unsafe impl Send for MainThreadSafe {} unsafe impl Sync for MainThreadSafe {} fn channel() -> (AsyncSender, AsyncReceiver) { let (sender, receiver) = mpsc::channel(); let sender = Arc::new(Mutex::new(sender)); let waker = Arc::new(AtomicWaker::new()); let closed = Arc::new(AtomicBool::new(false)); let sender = AsyncSender { sender, closed: closed.clone(), waker: Arc::clone(&waker), }; let receiver = AsyncReceiver { receiver, closed, waker, }; (sender, receiver) } struct AsyncSender { // We need to wrap it into a `Mutex` to make it `Sync`. So the sender can't // be accessed on the main thread, as it could block. Additionally we need // to wrap it in an `Arc` to make it clonable on the main thread without // having to block. sender: Arc>>, closed: Arc, waker: Arc, } impl AsyncSender { pub fn send(&self, event: T) -> Result<(), SendError> { self.sender.lock().unwrap().send(event)?; self.waker.wake(); Ok(()) } } impl Clone for AsyncSender { fn clone(&self) -> Self { Self { sender: self.sender.clone(), waker: self.waker.clone(), closed: self.closed.clone(), } } } impl Drop for AsyncSender { fn drop(&mut self) { // If it's the last + the one held by the receiver make sure to wake it // up and tell it that all receiver have dropped. if Arc::strong_count(&self.closed) == 2 { self.closed.store(true, Ordering::Relaxed); self.waker.wake() } } } struct AsyncReceiver { receiver: Receiver, closed: Arc, waker: Arc, } impl AsyncReceiver { pub async fn next(&self) -> Result { future::poll_fn(|cx| match self.receiver.try_recv() { Ok(event) => Poll::Ready(Ok(event)), Err(TryRecvError::Empty) => { if self.closed.load(Ordering::Relaxed) { return Poll::Ready(Err(RecvError)); } self.waker.register(cx.waker()); match self.receiver.try_recv() { Ok(event) => Poll::Ready(Ok(event)), Err(TryRecvError::Empty) => { if self.closed.load(Ordering::Relaxed) { Poll::Ready(Err(RecvError)) } else { Poll::Pending } } Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError)), } } Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError)), }) .await } } pub struct Dispatcher(MainThreadSafe>); pub enum Closure { Ref(Box), RefMut(Box), } impl Dispatcher { #[track_caller] pub fn new(value: T) -> Option { MainThreadSafe::new(value, |value, closure| match closure { Closure::Ref(f) => f(value.read().unwrap().as_ref().unwrap()), Closure::RefMut(f) => f(value.write().unwrap().as_mut().unwrap()), }) .map(Self) } pub fn dispatch(&self, f: impl 'static + FnOnce(&T) + Send) { if self.is_main_thread() { self.0.with(f).unwrap() } else { self.0.send(Closure::Ref(Box::new(f))) } } pub fn dispatch_mut(&self, f: impl 'static + FnOnce(&mut T) + Send) { if self.is_main_thread() { self.0.with_mut(f).unwrap() } else { self.0.send(Closure::RefMut(Box::new(f))) } } pub fn queue(&self, f: impl 'static + FnOnce(&T) -> R + Send) -> R { if self.is_main_thread() { self.0.with(f).unwrap() } else { let pair = Arc::new((Mutex::new(None), Condvar::new())); let closure = Closure::Ref(Box::new({ let pair = pair.clone(); move |value| { *pair.0.lock().unwrap() = Some(f(value)); pair.1.notify_one(); } })); self.0.send(closure); let mut started = pair.0.lock().unwrap(); while started.is_none() { started = pair.1.wait(started).unwrap(); } started.take().unwrap() } } } impl Deref for Dispatcher { type Target = MainThreadSafe>; fn deref(&self) -> &Self::Target { &self.0 } } type ChannelValue = MainThreadSafe; pub struct Channel(ChannelValue); impl Channel { pub fn new(value: T, handler: fn(&T, E)) -> Option { MainThreadSafe::new((value, handler), |runner, event| { let lock = runner.read().unwrap(); let (value, handler) = lock.as_ref().unwrap(); handler(value, event); }) .map(Self) } } impl Clone for Channel { fn clone(&self) -> Self { Self(self.0.clone()) } } impl Deref for Channel { type Target = ChannelValue; fn deref(&self) -> &Self::Target { &self.0 } }