xref: /openbmc/linux/rust/macros/pin_data.rs (revision 8fcd94ad)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro::{Punct, Spacing, TokenStream, TokenTree};
4 
5 pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream {
6     // This proc-macro only does some pre-parsing and then delegates the actual parsing to
7     // `kernel::__pin_data!`.
8     //
9     // In here we only collect the generics, since parsing them in declarative macros is very
10     // elaborate. We also do not need to analyse their structure, we only need to collect them.
11 
12     // `impl_generics`, the declared generics with their bounds.
13     let mut impl_generics = vec![];
14     // Only the names of the generics, without any bounds.
15     let mut ty_generics = vec![];
16     // Tokens not related to the generics e.g. the `impl` token.
17     let mut rest = vec![];
18     // The current level of `<`.
19     let mut nesting = 0;
20     let mut toks = input.into_iter();
21     // If we are at the beginning of a generic parameter.
22     let mut at_start = true;
23     for tt in &mut toks {
24         match tt.clone() {
25             TokenTree::Punct(p) if p.as_char() == '<' => {
26                 if nesting >= 1 {
27                     impl_generics.push(tt);
28                 }
29                 nesting += 1;
30             }
31             TokenTree::Punct(p) if p.as_char() == '>' => {
32                 if nesting == 0 {
33                     break;
34                 } else {
35                     nesting -= 1;
36                     if nesting >= 1 {
37                         impl_generics.push(tt);
38                     }
39                     if nesting == 0 {
40                         break;
41                     }
42                 }
43             }
44             tt => {
45                 if nesting == 1 {
46                     match &tt {
47                         TokenTree::Ident(i) if i.to_string() == "const" => {}
48                         TokenTree::Ident(_) if at_start => {
49                             ty_generics.push(tt.clone());
50                             ty_generics.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
51                             at_start = false;
52                         }
53                         TokenTree::Punct(p) if p.as_char() == ',' => at_start = true,
54                         TokenTree::Punct(p) if p.as_char() == '\'' && at_start => {
55                             ty_generics.push(tt.clone());
56                         }
57                         _ => {}
58                     }
59                 }
60                 if nesting >= 1 {
61                     impl_generics.push(tt);
62                 } else if nesting == 0 {
63                     rest.push(tt);
64                 }
65             }
66         }
67     }
68     rest.extend(toks);
69     // This should be the body of the struct `{...}`.
70     let last = rest.pop();
71     quote!(::kernel::__pin_data! {
72         parse_input:
73         @args(#args),
74         @sig(#(#rest)*),
75         @impl_generics(#(#impl_generics)*),
76         @ty_generics(#(#ty_generics)*),
77         @body(#last),
78     })
79 }
80