From 9badd58557531f63ed727991d59a35d717bd5e9a Mon Sep 17 00:00:00 2001 From: Florian Klein Date: Sat, 27 Jun 2026 15:30:14 +0200 Subject: [PATCH] Allow visits from both visitor and VisitorMut at the same time --- derive/README.md | 6 +++--- derive/src/lib.rs | 2 ++ derive/src/visit.rs | 15 ++++++++------ src/ast/mod.rs | 2 +- src/ast/visitor.rs | 50 ++++++++++++++++++++++++++++++++------------- 5 files changed, 51 insertions(+), 24 deletions(-) diff --git a/derive/README.md b/derive/README.md index b5ccc69e0f..addabbb17e 100644 --- a/derive/README.md +++ b/derive/README.md @@ -188,11 +188,11 @@ impl sqlparser::ast::VisitMut for ShowStatementIn { &mut self, visitor: &mut V, ) -> ::std::ops::ControlFlow { - sqlparser::ast::VisitMut::visit(&mut self.clause, visitor)?; - sqlparser::ast::VisitMut::visit(&mut self.parent_type, visitor)?; + sqlparser::ast::VisitMut::visit_mut(&mut self.clause, visitor)?; + sqlparser::ast::VisitMut::visit_mut(&mut self.parent_type, visitor)?; if let Some(value) = &mut self.parent_name { visitor.pre_visit_relation(value)?; - sqlparser::ast::VisitMut::visit(value, visitor)?; + sqlparser::ast::VisitMut::visit_mut(value, visitor)?; visitor.post_visit_relation(value)?; } ::std::ops::ControlFlow::Continue(()) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index e3eaeea6d5..bb32140cbd 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -35,6 +35,7 @@ pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStre input, &visit::VisitType { visit_trait: quote!(VisitMut), + visit_method: quote!(visit_mut), visitor_trait: quote!(VisitorMut), modifier: Some(quote!(mut)), }, @@ -49,6 +50,7 @@ pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::Tok input, &visit::VisitType { visit_trait: quote!(Visit), + visit_method: quote!(visit), visitor_trait: quote!(Visitor), modifier: None, }, diff --git a/derive/src/visit.rs b/derive/src/visit.rs index cb02733b77..5e59fec6d1 100644 --- a/derive/src/visit.rs +++ b/derive/src/visit.rs @@ -29,6 +29,7 @@ use syn::{Path, PathArguments}; pub(crate) struct VisitType { pub visit_trait: TokenStream, + pub visit_method: TokenStream, pub visitor_trait: TokenStream, pub modifier: Option, } @@ -41,6 +42,7 @@ pub(crate) fn derive_visit( let VisitType { visit_trait, + visit_method, visitor_trait, modifier, } = visit_type; @@ -59,7 +61,7 @@ pub(crate) fn derive_visit( // See tests in https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info. impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause { #[cfg_attr(feature = "recursive-protection", recursive::recursive)] - fn visit( + fn #visit_method( &#modifier self, visitor: &mut V ) -> ::core::ops::ControlFlow { @@ -154,6 +156,7 @@ fn visit_children( data: &Data, VisitType { visit_trait, + visit_method, modifier, .. }: &VisitType, @@ -169,13 +172,13 @@ fn visit_children( let (pre_visit, post_visit) = attributes.visit(quote!(value)); quote_spanned!(f.span() => if let Some(value) = &#modifier self.#name { - #pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit + #pre_visit sqlparser::ast::#visit_trait::#visit_method(value, visitor)?; #post_visit } ) } else { let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); quote_spanned!(f.span() => - #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit + #pre_visit sqlparser::ast::#visit_trait::#visit_method(&#modifier self.#name, visitor)?; #post_visit ) } }); @@ -188,7 +191,7 @@ fn visit_children( let index = Index::from(i); let attributes = Attributes::parse(&f.attrs); let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index)); - quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit) + quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::#visit_method(&#modifier self.#index, visitor)?; #post_visit) }); quote! { #(#recurse)* @@ -208,7 +211,7 @@ fn visit_children( let name = &f.ident; let attributes = Attributes::parse(&f.attrs); let (pre_visit, post_visit) = attributes.visit(name.to_token_stream()); - quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit) + quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::#visit_method(#name, visitor)?; #post_visit) }); quote!( @@ -223,7 +226,7 @@ fn visit_children( let name = format_ident!("_{}", i); let attributes = Attributes::parse(&f.attrs); let (pre_visit, post_visit) = attributes.visit(name.to_token_stream()); - quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit) + quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::#visit_method(#name, visitor)?; #post_visit) }); quote! { diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 88883cfbb8..73d0941276 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -402,7 +402,7 @@ impl Visit for Ident { #[cfg(feature = "visitor")] impl VisitMut for Ident { - fn visit(&mut self, visitor: &mut V) -> ControlFlow { + fn visit_mut(&mut self, visitor: &mut V) -> ControlFlow { visitor.pre_visit_ident(self)?; visitor.post_visit_ident(self) } diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index c70e83ec62..8300029eaf 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -59,7 +59,7 @@ pub trait VisitMut { /// Implementations should call the appropriate mutable visitor hooks to /// traverse and allow in-place mutation of child nodes. Returning a /// `ControlFlow` value permits early termination of the traversal. - fn visit(&mut self, visitor: &mut V) -> ControlFlow; + fn visit_mut(&mut self, visitor: &mut V) -> ControlFlow; } impl Visit for Option { @@ -87,26 +87,26 @@ impl Visit for Box { } impl VisitMut for Option { - fn visit(&mut self, visitor: &mut V) -> ControlFlow { + fn visit_mut(&mut self, visitor: &mut V) -> ControlFlow { if let Some(s) = self { - s.visit(visitor)?; + s.visit_mut(visitor)?; } ControlFlow::Continue(()) } } impl VisitMut for Vec { - fn visit(&mut self, visitor: &mut V) -> ControlFlow { + fn visit_mut(&mut self, visitor: &mut V) -> ControlFlow { for v in self { - v.visit(visitor)?; + v.visit_mut(visitor)?; } ControlFlow::Continue(()) } } impl VisitMut for Box { - fn visit(&mut self, visitor: &mut V) -> ControlFlow { - T::visit(self, visitor) + fn visit_mut(&mut self, visitor: &mut V) -> ControlFlow { + T::visit_mut(self, visitor) } } @@ -118,7 +118,7 @@ macro_rules! visit_noop { } })+ $(impl VisitMut for $t { - fn visit(&mut self, _visitor: &mut V) -> ControlFlow { + fn visit_mut(&mut self, _visitor: &mut V) -> ControlFlow { ControlFlow::Continue(()) } })+ @@ -320,7 +320,7 @@ pub trait Visitor { /// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap(); /// /// // Drive the visitor through the AST -/// statements.visit(&mut Replacer); +/// statements.visit_mut(&mut Replacer); /// /// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)"); /// ``` @@ -503,7 +503,7 @@ where F: FnMut(&mut ObjectName) -> ControlFlow, { let mut visitor = RelationVisitor(f); - v.visit(&mut visitor)?; + v.visit_mut(&mut visitor)?; ControlFlow::Continue(()) } @@ -633,7 +633,7 @@ where V: VisitMut, F: FnMut(&mut Expr) -> ControlFlow, { - v.visit(&mut ExprVisitor(f))?; + v.visit_mut(&mut ExprVisitor(f))?; ControlFlow::Continue(()) } @@ -720,7 +720,7 @@ where V: VisitMut, F: FnMut(&mut Statement) -> ControlFlow, { - v.visit(&mut StatementVisitor(f))?; + v.visit_mut(&mut StatementVisitor(f))?; ControlFlow::Continue(()) } @@ -1059,7 +1059,7 @@ mod tests { #[cfg(test)] mod visit_mut_tests { - use crate::ast::{Ident, Statement, Value, ValueWithSpan, VisitMut, VisitorMut}; + use crate::ast::{Ident, Statement, Value, ValueWithSpan, Visit, VisitMut, Visitor, VisitorMut}; use crate::dialect::GenericDialect; use crate::parser::Parser; use crate::tokenizer::Tokenizer; @@ -1092,7 +1092,7 @@ mod visit_mut_tests { .parse_statement() .unwrap(); - let flow = s.visit(visitor); + let flow = s.visit_mut(visitor); assert_eq!(flow, ControlFlow::Continue(())); s } @@ -1139,4 +1139,26 @@ mod visit_mut_tests { let mutated = do_visit_mut("SELECT a, b FROM t", &mut visitor); assert_eq!(mutated.to_string(), "SELECT A, B FROM T"); } + + struct DummyVisitor; + impl Visitor for DummyVisitor { + type Break = (); + } + + struct DummyVisitorMut; + impl VisitorMut for DummyVisitorMut { + type Break = (); + } + + #[test] + fn test_both_visit_and_visit_mut() { + let mut visitor = DummyVisitor; + let mut visitor_mut = DummyVisitorMut; + let mut statements = Parser::parse_sql(&GenericDialect {}, "SELECT 1").unwrap(); + + let _ = statements.visit(&mut visitor); + let _ = statements.visit_mut(&mut visitor_mut); + + assert_eq!(statements[0].to_string(), "SELECT 1"); + } }