1*071cedc8SBenno Lossin // SPDX-License-Identifier: GPL-2.0
2*071cedc8SBenno Lossin
3*071cedc8SBenno Lossin use crate::helpers::{parse_generics, Generics};
4*071cedc8SBenno Lossin use proc_macro::{TokenStream, TokenTree};
5*071cedc8SBenno Lossin
derive(input: TokenStream) -> TokenStream6*071cedc8SBenno Lossin pub(crate) fn derive(input: TokenStream) -> TokenStream {
7*071cedc8SBenno Lossin let (
8*071cedc8SBenno Lossin Generics {
9*071cedc8SBenno Lossin impl_generics,
10*071cedc8SBenno Lossin ty_generics,
11*071cedc8SBenno Lossin },
12*071cedc8SBenno Lossin mut rest,
13*071cedc8SBenno Lossin ) = parse_generics(input);
14*071cedc8SBenno Lossin // This should be the body of the struct `{...}`.
15*071cedc8SBenno Lossin let last = rest.pop();
16*071cedc8SBenno Lossin // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
17*071cedc8SBenno Lossin let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
18*071cedc8SBenno Lossin // Are we inside of a generic where we want to add `Zeroable`?
19*071cedc8SBenno Lossin let mut in_generic = !impl_generics.is_empty();
20*071cedc8SBenno Lossin // Have we already inserted `Zeroable`?
21*071cedc8SBenno Lossin let mut inserted = false;
22*071cedc8SBenno Lossin // Level of `<>` nestings.
23*071cedc8SBenno Lossin let mut nested = 0;
24*071cedc8SBenno Lossin for tt in impl_generics {
25*071cedc8SBenno Lossin match &tt {
26*071cedc8SBenno Lossin // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
27*071cedc8SBenno Lossin TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
28*071cedc8SBenno Lossin if in_generic && !inserted {
29*071cedc8SBenno Lossin new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
30*071cedc8SBenno Lossin }
31*071cedc8SBenno Lossin in_generic = true;
32*071cedc8SBenno Lossin inserted = false;
33*071cedc8SBenno Lossin new_impl_generics.push(tt);
34*071cedc8SBenno Lossin }
35*071cedc8SBenno Lossin // If we find `'`, then we are entering a lifetime.
36*071cedc8SBenno Lossin TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
37*071cedc8SBenno Lossin in_generic = false;
38*071cedc8SBenno Lossin new_impl_generics.push(tt);
39*071cedc8SBenno Lossin }
40*071cedc8SBenno Lossin TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
41*071cedc8SBenno Lossin new_impl_generics.push(tt);
42*071cedc8SBenno Lossin if in_generic {
43*071cedc8SBenno Lossin new_impl_generics.extend(quote! { ::kernel::init::Zeroable + });
44*071cedc8SBenno Lossin inserted = true;
45*071cedc8SBenno Lossin }
46*071cedc8SBenno Lossin }
47*071cedc8SBenno Lossin TokenTree::Punct(p) if p.as_char() == '<' => {
48*071cedc8SBenno Lossin nested += 1;
49*071cedc8SBenno Lossin new_impl_generics.push(tt);
50*071cedc8SBenno Lossin }
51*071cedc8SBenno Lossin TokenTree::Punct(p) if p.as_char() == '>' => {
52*071cedc8SBenno Lossin assert!(nested > 0);
53*071cedc8SBenno Lossin nested -= 1;
54*071cedc8SBenno Lossin new_impl_generics.push(tt);
55*071cedc8SBenno Lossin }
56*071cedc8SBenno Lossin _ => new_impl_generics.push(tt),
57*071cedc8SBenno Lossin }
58*071cedc8SBenno Lossin }
59*071cedc8SBenno Lossin assert_eq!(nested, 0);
60*071cedc8SBenno Lossin if in_generic && !inserted {
61*071cedc8SBenno Lossin new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
62*071cedc8SBenno Lossin }
63*071cedc8SBenno Lossin quote! {
64*071cedc8SBenno Lossin ::kernel::__derive_zeroable!(
65*071cedc8SBenno Lossin parse_input:
66*071cedc8SBenno Lossin @sig(#(#rest)*),
67*071cedc8SBenno Lossin @impl_generics(#(#new_impl_generics)*),
68*071cedc8SBenno Lossin @ty_generics(#(#ty_generics)*),
69*071cedc8SBenno Lossin @body(#last),
70*071cedc8SBenno Lossin );
71*071cedc8SBenno Lossin }
72*071cedc8SBenno Lossin }
73