1 // Copyright 2024, Linaro Limited
2 // Author(s): Manos Pitsidianakis <manos.pitsidianakis@linaro.org>
3 // SPDX-License-Identifier: GPL-2.0-or-later
4
5 use proc_macro::TokenStream;
6 use quote::quote;
7 use syn::{
8 parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, Data,
9 DeriveInput, Field, Fields, FieldsUnnamed, Ident, Meta, Path, Token, Type, Variant, Visibility,
10 };
11
12 mod utils;
13 use utils::MacroError;
14
get_fields<'a>( input: &'a DeriveInput, msg: &str, ) -> Result<&'a Punctuated<Field, Comma>, MacroError>15 fn get_fields<'a>(
16 input: &'a DeriveInput,
17 msg: &str,
18 ) -> Result<&'a Punctuated<Field, Comma>, MacroError> {
19 if let Data::Struct(s) = &input.data {
20 if let Fields::Named(fs) = &s.fields {
21 Ok(&fs.named)
22 } else {
23 Err(MacroError::Message(
24 format!("Named fields required for {}", msg),
25 input.ident.span(),
26 ))
27 }
28 } else {
29 Err(MacroError::Message(
30 format!("Struct required for {}", msg),
31 input.ident.span(),
32 ))
33 }
34 }
35
get_unnamed_field<'a>(input: &'a DeriveInput, msg: &str) -> Result<&'a Field, MacroError>36 fn get_unnamed_field<'a>(input: &'a DeriveInput, msg: &str) -> Result<&'a Field, MacroError> {
37 if let Data::Struct(s) = &input.data {
38 let unnamed = match &s.fields {
39 Fields::Unnamed(FieldsUnnamed {
40 unnamed: ref fields,
41 ..
42 }) => fields,
43 _ => {
44 return Err(MacroError::Message(
45 format!("Tuple struct required for {}", msg),
46 s.fields.span(),
47 ))
48 }
49 };
50 if unnamed.len() != 1 {
51 return Err(MacroError::Message(
52 format!("A single field is required for {}", msg),
53 s.fields.span(),
54 ));
55 }
56 Ok(&unnamed[0])
57 } else {
58 Err(MacroError::Message(
59 format!("Struct required for {}", msg),
60 input.ident.span(),
61 ))
62 }
63 }
64
is_c_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError>65 fn is_c_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> {
66 let expected = parse_quote! { #[repr(C)] };
67
68 if input.attrs.iter().any(|attr| attr == &expected) {
69 Ok(())
70 } else {
71 Err(MacroError::Message(
72 format!("#[repr(C)] required for {}", msg),
73 input.ident.span(),
74 ))
75 }
76 }
77
is_transparent_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError>78 fn is_transparent_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> {
79 let expected = parse_quote! { #[repr(transparent)] };
80
81 if input.attrs.iter().any(|attr| attr == &expected) {
82 Ok(())
83 } else {
84 Err(MacroError::Message(
85 format!("#[repr(transparent)] required for {}", msg),
86 input.ident.span(),
87 ))
88 }
89 }
90
derive_object_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError>91 fn derive_object_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
92 is_c_repr(&input, "#[derive(Object)]")?;
93
94 let name = &input.ident;
95 let parent = &get_fields(&input, "#[derive(Object)]")?[0].ident;
96
97 Ok(quote! {
98 ::qemu_api::assert_field_type!(#name, #parent,
99 ::qemu_api::qom::ParentField<<#name as ::qemu_api::qom::ObjectImpl>::ParentType>);
100
101 ::qemu_api::module_init! {
102 MODULE_INIT_QOM => unsafe {
103 ::qemu_api::bindings::type_register_static(&<#name as ::qemu_api::qom::ObjectImpl>::TYPE_INFO);
104 }
105 }
106 })
107 }
108
109 #[proc_macro_derive(Object)]
derive_object(input: TokenStream) -> TokenStream110 pub fn derive_object(input: TokenStream) -> TokenStream {
111 let input = parse_macro_input!(input as DeriveInput);
112 let expanded = derive_object_or_error(input).unwrap_or_else(Into::into);
113
114 TokenStream::from(expanded)
115 }
116
derive_opaque_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError>117 fn derive_opaque_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
118 is_transparent_repr(&input, "#[derive(Wrapper)]")?;
119
120 let name = &input.ident;
121 let field = &get_unnamed_field(&input, "#[derive(Wrapper)]")?;
122 let typ = &field.ty;
123
124 // TODO: how to add "::qemu_api"? For now, this is only used in the
125 // qemu_api crate so it's not a problem.
126 Ok(quote! {
127 unsafe impl crate::cell::Wrapper for #name {
128 type Wrapped = <#typ as crate::cell::Wrapper>::Wrapped;
129 }
130 impl #name {
131 pub unsafe fn from_raw<'a>(ptr: *mut <Self as crate::cell::Wrapper>::Wrapped) -> &'a Self {
132 let ptr = ::std::ptr::NonNull::new(ptr).unwrap().cast::<Self>();
133 unsafe { ptr.as_ref() }
134 }
135
136 pub const fn as_mut_ptr(&self) -> *mut <Self as crate::cell::Wrapper>::Wrapped {
137 self.0.as_mut_ptr()
138 }
139
140 pub const fn as_ptr(&self) -> *const <Self as crate::cell::Wrapper>::Wrapped {
141 self.0.as_ptr()
142 }
143
144 pub const fn as_void_ptr(&self) -> *mut ::core::ffi::c_void {
145 self.0.as_void_ptr()
146 }
147
148 pub const fn raw_get(slot: *mut Self) -> *mut <Self as crate::cell::Wrapper>::Wrapped {
149 slot.cast()
150 }
151 }
152 })
153 }
154
155 #[proc_macro_derive(Wrapper)]
derive_opaque(input: TokenStream) -> TokenStream156 pub fn derive_opaque(input: TokenStream) -> TokenStream {
157 let input = parse_macro_input!(input as DeriveInput);
158 let expanded = derive_opaque_or_error(input).unwrap_or_else(Into::into);
159
160 TokenStream::from(expanded)
161 }
162
163 #[rustfmt::skip::macros(quote)]
derive_offsets_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError>164 fn derive_offsets_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
165 is_c_repr(&input, "#[derive(offsets)]")?;
166
167 let name = &input.ident;
168 let fields = get_fields(&input, "#[derive(offsets)]")?;
169 let field_names: Vec<&Ident> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
170 let field_types: Vec<&Type> = fields.iter().map(|f| &f.ty).collect();
171 let field_vis: Vec<&Visibility> = fields.iter().map(|f| &f.vis).collect();
172
173 Ok(quote! {
174 ::qemu_api::with_offsets! {
175 struct #name {
176 #(#field_vis #field_names: #field_types,)*
177 }
178 }
179 })
180 }
181
182 #[proc_macro_derive(offsets)]
derive_offsets(input: TokenStream) -> TokenStream183 pub fn derive_offsets(input: TokenStream) -> TokenStream {
184 let input = parse_macro_input!(input as DeriveInput);
185 let expanded = derive_offsets_or_error(input).unwrap_or_else(Into::into);
186
187 TokenStream::from(expanded)
188 }
189
190 #[allow(non_snake_case)]
get_repr_uN(input: &DeriveInput, msg: &str) -> Result<Path, MacroError>191 fn get_repr_uN(input: &DeriveInput, msg: &str) -> Result<Path, MacroError> {
192 let repr = input.attrs.iter().find(|attr| attr.path().is_ident("repr"));
193 if let Some(repr) = repr {
194 let nested = repr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
195 for meta in nested {
196 match meta {
197 Meta::Path(path) if path.is_ident("u8") => return Ok(path),
198 Meta::Path(path) if path.is_ident("u16") => return Ok(path),
199 Meta::Path(path) if path.is_ident("u32") => return Ok(path),
200 Meta::Path(path) if path.is_ident("u64") => return Ok(path),
201 _ => {}
202 }
203 }
204 }
205
206 Err(MacroError::Message(
207 format!("#[repr(u8/u16/u32/u64) required for {}", msg),
208 input.ident.span(),
209 ))
210 }
211
get_variants(input: &DeriveInput) -> Result<&Punctuated<Variant, Comma>, MacroError>212 fn get_variants(input: &DeriveInput) -> Result<&Punctuated<Variant, Comma>, MacroError> {
213 if let Data::Enum(e) = &input.data {
214 if let Some(v) = e.variants.iter().find(|v| v.fields != Fields::Unit) {
215 return Err(MacroError::Message(
216 "Cannot derive TryInto for enum with non-unit variants.".to_string(),
217 v.fields.span(),
218 ));
219 }
220 Ok(&e.variants)
221 } else {
222 Err(MacroError::Message(
223 "Cannot derive TryInto for union or struct.".to_string(),
224 input.ident.span(),
225 ))
226 }
227 }
228
229 #[rustfmt::skip::macros(quote)]
derive_tryinto_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError>230 fn derive_tryinto_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
231 let repr = get_repr_uN(&input, "#[derive(TryInto)]")?;
232
233 let name = &input.ident;
234 let variants = get_variants(&input)?;
235 let discriminants: Vec<&Ident> = variants.iter().map(|f| &f.ident).collect();
236
237 Ok(quote! {
238 impl core::convert::TryFrom<#repr> for #name {
239 type Error = #repr;
240
241 fn try_from(value: #repr) -> Result<Self, Self::Error> {
242 #(const #discriminants: #repr = #name::#discriminants as #repr;)*;
243 match value {
244 #(#discriminants => Ok(Self::#discriminants),)*
245 _ => Err(value),
246 }
247 }
248 }
249 })
250 }
251
252 #[proc_macro_derive(TryInto)]
derive_tryinto(input: TokenStream) -> TokenStream253 pub fn derive_tryinto(input: TokenStream) -> TokenStream {
254 let input = parse_macro_input!(input as DeriveInput);
255 let expanded = derive_tryinto_or_error(input).unwrap_or_else(Into::into);
256
257 TokenStream::from(expanded)
258 }
259