xref: /openbmc/linux/rust/macros/zeroable.rs (revision 6fa24b41)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 use crate::helpers::{parse_generics, Generics};
4 use proc_macro::{TokenStream, TokenTree};
5 
6 pub(crate) fn derive(input: TokenStream) -> TokenStream {
7     let (
8         Generics {
9             impl_generics,
10             ty_generics,
11         },
12         mut rest,
13     ) = parse_generics(input);
14     // This should be the body of the struct `{...}`.
15     let last = rest.pop();
16     // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
17     let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
18     // Are we inside of a generic where we want to add `Zeroable`?
19     let mut in_generic = !impl_generics.is_empty();
20     // Have we already inserted `Zeroable`?
21     let mut inserted = false;
22     // Level of `<>` nestings.
23     let mut nested = 0;
24     for tt in impl_generics {
25         match &tt {
26             // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
27             TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
28                 if in_generic && !inserted {
29                     new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
30                 }
31                 in_generic = true;
32                 inserted = false;
33                 new_impl_generics.push(tt);
34             }
35             // If we find `'`, then we are entering a lifetime.
36             TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
37                 in_generic = false;
38                 new_impl_generics.push(tt);
39             }
40             TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
41                 new_impl_generics.push(tt);
42                 if in_generic {
43                     new_impl_generics.extend(quote! { ::kernel::init::Zeroable + });
44                     inserted = true;
45                 }
46             }
47             TokenTree::Punct(p) if p.as_char() == '<' => {
48                 nested += 1;
49                 new_impl_generics.push(tt);
50             }
51             TokenTree::Punct(p) if p.as_char() == '>' => {
52                 assert!(nested > 0);
53                 nested -= 1;
54                 new_impl_generics.push(tt);
55             }
56             _ => new_impl_generics.push(tt),
57         }
58     }
59     assert_eq!(nested, 0);
60     if in_generic && !inserted {
61         new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
62     }
63     quote! {
64         ::kernel::__derive_zeroable!(
65             parse_input:
66                 @sig(#(#rest)*),
67                 @impl_generics(#(#new_impl_generics)*),
68                 @ty_generics(#(#ty_generics)*),
69                 @body(#last),
70         );
71     }
72 }
73