| // SPDX-License-Identifier: GPL-2.0 |
| |
| use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; |
| use std::collections::HashSet; |
| use std::fmt::Write; |
| |
| pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream { |
| let mut tokens: Vec<_> = ts.into_iter().collect(); |
| |
| // Scan for the `trait` or `impl` keyword. |
| let is_trait = tokens |
| .iter() |
| .find_map(|token| match token { |
| TokenTree::Ident(ident) => match ident.to_string().as_str() { |
| "trait" => Some(true), |
| "impl" => Some(false), |
| _ => None, |
| }, |
| _ => None, |
| }) |
| .expect("#[vtable] attribute should only be applied to trait or impl block"); |
| |
| // Retrieve the main body. The main body should be the last token tree. |
| let body = match tokens.pop() { |
| Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, |
| _ => panic!("cannot locate main body of trait or impl block"), |
| }; |
| |
| let mut body_it = body.stream().into_iter(); |
| let mut functions = Vec::new(); |
| let mut consts = HashSet::new(); |
| while let Some(token) = body_it.next() { |
| match token { |
| TokenTree::Ident(ident) if ident.to_string() == "fn" => { |
| let fn_name = match body_it.next() { |
| Some(TokenTree::Ident(ident)) => ident.to_string(), |
| // Possibly we've encountered a fn pointer type instead. |
| _ => continue, |
| }; |
| functions.push(fn_name); |
| } |
| TokenTree::Ident(ident) if ident.to_string() == "const" => { |
| let const_name = match body_it.next() { |
| Some(TokenTree::Ident(ident)) => ident.to_string(), |
| // Possibly we've encountered an inline const block instead. |
| _ => continue, |
| }; |
| consts.insert(const_name); |
| } |
| _ => (), |
| } |
| } |
| |
| let mut const_items; |
| if is_trait { |
| const_items = " |
| /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) |
| /// attribute when implementing this trait. |
| const USE_VTABLE_ATTR: (); |
| " |
| .to_owned(); |
| |
| for f in functions { |
| let gen_const_name = format!("HAS_{}", f.to_uppercase()); |
| // Skip if it's declared already -- this allows user override. |
| if consts.contains(&gen_const_name) { |
| continue; |
| } |
| // We don't know on the implementation-site whether a method is required or provided |
| // so we have to generate a const for all methods. |
| write!( |
| const_items, |
| "/// Indicates if the `{f}` method is overridden by the implementor. |
| const {gen_const_name}: bool = false;", |
| ) |
| .unwrap(); |
| consts.insert(gen_const_name); |
| } |
| } else { |
| const_items = "const USE_VTABLE_ATTR: () = ();".to_owned(); |
| |
| for f in functions { |
| let gen_const_name = format!("HAS_{}", f.to_uppercase()); |
| if consts.contains(&gen_const_name) { |
| continue; |
| } |
| write!(const_items, "const {gen_const_name}: bool = true;").unwrap(); |
| } |
| } |
| |
| let new_body = vec![const_items.parse().unwrap(), body.stream()] |
| .into_iter() |
| .collect(); |
| tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); |
| tokens.into_iter().collect() |
| } |