Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions derive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ impl sqlparser::ast::VisitMut for ShowStatementIn {
&mut self,
visitor: &mut V,
) -> ::std::ops::ControlFlow<V::Break> {
sqlparser::ast::VisitMut::visit(&mut self.clause, visitor)?;
sqlparser::ast::VisitMut::visit(&mut self.parent_type, visitor)?;
sqlparser::ast::VisitMut::visit_mut(&mut self.clause, visitor)?;
sqlparser::ast::VisitMut::visit_mut(&mut self.parent_type, visitor)?;
if let Some(value) = &mut self.parent_name {
visitor.pre_visit_relation(value)?;
sqlparser::ast::VisitMut::visit(value, visitor)?;
sqlparser::ast::VisitMut::visit_mut(value, visitor)?;
visitor.post_visit_relation(value)?;
}
::std::ops::ControlFlow::Continue(())
Expand Down
2 changes: 2 additions & 0 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStre
input,
&visit::VisitType {
visit_trait: quote!(VisitMut),
visit_method: quote!(visit_mut),
visitor_trait: quote!(VisitorMut),
modifier: Some(quote!(mut)),
},
Expand All @@ -49,6 +50,7 @@ pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::Tok
input,
&visit::VisitType {
visit_trait: quote!(Visit),
visit_method: quote!(visit),
visitor_trait: quote!(Visitor),
modifier: None,
},
Expand Down
15 changes: 9 additions & 6 deletions derive/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use syn::{Path, PathArguments};

pub(crate) struct VisitType {
pub visit_trait: TokenStream,
pub visit_method: TokenStream,
pub visitor_trait: TokenStream,
pub modifier: Option<TokenStream>,
}
Expand All @@ -41,6 +42,7 @@ pub(crate) fn derive_visit(

let VisitType {
visit_trait,
visit_method,
visitor_trait,
modifier,
} = visit_type;
Expand All @@ -59,7 +61,7 @@ pub(crate) fn derive_visit(
// See tests in https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info.
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
fn visit<V: sqlparser::ast::#visitor_trait>(
fn #visit_method<V: sqlparser::ast::#visitor_trait>(
&#modifier self,
visitor: &mut V
) -> ::core::ops::ControlFlow<V::Break> {
Expand Down Expand Up @@ -154,6 +156,7 @@ fn visit_children(
data: &Data,
VisitType {
visit_trait,
visit_method,
modifier,
..
}: &VisitType,
Expand All @@ -169,13 +172,13 @@ fn visit_children(
let (pre_visit, post_visit) = attributes.visit(quote!(value));
quote_spanned!(f.span() =>
if let Some(value) = &#modifier self.#name {
#pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit
#pre_visit sqlparser::ast::#visit_trait::#visit_method(value, visitor)?; #post_visit
}
)
} else {
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
quote_spanned!(f.span() =>
#pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit
#pre_visit sqlparser::ast::#visit_trait::#visit_method(&#modifier self.#name, visitor)?; #post_visit
)
}
});
Expand All @@ -188,7 +191,7 @@ fn visit_children(
let index = Index::from(i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit)
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::#visit_method(&#modifier self.#index, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
Expand All @@ -208,7 +211,7 @@ fn visit_children(
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::#visit_method(#name, visitor)?; #post_visit)
});

quote!(
Expand All @@ -223,7 +226,7 @@ fn visit_children(
let name = format_ident!("_{}", i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::#visit_method(#name, visitor)?; #post_visit)
});

quote! {
Expand Down
2 changes: 1 addition & 1 deletion src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ impl Visit for Ident {

#[cfg(feature = "visitor")]
impl VisitMut for Ident {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
fn visit_mut<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
visitor.pre_visit_ident(self)?;
visitor.post_visit_ident(self)
}
Expand Down
50 changes: 36 additions & 14 deletions src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub trait VisitMut {
/// Implementations should call the appropriate mutable visitor hooks to
/// traverse and allow in-place mutation of child nodes. Returning a
/// `ControlFlow` value permits early termination of the traversal.
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
fn visit_mut<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
}

impl<T: Visit> Visit for Option<T> {
Expand Down Expand Up @@ -87,26 +87,26 @@ impl<T: Visit> Visit for Box<T> {
}

impl<T: VisitMut> VisitMut for Option<T> {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
fn visit_mut<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
if let Some(s) = self {
s.visit(visitor)?;
s.visit_mut(visitor)?;
}
ControlFlow::Continue(())
}
}

impl<T: VisitMut> VisitMut for Vec<T> {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
fn visit_mut<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
for v in self {
v.visit(visitor)?;
v.visit_mut(visitor)?;
}
ControlFlow::Continue(())
}
}

impl<T: VisitMut> VisitMut for Box<T> {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
T::visit(self, visitor)
fn visit_mut<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
T::visit_mut(self, visitor)
}
}

Expand All @@ -118,7 +118,7 @@ macro_rules! visit_noop {
}
})+
$(impl VisitMut for $t {
fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
fn visit_mut<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
ControlFlow::Continue(())
}
})+
Expand Down Expand Up @@ -320,7 +320,7 @@ pub trait Visitor {
/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
///
/// // Drive the visitor through the AST
/// statements.visit(&mut Replacer);
/// statements.visit_mut(&mut Replacer);
///
/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
/// ```
Expand Down Expand Up @@ -503,7 +503,7 @@ where
F: FnMut(&mut ObjectName) -> ControlFlow<E>,
{
let mut visitor = RelationVisitor(f);
v.visit(&mut visitor)?;
v.visit_mut(&mut visitor)?;
ControlFlow::Continue(())
}

Expand Down Expand Up @@ -633,7 +633,7 @@ where
V: VisitMut,
F: FnMut(&mut Expr) -> ControlFlow<E>,
{
v.visit(&mut ExprVisitor(f))?;
v.visit_mut(&mut ExprVisitor(f))?;
ControlFlow::Continue(())
}

Expand Down Expand Up @@ -720,7 +720,7 @@ where
V: VisitMut,
F: FnMut(&mut Statement) -> ControlFlow<E>,
{
v.visit(&mut StatementVisitor(f))?;
v.visit_mut(&mut StatementVisitor(f))?;
ControlFlow::Continue(())
}

Expand Down Expand Up @@ -1059,7 +1059,7 @@ mod tests {

#[cfg(test)]
mod visit_mut_tests {
use crate::ast::{Ident, Statement, Value, ValueWithSpan, VisitMut, VisitorMut};
use crate::ast::{Ident, Statement, Value, ValueWithSpan, Visit, VisitMut, Visitor, VisitorMut};
use crate::dialect::GenericDialect;
use crate::parser::Parser;
use crate::tokenizer::Tokenizer;
Expand Down Expand Up @@ -1092,7 +1092,7 @@ mod visit_mut_tests {
.parse_statement()
.unwrap();

let flow = s.visit(visitor);
let flow = s.visit_mut(visitor);
assert_eq!(flow, ControlFlow::Continue(()));
s
}
Expand Down Expand Up @@ -1139,4 +1139,26 @@ mod visit_mut_tests {
let mutated = do_visit_mut("SELECT a, b FROM t", &mut visitor);
assert_eq!(mutated.to_string(), "SELECT A, B FROM T");
}

struct DummyVisitor;
impl Visitor for DummyVisitor {
type Break = ();
}

struct DummyVisitorMut;
impl VisitorMut for DummyVisitorMut {
type Break = ();
}

#[test]
fn test_both_visit_and_visit_mut() {
let mut visitor = DummyVisitor;
let mut visitor_mut = DummyVisitorMut;
let mut statements = Parser::parse_sql(&GenericDialect {}, "SELECT 1").unwrap();

let _ = statements.visit(&mut visitor);
let _ = statements.visit_mut(&mut visitor_mut);

assert_eq!(statements[0].to_string(), "SELECT 1");
}
}