xref: /openbmc/qemu/rust/common/src/callbacks.rs (revision ccafa85a97e38698b798115bba6c18c849846e25)
1 // SPDX-License-Identifier: MIT
2 
3 //! Utility functions to deal with callbacks from C to Rust.
4 
5 use std::{mem, ptr::NonNull};
6 
7 /// Trait for functions (types implementing [`Fn`]) that can be used as
8 /// callbacks. These include both zero-capture closures and function pointers.
9 ///
10 /// In Rust, calling a function through the `Fn` trait normally requires a
11 /// `self` parameter, even though for zero-sized functions (including function
12 /// pointers) the type itself contains all necessary information to call the
13 /// function. This trait provides a `call` function that doesn't require `self`,
14 /// allowing zero-sized functions to be called using only their type.
15 ///
16 /// This enables zero-sized functions to be passed entirely through generic
17 /// parameters and resolved at compile-time. A typical use is a function
18 /// receiving an unused parameter of generic type `F` and calling it via
19 /// `F::call` or passing it to another function via `func::<F>`.
20 ///
21 /// QEMU uses this trick to create wrappers to C callbacks.  The wrappers
22 /// are needed to convert an opaque `*mut c_void` into a Rust reference,
23 /// but they only have a single opaque that they can use.  The `FnCall`
24 /// trait makes it possible to use that opaque for `self` or any other
25 /// reference:
26 ///
27 /// ```ignore
28 /// // The compiler creates a new `rust_bh_cb` wrapper for each function
29 /// // passed to `qemu_bh_schedule_oneshot` below.
30 /// unsafe extern "C" fn rust_bh_cb<T, F: for<'a> FnCall<(&'a T,)>>(
31 ///     opaque: *mut c_void,
32 /// ) {
33 ///     // SAFETY: the opaque was passed as a reference to `T`.
34 ///     F::call((unsafe { &*(opaque.cast::<T>()) }, ))
35 /// }
36 ///
37 /// // The `_f` parameter is unused but it helps the compiler build the appropriate `F`.
38 /// // Using a reference allows usage in const context.
39 /// fn qemu_bh_schedule_oneshot<T, F: for<'a> FnCall<(&'a T,)>>(_f: &F, opaque: &T) {
40 ///     let cb: unsafe extern "C" fn(*mut c_void) = rust_bh_cb::<T, F>;
41 ///     unsafe {
42 ///         bindings::qemu_bh_schedule_oneshot(cb, opaque as *const T as *const c_void as *mut c_void)
43 ///     }
44 /// }
45 /// ```
46 ///
47 /// Each wrapper is a separate instance of `rust_bh_cb` and is therefore
48 /// compiled to a separate function ("monomorphization").  If you wanted
49 /// to pass `self` as the opaque value, the generic parameters would be
50 /// `rust_bh_cb::<Self, F>`.
51 ///
52 /// `Args` is a tuple type whose types are the arguments of the function,
53 /// while `R` is the returned type.
54 ///
55 /// # Examples
56 ///
57 /// ```
58 /// # use common::callbacks::FnCall;
59 /// fn call_it<F: for<'a> FnCall<(&'a str,), String>>(_f: &F, s: &str) -> String {
60 ///     F::call((s,))
61 /// }
62 ///
63 /// let s: String = call_it(&str::to_owned, "hello world");
64 /// assert_eq!(s, "hello world");
65 /// ```
66 ///
67 /// Note that the compiler will produce a different version of `call_it` for
68 /// each function that is passed to it.  Therefore the argument is not really
69 /// used, except to decide what is `F` and what `F::call` does.
70 ///
71 /// Attempting to pass a non-zero-sized closure causes a compile-time failure:
72 ///
73 /// ```compile_fail
74 /// # use common::callbacks::FnCall;
75 /// # fn call_it<'a, F: FnCall<(&'a str,), String>>(_f: &F, s: &'a str) -> String {
76 /// #     F::call((s,))
77 /// # }
78 /// let x: &'static str = "goodbye world";
79 /// call_it(&move |_| String::from(x), "hello workd");
80 /// ```
81 ///
82 /// `()` can be used to indicate "no function":
83 ///
84 /// ```
85 /// # use common::callbacks::FnCall;
86 /// fn optional<F: for<'a> FnCall<(&'a str,), String>>(_f: &F, s: &str) -> Option<String> {
87 ///     if F::IS_SOME {
88 ///         Some(F::call((s,)))
89 ///     } else {
90 ///         None
91 ///     }
92 /// }
93 ///
94 /// assert!(optional(&(), "hello world").is_none());
95 /// ```
96 ///
97 /// Invoking `F::call` will then be a run-time error.
98 ///
99 /// ```should_panic
100 /// # use common::callbacks::FnCall;
101 /// # fn call_it<F: for<'a> FnCall<(&'a str,), String>>(_f: &F, s: &str) -> String {
102 /// #     F::call((s,))
103 /// # }
104 /// let s: String = call_it(&(), "hello world"); // panics
105 /// ```
106 ///
107 /// # Safety
108 ///
109 /// Because `Self` is a zero-sized type, all instances of the type are
110 /// equivalent. However, in addition to this, `Self` must have no invariants
111 /// that could be violated by creating a reference to it.
112 ///
113 /// This is always true for zero-capture closures and function pointers, as long
114 /// as the code is able to name the function in the first place.
115 pub unsafe trait FnCall<Args, R = ()>: 'static + Sync + Sized {
116     /// `true` if `Self` is an actual function type and not `()`.
117     ///
118     /// # Examples
119     ///
120     /// You can use `IS_SOME` to catch this at compile time:
121     ///
122     /// ```compile_fail
123     /// # use common::callbacks::FnCall;
124     /// fn call_it<F: for<'a> FnCall<(&'a str,), String>>(_f: &F, s: &str) -> String {
125     ///     const { assert!(F::IS_SOME) }
126     ///     F::call((s,))
127     /// }
128     ///
129     /// let s: String = call_it((), "hello world"); // does not compile
130     /// ```
131     const IS_SOME: bool;
132 
133     /// `false` if `Self` is an actual function type, `true` if it is `()`.
134     fn is_none() -> bool {
135         !Self::IS_SOME
136     }
137 
138     /// `true` if `Self` is an actual function type, `false` if it is `()`.
139     fn is_some() -> bool {
140         Self::IS_SOME
141     }
142 
143     /// Call the function with the arguments in args.
144     fn call(a: Args) -> R;
145 }
146 
147 /// `()` acts as a "null" callback.  Using `()` and `function` is nicer
148 /// than `None` and `Some(function)`, because the compiler is unable to
149 /// infer the type of just `None`.  Therefore, the trait itself acts as the
150 /// option type, with functions [`FnCall::is_some`] and [`FnCall::is_none`].
151 unsafe impl<Args, R> FnCall<Args, R> for () {
152     const IS_SOME: bool = false;
153 
154     /// Call the function with the arguments in args.
155     fn call(_a: Args) -> R {
156         panic!("callback not specified")
157     }
158 }
159 
160 macro_rules! impl_call {
161     ($($args:ident,)* ) => (
162         // SAFETY: because each function is treated as a separate type,
163         // accessing `FnCall` is only possible in code that would be
164         // allowed to call the function.
165         unsafe impl<F, $($args,)* R> FnCall<($($args,)*), R> for F
166         where
167             F: 'static + Sync + Sized + Fn($($args, )*) -> R,
168         {
169             const IS_SOME: bool = true;
170 
171             #[inline(always)]
172             fn call(a: ($($args,)*)) -> R {
173                 const { assert!(mem::size_of::<Self>() == 0) };
174 
175                 // SAFETY: the safety of this method is the condition for implementing
176                 // `FnCall`.  As to the `NonNull` idiom to create a zero-sized type,
177                 // see https://github.com/rust-lang/libs-team/issues/292.
178                 let f: &'static F = unsafe { &*NonNull::<Self>::dangling().as_ptr() };
179                 let ($($args,)*) = a;
180                 f($($args,)*)
181             }
182         }
183     )
184 }
185 
186 impl_call!(_1, _2, _3, _4, _5,);
187 impl_call!(_1, _2, _3, _4,);
188 impl_call!(_1, _2, _3,);
189 impl_call!(_1, _2,);
190 impl_call!(_1,);
191 impl_call!();
192 
193 #[cfg(test)]
194 mod tests {
195     use super::*;
196 
197     // The `_f` parameter is unused but it helps the compiler infer `F`.
198     fn do_test_call<'a, F: FnCall<(&'a str,), String>>(_f: &F) -> String {
199         F::call(("hello world",))
200     }
201 
202     #[test]
203     fn test_call() {
204         assert_eq!(do_test_call(&str::to_owned), "hello world")
205     }
206 
207     // The `_f` parameter is unused but it helps the compiler infer `F`.
208     fn do_test_is_some<'a, F: FnCall<(&'a str,), String>>(_f: &F) {
209         assert!(F::is_some());
210     }
211 
212     #[test]
213     fn test_is_some() {
214         do_test_is_some(&str::to_owned);
215     }
216 }
217