xref: /openbmc/linux/rust/macros/zeroable.rs (revision c900529f3d9161bfde5cca0754f83b4d3c3e0220)
1*071cedc8SBenno Lossin // SPDX-License-Identifier: GPL-2.0
2*071cedc8SBenno Lossin 
3*071cedc8SBenno Lossin use crate::helpers::{parse_generics, Generics};
4*071cedc8SBenno Lossin use proc_macro::{TokenStream, TokenTree};
5*071cedc8SBenno Lossin 
derive(input: TokenStream) -> TokenStream6*071cedc8SBenno Lossin pub(crate) fn derive(input: TokenStream) -> TokenStream {
7*071cedc8SBenno Lossin     let (
8*071cedc8SBenno Lossin         Generics {
9*071cedc8SBenno Lossin             impl_generics,
10*071cedc8SBenno Lossin             ty_generics,
11*071cedc8SBenno Lossin         },
12*071cedc8SBenno Lossin         mut rest,
13*071cedc8SBenno Lossin     ) = parse_generics(input);
14*071cedc8SBenno Lossin     // This should be the body of the struct `{...}`.
15*071cedc8SBenno Lossin     let last = rest.pop();
16*071cedc8SBenno Lossin     // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
17*071cedc8SBenno Lossin     let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
18*071cedc8SBenno Lossin     // Are we inside of a generic where we want to add `Zeroable`?
19*071cedc8SBenno Lossin     let mut in_generic = !impl_generics.is_empty();
20*071cedc8SBenno Lossin     // Have we already inserted `Zeroable`?
21*071cedc8SBenno Lossin     let mut inserted = false;
22*071cedc8SBenno Lossin     // Level of `<>` nestings.
23*071cedc8SBenno Lossin     let mut nested = 0;
24*071cedc8SBenno Lossin     for tt in impl_generics {
25*071cedc8SBenno Lossin         match &tt {
26*071cedc8SBenno Lossin             // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
27*071cedc8SBenno Lossin             TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
28*071cedc8SBenno Lossin                 if in_generic && !inserted {
29*071cedc8SBenno Lossin                     new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
30*071cedc8SBenno Lossin                 }
31*071cedc8SBenno Lossin                 in_generic = true;
32*071cedc8SBenno Lossin                 inserted = false;
33*071cedc8SBenno Lossin                 new_impl_generics.push(tt);
34*071cedc8SBenno Lossin             }
35*071cedc8SBenno Lossin             // If we find `'`, then we are entering a lifetime.
36*071cedc8SBenno Lossin             TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
37*071cedc8SBenno Lossin                 in_generic = false;
38*071cedc8SBenno Lossin                 new_impl_generics.push(tt);
39*071cedc8SBenno Lossin             }
40*071cedc8SBenno Lossin             TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
41*071cedc8SBenno Lossin                 new_impl_generics.push(tt);
42*071cedc8SBenno Lossin                 if in_generic {
43*071cedc8SBenno Lossin                     new_impl_generics.extend(quote! { ::kernel::init::Zeroable + });
44*071cedc8SBenno Lossin                     inserted = true;
45*071cedc8SBenno Lossin                 }
46*071cedc8SBenno Lossin             }
47*071cedc8SBenno Lossin             TokenTree::Punct(p) if p.as_char() == '<' => {
48*071cedc8SBenno Lossin                 nested += 1;
49*071cedc8SBenno Lossin                 new_impl_generics.push(tt);
50*071cedc8SBenno Lossin             }
51*071cedc8SBenno Lossin             TokenTree::Punct(p) if p.as_char() == '>' => {
52*071cedc8SBenno Lossin                 assert!(nested > 0);
53*071cedc8SBenno Lossin                 nested -= 1;
54*071cedc8SBenno Lossin                 new_impl_generics.push(tt);
55*071cedc8SBenno Lossin             }
56*071cedc8SBenno Lossin             _ => new_impl_generics.push(tt),
57*071cedc8SBenno Lossin         }
58*071cedc8SBenno Lossin     }
59*071cedc8SBenno Lossin     assert_eq!(nested, 0);
60*071cedc8SBenno Lossin     if in_generic && !inserted {
61*071cedc8SBenno Lossin         new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
62*071cedc8SBenno Lossin     }
63*071cedc8SBenno Lossin     quote! {
64*071cedc8SBenno Lossin         ::kernel::__derive_zeroable!(
65*071cedc8SBenno Lossin             parse_input:
66*071cedc8SBenno Lossin                 @sig(#(#rest)*),
67*071cedc8SBenno Lossin                 @impl_generics(#(#new_impl_generics)*),
68*071cedc8SBenno Lossin                 @ty_generics(#(#ty_generics)*),
69*071cedc8SBenno Lossin                 @body(#last),
70*071cedc8SBenno Lossin         );
71*071cedc8SBenno Lossin     }
72*071cedc8SBenno Lossin }
73