xref: /openbmc/linux/rust/macros/paste.rs (revision 3ddc8b84)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 use proc_macro::{Delimiter, Group, Ident, Spacing, Span, TokenTree};
4 
5 fn concat(tokens: &[TokenTree], group_span: Span) -> TokenTree {
6     let mut tokens = tokens.iter();
7     let mut segments = Vec::new();
8     let mut span = None;
9     loop {
10         match tokens.next() {
11             None => break,
12             Some(TokenTree::Literal(lit)) => segments.push((lit.to_string(), lit.span())),
13             Some(TokenTree::Ident(ident)) => {
14                 let mut value = ident.to_string();
15                 if value.starts_with("r#") {
16                     value.replace_range(0..2, "");
17                 }
18                 segments.push((value, ident.span()));
19             }
20             Some(TokenTree::Punct(p)) if p.as_char() == ':' => {
21                 let Some(TokenTree::Ident(ident)) = tokens.next() else {
22                     panic!("expected identifier as modifier");
23                 };
24 
25                 let (mut value, sp) = segments.pop().expect("expected identifier before modifier");
26                 match ident.to_string().as_str() {
27                     // Set the overall span of concatenated token as current span
28                     "span" => {
29                         assert!(
30                             span.is_none(),
31                             "span modifier should only appear at most once"
32                         );
33                         span = Some(sp);
34                     }
35                     "lower" => value = value.to_lowercase(),
36                     "upper" => value = value.to_uppercase(),
37                     v => panic!("unknown modifier `{v}`"),
38                 };
39                 segments.push((value, sp));
40             }
41             _ => panic!("unexpected token in paste segments"),
42         };
43     }
44 
45     let pasted: String = segments.into_iter().map(|x| x.0).collect();
46     TokenTree::Ident(Ident::new(&pasted, span.unwrap_or(group_span)))
47 }
48 
49 pub(crate) fn expand(tokens: &mut Vec<TokenTree>) {
50     for token in tokens.iter_mut() {
51         if let TokenTree::Group(group) = token {
52             let delimiter = group.delimiter();
53             let span = group.span();
54             let mut stream: Vec<_> = group.stream().into_iter().collect();
55             // Find groups that looks like `[< A B C D >]`
56             if delimiter == Delimiter::Bracket
57                 && stream.len() >= 3
58                 && matches!(&stream[0], TokenTree::Punct(p) if p.as_char() == '<')
59                 && matches!(&stream[stream.len() - 1], TokenTree::Punct(p) if p.as_char() == '>')
60             {
61                 // Replace the group with concatenated token
62                 *token = concat(&stream[1..stream.len() - 1], span);
63             } else {
64                 // Recursively expand tokens inside the group
65                 expand(&mut stream);
66                 let mut group = Group::new(delimiter, stream.into_iter().collect());
67                 group.set_span(span);
68                 *token = TokenTree::Group(group);
69             }
70         }
71     }
72 
73     // Path segments cannot contain invisible delimiter group, so remove them if any.
74     for i in (0..tokens.len().saturating_sub(3)).rev() {
75         // Looking for a double colon
76         if matches!(
77             (&tokens[i + 1], &tokens[i + 2]),
78             (TokenTree::Punct(a), TokenTree::Punct(b))
79                 if a.as_char() == ':' && a.spacing() == Spacing::Joint && b.as_char() == ':'
80         ) {
81             match &tokens[i + 3] {
82                 TokenTree::Group(group) if group.delimiter() == Delimiter::None => {
83                     tokens.splice(i + 3..i + 4, group.stream());
84                 }
85                 _ => (),
86             }
87 
88             match &tokens[i] {
89                 TokenTree::Group(group) if group.delimiter() == Delimiter::None => {
90                     tokens.splice(i..i + 1, group.stream());
91                 }
92                 _ => (),
93             }
94         }
95     }
96 }
97