rust: macros: replace Self with the concrete type in #[pin_data]

When using `#[pin_data]` on a struct that used `Self` in the field
types, a type error would be emitted when trying to use `pin_init!`.
Since an internal type would be referenced by `Self` instead of the
defined struct.

This patch fixes this issue by replacing all occurrences of `Self` in
the `#[pin_data]` macro with the concrete type circumventing the issue.
Since rust allows type definitions inside of blocks, which are
expressions, the macro also checks for these and emits a compile error
when it finds `trait`, `enum`, `union`, `struct` or `impl`. These
keywords allow creating new `Self` contexts, which conflicts with the
current implementation of replacing every `Self` ident. If these were
allowed, some `Self` idents would be replaced incorrectly.

Signed-off-by: Benno Lossin <benno.lossin@proton.me>
Reported-by: Alice Ryhl <aliceryhl@google.com>
Reviewed-by: Alice Ryhl <aliceryhl@google.com>
Reviewed-by: Martin Rodriguez Reboredo <yakoyoku@gmail.com>
Reviewed-by: Gary Guo <gary@garyguo.net>
Link: https://lore.kernel.org/r/20230424081112.99890-3-benno.lossin@proton.me
[ Added newline in commit message ]
Signed-off-by: Miguel Ojeda <ojeda@kernel.org>
diff --git a/rust/macros/pin_data.rs b/rust/macros/pin_data.rs
index c593b05..6d58cfd 100644
--- a/rust/macros/pin_data.rs
+++ b/rust/macros/pin_data.rs
@@ -1,7 +1,7 @@
 // SPDX-License-Identifier: Apache-2.0 OR MIT
 
 use crate::helpers::{parse_generics, Generics};
-use proc_macro::TokenStream;
+use proc_macro::{Group, Punct, Spacing, TokenStream, TokenTree};
 
 pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream {
     // This proc-macro only does some pre-parsing and then delegates the actual parsing to
@@ -12,16 +12,116 @@ pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream {
             impl_generics,
             ty_generics,
         },
-        mut rest,
+        rest,
     ) = parse_generics(input);
+    // The struct definition might contain the `Self` type. Since `__pin_data!` will define a new
+    // type with the same generics and bounds, this poses a problem, since `Self` will refer to the
+    // new type as opposed to this struct definition. Therefore we have to replace `Self` with the
+    // concrete name.
+
+    // Errors that occur when replacing `Self` with `struct_name`.
+    let mut errs = TokenStream::new();
+    // The name of the struct with ty_generics.
+    let struct_name = rest
+        .iter()
+        .skip_while(|tt| !matches!(tt, TokenTree::Ident(i) if i.to_string() == "struct"))
+        .nth(1)
+        .and_then(|tt| match tt {
+            TokenTree::Ident(_) => {
+                let tt = tt.clone();
+                let mut res = vec![tt];
+                if !ty_generics.is_empty() {
+                    // We add this, so it is maximally compatible with e.g. `Self::CONST` which
+                    // will be replaced by `StructName::<$generics>::CONST`.
+                    res.push(TokenTree::Punct(Punct::new(':', Spacing::Joint)));
+                    res.push(TokenTree::Punct(Punct::new(':', Spacing::Alone)));
+                    res.push(TokenTree::Punct(Punct::new('<', Spacing::Alone)));
+                    res.extend(ty_generics.iter().cloned());
+                    res.push(TokenTree::Punct(Punct::new('>', Spacing::Alone)));
+                }
+                Some(res)
+            }
+            _ => None,
+        })
+        .unwrap_or_else(|| {
+            // If we did not find the name of the struct then we will use `Self` as the replacement
+            // and add a compile error to ensure it does not compile.
+            errs.extend(
+                "::core::compile_error!(\"Could not locate type name.\");"
+                    .parse::<TokenStream>()
+                    .unwrap(),
+            );
+            "Self".parse::<TokenStream>().unwrap().into_iter().collect()
+        });
+    let impl_generics = impl_generics
+        .into_iter()
+        .flat_map(|tt| replace_self_and_deny_type_defs(&struct_name, tt, &mut errs))
+        .collect::<Vec<_>>();
+    let mut rest = rest
+        .into_iter()
+        .flat_map(|tt| {
+            // We ignore top level `struct` tokens, since they would emit a compile error.
+            if matches!(&tt, TokenTree::Ident(i) if i.to_string() == "struct") {
+                vec![tt]
+            } else {
+                replace_self_and_deny_type_defs(&struct_name, tt, &mut errs)
+            }
+        })
+        .collect::<Vec<_>>();
     // This should be the body of the struct `{...}`.
     let last = rest.pop();
-    quote!(::kernel::__pin_data! {
+    let mut quoted = quote!(::kernel::__pin_data! {
         parse_input:
         @args(#args),
         @sig(#(#rest)*),
         @impl_generics(#(#impl_generics)*),
         @ty_generics(#(#ty_generics)*),
         @body(#last),
-    })
+    });
+    quoted.extend(errs);
+    quoted
+}
+
+/// Replaces `Self` with `struct_name` and errors on `enum`, `trait`, `struct` `union` and `impl`
+/// keywords.
+///
+/// The error is appended to `errs` to allow normal parsing to continue.
+fn replace_self_and_deny_type_defs(
+    struct_name: &Vec<TokenTree>,
+    tt: TokenTree,
+    errs: &mut TokenStream,
+) -> Vec<TokenTree> {
+    match tt {
+        TokenTree::Ident(ref i)
+            if i.to_string() == "enum"
+                || i.to_string() == "trait"
+                || i.to_string() == "struct"
+                || i.to_string() == "union"
+                || i.to_string() == "impl" =>
+        {
+            errs.extend(
+                format!(
+                    "::core::compile_error!(\"Cannot use `{i}` inside of struct definition with \
+                        `#[pin_data]`.\");"
+                )
+                .parse::<TokenStream>()
+                .unwrap()
+                .into_iter()
+                .map(|mut tok| {
+                    tok.set_span(tt.span());
+                    tok
+                }),
+            );
+            vec![tt]
+        }
+        TokenTree::Ident(i) if i.to_string() == "Self" => struct_name.clone(),
+        TokenTree::Literal(_) | TokenTree::Punct(_) | TokenTree::Ident(_) => vec![tt],
+        TokenTree::Group(g) => vec![TokenTree::Group(Group::new(
+            g.delimiter(),
+            g.stream()
+                .into_iter()
+                .flat_map(|tt| replace_self_and_deny_type_defs(struct_name, tt, errs))
+                .collect(),
+        ))],
+    }
 }