diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs index 4c46a51bef58..3bbf9a0ad3a2 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs @@ -6,7 +6,7 @@ FxHashSet, RootDatabase, defs::Definition, helpers::mod_path_to_ast, - imports::insert_use::{ImportScope, InsertUseConfig, insert_use}, + imports::insert_use::{ImportScope, InsertUseConfig, insert_use_with_editor}, path_transform::PathTransform, search::FileReference, }; @@ -16,12 +16,14 @@ SyntaxKind::*, SyntaxNode, T, ast::{ - self, AstNode, HasAttrs, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit, make, + self, AstNode, HasAttrs, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit, + syntax_factory::SyntaxFactory, }, - match_ast, ted, + match_ast, + syntax_editor::{Position, SyntaxEditor}, }; -use crate::{AssistContext, AssistId, Assists, assist_context::SourceChangeBuilder}; +use crate::{AssistContext, AssistId, Assists}; // Assist: extract_struct_from_enum_variant // @@ -58,6 +60,8 @@ pub(crate) fn extract_struct_from_enum_variant( "Extract struct from enum variant", target, |builder| { + let make = SyntaxFactory::with_mappings(); + let mut editor = builder.make_editor(variant.syntax()); let edition = enum_hir.krate(ctx.db()).edition(ctx.db()); let variant_hir_name = variant_hir.name(ctx.db()); let enum_module_def = ModuleDef::from(enum_hir); @@ -73,40 +77,56 @@ pub(crate) fn extract_struct_from_enum_variant( def_file_references = Some(references); continue; } - builder.edit_file(file_id.file_id(ctx.db())); let processed = process_references( ctx, - builder, &mut visited_modules_set, &enum_module_def, &variant_hir_name, references, ); + if processed.is_empty() { + continue; + } + let mut file_editor = builder.make_editor(processed[0].0.syntax()); processed.into_iter().for_each(|(path, node, import)| { - apply_references(ctx.config.insert_use, path, node, import, edition) + apply_references( + ctx.config.insert_use, + path, + node, + import, + edition, + &mut file_editor, + &make, + ) }); + file_editor.add_mappings(make.take()); + builder.add_file_edits(file_id.file_id(ctx.db()), file_editor); } - builder.edit_file(ctx.vfs_file_id()); - let variant = builder.make_mut(variant.clone()); if let Some(references) = def_file_references { let processed = process_references( ctx, - builder, &mut visited_modules_set, &enum_module_def, &variant_hir_name, references, ); processed.into_iter().for_each(|(path, node, import)| { - apply_references(ctx.config.insert_use, path, node, import, edition) + apply_references( + ctx.config.insert_use, + path, + node, + import, + edition, + &mut editor, + &make, + ) }); } - let generic_params = enum_ast - .generic_param_list() - .and_then(|known_generics| extract_generic_params(&known_generics, &field_list)); - let generics = generic_params.as_ref().map(|generics| generics.clone_for_update()); + let generic_params = enum_ast.generic_param_list().and_then(|known_generics| { + extract_generic_params(&make, &known_generics, &field_list) + }); // resolve GenericArg in field_list to actual type let field_list = if let Some((target_scope, source_scope)) = @@ -124,25 +144,45 @@ pub(crate) fn extract_struct_from_enum_variant( } } } else { - field_list.clone_for_update() + field_list.clone() }; - let def = - create_struct_def(variant_name.clone(), &variant, &field_list, generics, &enum_ast); + let (comments_for_struct, comments_to_delete) = + collect_variant_comments(&make, variant.syntax()); + for element in &comments_to_delete { + editor.delete(element.clone()); + } + + let def = create_struct_def( + &make, + variant_name.clone(), + &field_list, + generic_params.clone(), + &enum_ast, + ); let enum_ast = variant.parent_enum(); let indent = enum_ast.indent_level(); let def = def.indent(indent); - ted::insert_all( - ted::Position::before(enum_ast.syntax()), - vec![ - def.syntax().clone().into(), - make::tokens::whitespace(&format!("\n\n{indent}")).into(), - ], + let mut insert_items: Vec = Vec::new(); + for attr in enum_ast.attrs() { + insert_items.push(attr.syntax().clone().into()); + insert_items.push(make.whitespace("\n").into()); + } + insert_items.extend(comments_for_struct); + insert_items.push(def.syntax().clone().into()); + insert_items.push(make.whitespace(&format!("\n\n{indent}")).into()); + editor.insert_all_with_whitespace( + Position::before(enum_ast.syntax()), + insert_items, + &make, ); - update_variant(&variant, generic_params.map(|g| g.clone_for_update())); + update_variant(&make, &mut editor, &variant, generic_params); + + editor.add_mappings(make.finish_with_mappings()); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } @@ -184,6 +224,7 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &En } fn extract_generic_params( + make: &SyntaxFactory, known_generics: &ast::GenericParamList, field_list: &Either, ) -> Option { @@ -201,7 +242,7 @@ fn extract_generic_params( }; let generics = generics.into_iter().filter_map(|(param, tag)| tag.then_some(param)); - tagged_one.then(|| make::generic_param_list(generics)) + tagged_one.then(|| make.generic_param_list(generics)) } fn tag_generics_in_variant(ty: &ast::Type, generics: &mut [(ast::GenericParam, bool)]) -> bool { @@ -250,82 +291,74 @@ fn tag_generics_in_variant(ty: &ast::Type, generics: &mut [(ast::GenericParam, b } fn create_struct_def( + make: &SyntaxFactory, name: ast::Name, - variant: &ast::Variant, field_list: &Either, generics: Option, enum_: &ast::Enum, ) -> ast::Struct { let enum_vis = enum_.visibility(); - let insert_vis = |node: &'_ SyntaxNode, vis: &'_ SyntaxNode| { - let vis = vis.clone_for_update(); - ted::insert(ted::Position::before(node), vis); - }; - // for fields without any existing visibility, use visibility of enum let field_list: ast::FieldList = match field_list { Either::Left(field_list) => { if let Some(vis) = &enum_vis { - field_list - .fields() - .filter(|field| field.visibility().is_none()) - .filter_map(|field| field.name()) - .for_each(|it| insert_vis(it.syntax(), vis.syntax())); + let new_fields = field_list.fields().map(|field| { + if field.visibility().is_none() + && let Some(name) = field.name() + && let Some(ty) = field.ty() + { + make.record_field(Some(vis.clone()), name, ty) + } else { + field + } + }); + make.record_field_list(new_fields).into() + } else { + field_list.clone().into() } - - field_list.clone().into() } Either::Right(field_list) => { if let Some(vis) = &enum_vis { - field_list - .fields() - .filter(|field| field.visibility().is_none()) - .filter_map(|field| field.ty()) - .for_each(|it| insert_vis(it.syntax(), vis.syntax())); + let new_fields = field_list.fields().map(|field| { + if field.visibility().is_none() + && let Some(ty) = field.ty() + { + make.tuple_field(Some(vis.clone()), ty) + } else { + field + } + }); + make.tuple_field_list(new_fields).into() + } else { + field_list.clone().into() } - - field_list.clone().into() } }; - let strukt = make::struct_(enum_vis, name, generics, field_list).clone_for_update(); - - // take comments from variant - ted::insert_all( - ted::Position::first_child_of(strukt.syntax()), - take_all_comments(variant.syntax()), - ); - - // copy attributes from enum - ted::insert_all( - ted::Position::first_child_of(strukt.syntax()), - enum_ - .attrs() - .flat_map(|it| { - vec![it.syntax().clone_for_update().into(), make::tokens::single_newline().into()] - }) - .collect(), - ); - - strukt + make.struct_(enum_vis, name, generics, field_list) } -fn update_variant(variant: &ast::Variant, generics: Option) -> Option<()> { +fn update_variant( + make: &SyntaxFactory, + editor: &mut SyntaxEditor, + variant: &ast::Variant, + generics: Option, +) -> Option<()> { let name = variant.name()?; let generic_args = generics .filter(|generics| generics.generic_params().count() > 0) .map(|generics| generics.to_generic_args()); // FIXME: replace with a `ast::make` constructor let ty = match generic_args { - Some(generic_args) => make::ty(&format!("{name}{generic_args}")), - None => make::ty(&name.text()), + Some(generic_args) => make.ty(&format!("{name}{generic_args}")), + None => make.ty(&name.text()), }; // change from a record to a tuple field list - let tuple_field = make::tuple_field(None, ty); - let field_list = make::tuple_field_list(iter::once(tuple_field)).clone_for_update(); - ted::replace(variant.field_list()?.syntax(), field_list.syntax()); + let tuple_field = make.tuple_field(None, ty); + let field_list = make.tuple_field_list(iter::once(tuple_field)); + editor.replace(variant.field_list()?.syntax(), field_list.syntax()); // remove any ws after the name if let Some(ws) = name @@ -333,35 +366,39 @@ fn update_variant(variant: &ast::Variant, generics: Option Vec { - let mut remove_next_ws = false; - node.children_with_tokens() - .filter_map(move |child| match child.kind() { +fn collect_variant_comments( + make: &SyntaxFactory, + node: &SyntaxNode, +) -> (Vec, Vec) { + let mut to_insert: Vec = Vec::new(); + let mut to_delete: Vec = Vec::new(); + let mut after_comment = false; + + for child in node.children_with_tokens() { + match child.kind() { COMMENT => { - remove_next_ws = true; - child.detach(); - Some(child) + after_comment = true; + to_insert.push(child.clone()); + to_delete.push(child); } - WHITESPACE if remove_next_ws => { - remove_next_ws = false; - child.detach(); - Some(make::tokens::single_newline().into()) + WHITESPACE if after_comment => { + after_comment = false; + to_insert.push(make.whitespace("\n").into()); + to_delete.push(child); } _ => { - remove_next_ws = false; - None + after_comment = false; } - }) - .collect() + } + } + + (to_insert, to_delete) } fn apply_references( @@ -370,20 +407,27 @@ fn apply_references( node: SyntaxNode, import: Option<(ImportScope, hir::ModPath)>, edition: Edition, + editor: &mut SyntaxEditor, + make: &SyntaxFactory, ) { if let Some((scope, path)) = import { - insert_use(&scope, mod_path_to_ast(&path, edition), &insert_use_cfg); + insert_use_with_editor( + &scope, + mod_path_to_ast(&path, edition), + &insert_use_cfg, + editor, + make, + ); } // deep clone to prevent cycle - let path = make::path_from_segments(iter::once(segment.clone_subtree()), false); - ted::insert_raw(ted::Position::before(segment.syntax()), path.clone_for_update().syntax()); - ted::insert_raw(ted::Position::before(segment.syntax()), make::token(T!['('])); - ted::insert_raw(ted::Position::after(&node), make::token(T![')'])); + let path = make.path_from_segments(iter::once(segment.clone()), false); + editor.insert(Position::before(segment.syntax()), make.token(T!['('])); + editor.insert(Position::before(segment.syntax()), path.syntax()); + editor.insert(Position::after(&node), make.token(T![')'])); } fn process_references( ctx: &AssistContext<'_>, - builder: &mut SourceChangeBuilder, visited_modules: &mut FxHashSet, enum_module_def: &ModuleDef, variant_hir_name: &Name, @@ -394,8 +438,6 @@ fn process_references( refs.into_iter() .flat_map(|reference| { let (segment, scope_node, module) = reference_to_node(&ctx.sema, reference)?; - let segment = builder.make_mut(segment); - let scope_node = builder.make_syntax_mut(scope_node); if !visited_modules.contains(&module) { let cfg = ctx.config.find_path_config(ctx.sema.is_nightly(module.krate(ctx.sema.db))); diff --git a/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs b/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs index 3a109a48e489..41ce1e59603d 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs @@ -709,7 +709,11 @@ fn insert_use_with_editor_( Some(b) => { cov_mark::hit!(insert_empty_module); syntax_editor.insert(Position::after(&b), syntax_factory.whitespace("\n")); - syntax_editor.insert(Position::after(&b), use_item.syntax()); + syntax_editor.insert_with_whitespace( + Position::after(&b), + use_item.syntax(), + syntax_factory, + ); } None => { cov_mark::hit!(insert_empty_file); diff --git a/src/tools/rust-analyzer/crates/syntax/src/syntax_editor.rs b/src/tools/rust-analyzer/crates/syntax/src/syntax_editor.rs index dbb9f15e173e..8e4dc75d2219 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/syntax_editor.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/syntax_editor.rs @@ -14,7 +14,10 @@ use rowan::TextRange; use rustc_hash::FxHashMap; -use crate::{AstNode, SyntaxElement, SyntaxNode, SyntaxToken}; +use crate::{ + AstNode, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, T, + ast::{self, edit::IndentLevel, syntax_factory::SyntaxFactory}, +}; mod edit_algo; mod edits; @@ -101,6 +104,34 @@ pub fn insert_all(&mut self, position: Position, elements: Vec) { self.changes.push(Change::InsertAll(position, elements)) } + pub fn insert_with_whitespace( + &mut self, + position: Position, + element: impl Element, + factory: &SyntaxFactory, + ) { + self.insert_all_with_whitespace(position, vec![element.syntax_element()], factory) + } + + pub fn insert_all_with_whitespace( + &mut self, + position: Position, + mut elements: Vec, + factory: &SyntaxFactory, + ) { + if let Some(first) = elements.first() + && let Some(ws) = ws_before(&position, first, factory) + { + elements.insert(0, ws.into()); + } + if let Some(last) = elements.last() + && let Some(ws) = ws_after(&position, last, factory) + { + elements.push(ws.into()); + } + self.insert_all(position, elements) + } + pub fn delete(&mut self, element: impl Element) { let element = element.syntax_element(); debug_assert!(is_ancestor_or_self_of_element(&element, &self.root)); @@ -412,6 +443,86 @@ fn syntax_element(self) -> SyntaxElement { } } +fn ws_before( + position: &Position, + new: &SyntaxElement, + factory: &SyntaxFactory, +) -> Option { + let prev = match &position.repr { + PositionRepr::FirstChild(_) => return None, + PositionRepr::After(it) => it, + }; + + if prev.kind() == T!['{'] + && new.kind() == SyntaxKind::USE + && let Some(item_list) = prev.parent().and_then(ast::ItemList::cast) + { + let mut indent = IndentLevel::from_element(&item_list.syntax().clone().into()); + indent.0 += 1; + return Some(factory.whitespace(&format!("\n{indent}"))); + } + + if prev.kind() == T!['{'] + && ast::Stmt::can_cast(new.kind()) + && let Some(stmt_list) = prev.parent().and_then(ast::StmtList::cast) + { + let mut indent = IndentLevel::from_element(&stmt_list.syntax().clone().into()); + indent.0 += 1; + return Some(factory.whitespace(&format!("\n{indent}"))); + } + + ws_between(prev, new, factory) +} + +fn ws_after( + position: &Position, + new: &SyntaxElement, + factory: &SyntaxFactory, +) -> Option { + let next = match &position.repr { + PositionRepr::FirstChild(parent) => parent.first_child_or_token()?, + PositionRepr::After(sibling) => sibling.next_sibling_or_token()?, + }; + ws_between(new, &next, factory) +} + +fn ws_between( + left: &SyntaxElement, + right: &SyntaxElement, + factory: &SyntaxFactory, +) -> Option { + if left.kind() == SyntaxKind::WHITESPACE || right.kind() == SyntaxKind::WHITESPACE { + return None; + } + if right.kind() == T![;] || right.kind() == T![,] { + return None; + } + if left.kind() == T![<] || right.kind() == T![>] { + return None; + } + if left.kind() == T![&] && right.kind() == SyntaxKind::LIFETIME { + return None; + } + if right.kind() == SyntaxKind::GENERIC_ARG_LIST { + return None; + } + if right.kind() == SyntaxKind::USE { + let mut indent = IndentLevel::from_element(left); + if left.kind() == SyntaxKind::USE { + indent.0 = IndentLevel::from_element(right).0.max(indent.0); + } + return Some(factory.whitespace(&format!("\n{indent}"))); + } + if left.kind() == SyntaxKind::ATTR { + let mut indent = IndentLevel::from_element(right); + if right.kind() == SyntaxKind::ATTR { + indent.0 = IndentLevel::from_element(left).0.max(indent.0); + } + return Some(factory.whitespace(&format!("\n{indent}"))); + } + Some(factory.whitespace(" ")) +} + fn is_ancestor_or_self(node: &SyntaxNode, ancestor: &SyntaxNode) -> bool { node == ancestor || node.ancestors().any(|it| &it == ancestor) }