Merge pull request #21996 from Shourya742/2026-04-08-migrate-extract-struct-from-enum-variant

Migrate extract struct from enum variant to new SyntaxEditor and Port whitespace heuristics to SyntaxEditor
This commit is contained in:
Chayim Refael Friedman
2026-04-09 12:22:45 +00:00
committed by GitHub
3 changed files with 258 additions and 101 deletions
@@ -6,7 +6,7 @@
FxHashSet, RootDatabase,
defs::Definition,
helpers::mod_path_to_ast,
imports::insert_use::{ImportScope, InsertUseConfig, insert_use},
imports::insert_use::{ImportScope, InsertUseConfig, insert_use_with_editor},
path_transform::PathTransform,
search::FileReference,
};
@@ -16,12 +16,14 @@
SyntaxKind::*,
SyntaxNode, T,
ast::{
self, AstNode, HasAttrs, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit, make,
self, AstNode, HasAttrs, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit,
syntax_factory::SyntaxFactory,
},
match_ast, ted,
match_ast,
syntax_editor::{Position, SyntaxEditor},
};
use crate::{AssistContext, AssistId, Assists, assist_context::SourceChangeBuilder};
use crate::{AssistContext, AssistId, Assists};
// Assist: extract_struct_from_enum_variant
//
@@ -58,6 +60,8 @@ pub(crate) fn extract_struct_from_enum_variant(
"Extract struct from enum variant",
target,
|builder| {
let make = SyntaxFactory::with_mappings();
let mut editor = builder.make_editor(variant.syntax());
let edition = enum_hir.krate(ctx.db()).edition(ctx.db());
let variant_hir_name = variant_hir.name(ctx.db());
let enum_module_def = ModuleDef::from(enum_hir);
@@ -73,40 +77,56 @@ pub(crate) fn extract_struct_from_enum_variant(
def_file_references = Some(references);
continue;
}
builder.edit_file(file_id.file_id(ctx.db()));
let processed = process_references(
ctx,
builder,
&mut visited_modules_set,
&enum_module_def,
&variant_hir_name,
references,
);
if processed.is_empty() {
continue;
}
let mut file_editor = builder.make_editor(processed[0].0.syntax());
processed.into_iter().for_each(|(path, node, import)| {
apply_references(ctx.config.insert_use, path, node, import, edition)
apply_references(
ctx.config.insert_use,
path,
node,
import,
edition,
&mut file_editor,
&make,
)
});
file_editor.add_mappings(make.take());
builder.add_file_edits(file_id.file_id(ctx.db()), file_editor);
}
builder.edit_file(ctx.vfs_file_id());
let variant = builder.make_mut(variant.clone());
if let Some(references) = def_file_references {
let processed = process_references(
ctx,
builder,
&mut visited_modules_set,
&enum_module_def,
&variant_hir_name,
references,
);
processed.into_iter().for_each(|(path, node, import)| {
apply_references(ctx.config.insert_use, path, node, import, edition)
apply_references(
ctx.config.insert_use,
path,
node,
import,
edition,
&mut editor,
&make,
)
});
}
let generic_params = enum_ast
.generic_param_list()
.and_then(|known_generics| extract_generic_params(&known_generics, &field_list));
let generics = generic_params.as_ref().map(|generics| generics.clone_for_update());
let generic_params = enum_ast.generic_param_list().and_then(|known_generics| {
extract_generic_params(&make, &known_generics, &field_list)
});
// resolve GenericArg in field_list to actual type
let field_list = if let Some((target_scope, source_scope)) =
@@ -124,25 +144,45 @@ pub(crate) fn extract_struct_from_enum_variant(
}
}
} else {
field_list.clone_for_update()
field_list.clone()
};
let def =
create_struct_def(variant_name.clone(), &variant, &field_list, generics, &enum_ast);
let (comments_for_struct, comments_to_delete) =
collect_variant_comments(&make, variant.syntax());
for element in &comments_to_delete {
editor.delete(element.clone());
}
let def = create_struct_def(
&make,
variant_name.clone(),
&field_list,
generic_params.clone(),
&enum_ast,
);
let enum_ast = variant.parent_enum();
let indent = enum_ast.indent_level();
let def = def.indent(indent);
ted::insert_all(
ted::Position::before(enum_ast.syntax()),
vec![
def.syntax().clone().into(),
make::tokens::whitespace(&format!("\n\n{indent}")).into(),
],
let mut insert_items: Vec<SyntaxElement> = Vec::new();
for attr in enum_ast.attrs() {
insert_items.push(attr.syntax().clone().into());
insert_items.push(make.whitespace("\n").into());
}
insert_items.extend(comments_for_struct);
insert_items.push(def.syntax().clone().into());
insert_items.push(make.whitespace(&format!("\n\n{indent}")).into());
editor.insert_all_with_whitespace(
Position::before(enum_ast.syntax()),
insert_items,
&make,
);
update_variant(&variant, generic_params.map(|g| g.clone_for_update()));
update_variant(&make, &mut editor, &variant, generic_params);
editor.add_mappings(make.finish_with_mappings());
builder.add_file_edits(ctx.vfs_file_id(), editor);
},
)
}
@@ -184,6 +224,7 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &En
}
fn extract_generic_params(
make: &SyntaxFactory,
known_generics: &ast::GenericParamList,
field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
) -> Option<ast::GenericParamList> {
@@ -201,7 +242,7 @@ fn extract_generic_params(
};
let generics = generics.into_iter().filter_map(|(param, tag)| tag.then_some(param));
tagged_one.then(|| make::generic_param_list(generics))
tagged_one.then(|| make.generic_param_list(generics))
}
fn tag_generics_in_variant(ty: &ast::Type, generics: &mut [(ast::GenericParam, bool)]) -> bool {
@@ -250,82 +291,74 @@ fn tag_generics_in_variant(ty: &ast::Type, generics: &mut [(ast::GenericParam, b
}
fn create_struct_def(
make: &SyntaxFactory,
name: ast::Name,
variant: &ast::Variant,
field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
generics: Option<ast::GenericParamList>,
enum_: &ast::Enum,
) -> ast::Struct {
let enum_vis = enum_.visibility();
let insert_vis = |node: &'_ SyntaxNode, vis: &'_ SyntaxNode| {
let vis = vis.clone_for_update();
ted::insert(ted::Position::before(node), vis);
};
// for fields without any existing visibility, use visibility of enum
let field_list: ast::FieldList = match field_list {
Either::Left(field_list) => {
if let Some(vis) = &enum_vis {
field_list
.fields()
.filter(|field| field.visibility().is_none())
.filter_map(|field| field.name())
.for_each(|it| insert_vis(it.syntax(), vis.syntax()));
let new_fields = field_list.fields().map(|field| {
if field.visibility().is_none()
&& let Some(name) = field.name()
&& let Some(ty) = field.ty()
{
make.record_field(Some(vis.clone()), name, ty)
} else {
field
}
});
make.record_field_list(new_fields).into()
} else {
field_list.clone().into()
}
field_list.clone().into()
}
Either::Right(field_list) => {
if let Some(vis) = &enum_vis {
field_list
.fields()
.filter(|field| field.visibility().is_none())
.filter_map(|field| field.ty())
.for_each(|it| insert_vis(it.syntax(), vis.syntax()));
let new_fields = field_list.fields().map(|field| {
if field.visibility().is_none()
&& let Some(ty) = field.ty()
{
make.tuple_field(Some(vis.clone()), ty)
} else {
field
}
});
make.tuple_field_list(new_fields).into()
} else {
field_list.clone().into()
}
field_list.clone().into()
}
};
let strukt = make::struct_(enum_vis, name, generics, field_list).clone_for_update();
// take comments from variant
ted::insert_all(
ted::Position::first_child_of(strukt.syntax()),
take_all_comments(variant.syntax()),
);
// copy attributes from enum
ted::insert_all(
ted::Position::first_child_of(strukt.syntax()),
enum_
.attrs()
.flat_map(|it| {
vec![it.syntax().clone_for_update().into(), make::tokens::single_newline().into()]
})
.collect(),
);
strukt
make.struct_(enum_vis, name, generics, field_list)
}
fn update_variant(variant: &ast::Variant, generics: Option<ast::GenericParamList>) -> Option<()> {
fn update_variant(
make: &SyntaxFactory,
editor: &mut SyntaxEditor,
variant: &ast::Variant,
generics: Option<ast::GenericParamList>,
) -> Option<()> {
let name = variant.name()?;
let generic_args = generics
.filter(|generics| generics.generic_params().count() > 0)
.map(|generics| generics.to_generic_args());
// FIXME: replace with a `ast::make` constructor
let ty = match generic_args {
Some(generic_args) => make::ty(&format!("{name}{generic_args}")),
None => make::ty(&name.text()),
Some(generic_args) => make.ty(&format!("{name}{generic_args}")),
None => make.ty(&name.text()),
};
// change from a record to a tuple field list
let tuple_field = make::tuple_field(None, ty);
let field_list = make::tuple_field_list(iter::once(tuple_field)).clone_for_update();
ted::replace(variant.field_list()?.syntax(), field_list.syntax());
let tuple_field = make.tuple_field(None, ty);
let field_list = make.tuple_field_list(iter::once(tuple_field));
editor.replace(variant.field_list()?.syntax(), field_list.syntax());
// remove any ws after the name
if let Some(ws) = name
@@ -333,35 +366,39 @@ fn update_variant(variant: &ast::Variant, generics: Option<ast::GenericParamList
.siblings_with_tokens(syntax::Direction::Next)
.find_map(|tok| tok.into_token().filter(|tok| tok.kind() == WHITESPACE))
{
ted::remove(SyntaxElement::Token(ws));
editor.delete(ws);
}
Some(())
}
// Note: this also detaches whitespace after comments,
// since `SyntaxNode::splice_children` (and by extension `ted::insert_all_raw`)
// detaches nodes. If we only took the comments, we'd leave behind the old whitespace.
fn take_all_comments(node: &SyntaxNode) -> Vec<SyntaxElement> {
let mut remove_next_ws = false;
node.children_with_tokens()
.filter_map(move |child| match child.kind() {
fn collect_variant_comments(
make: &SyntaxFactory,
node: &SyntaxNode,
) -> (Vec<SyntaxElement>, Vec<SyntaxElement>) {
let mut to_insert: Vec<SyntaxElement> = Vec::new();
let mut to_delete: Vec<SyntaxElement> = Vec::new();
let mut after_comment = false;
for child in node.children_with_tokens() {
match child.kind() {
COMMENT => {
remove_next_ws = true;
child.detach();
Some(child)
after_comment = true;
to_insert.push(child.clone());
to_delete.push(child);
}
WHITESPACE if remove_next_ws => {
remove_next_ws = false;
child.detach();
Some(make::tokens::single_newline().into())
WHITESPACE if after_comment => {
after_comment = false;
to_insert.push(make.whitespace("\n").into());
to_delete.push(child);
}
_ => {
remove_next_ws = false;
None
after_comment = false;
}
})
.collect()
}
}
(to_insert, to_delete)
}
fn apply_references(
@@ -370,20 +407,27 @@ fn apply_references(
node: SyntaxNode,
import: Option<(ImportScope, hir::ModPath)>,
edition: Edition,
editor: &mut SyntaxEditor,
make: &SyntaxFactory,
) {
if let Some((scope, path)) = import {
insert_use(&scope, mod_path_to_ast(&path, edition), &insert_use_cfg);
insert_use_with_editor(
&scope,
mod_path_to_ast(&path, edition),
&insert_use_cfg,
editor,
make,
);
}
// deep clone to prevent cycle
let path = make::path_from_segments(iter::once(segment.clone_subtree()), false);
ted::insert_raw(ted::Position::before(segment.syntax()), path.clone_for_update().syntax());
ted::insert_raw(ted::Position::before(segment.syntax()), make::token(T!['(']));
ted::insert_raw(ted::Position::after(&node), make::token(T![')']));
let path = make.path_from_segments(iter::once(segment.clone()), false);
editor.insert(Position::before(segment.syntax()), make.token(T!['(']));
editor.insert(Position::before(segment.syntax()), path.syntax());
editor.insert(Position::after(&node), make.token(T![')']));
}
fn process_references(
ctx: &AssistContext<'_>,
builder: &mut SourceChangeBuilder,
visited_modules: &mut FxHashSet<Module>,
enum_module_def: &ModuleDef,
variant_hir_name: &Name,
@@ -394,8 +438,6 @@ fn process_references(
refs.into_iter()
.flat_map(|reference| {
let (segment, scope_node, module) = reference_to_node(&ctx.sema, reference)?;
let segment = builder.make_mut(segment);
let scope_node = builder.make_syntax_mut(scope_node);
if !visited_modules.contains(&module) {
let cfg =
ctx.config.find_path_config(ctx.sema.is_nightly(module.krate(ctx.sema.db)));
@@ -709,7 +709,11 @@ fn insert_use_with_editor_(
Some(b) => {
cov_mark::hit!(insert_empty_module);
syntax_editor.insert(Position::after(&b), syntax_factory.whitespace("\n"));
syntax_editor.insert(Position::after(&b), use_item.syntax());
syntax_editor.insert_with_whitespace(
Position::after(&b),
use_item.syntax(),
syntax_factory,
);
}
None => {
cov_mark::hit!(insert_empty_file);
@@ -14,7 +14,10 @@
use rowan::TextRange;
use rustc_hash::FxHashMap;
use crate::{AstNode, SyntaxElement, SyntaxNode, SyntaxToken};
use crate::{
AstNode, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, T,
ast::{self, edit::IndentLevel, syntax_factory::SyntaxFactory},
};
mod edit_algo;
mod edits;
@@ -101,6 +104,34 @@ pub fn insert_all(&mut self, position: Position, elements: Vec<SyntaxElement>) {
self.changes.push(Change::InsertAll(position, elements))
}
pub fn insert_with_whitespace(
&mut self,
position: Position,
element: impl Element,
factory: &SyntaxFactory,
) {
self.insert_all_with_whitespace(position, vec![element.syntax_element()], factory)
}
pub fn insert_all_with_whitespace(
&mut self,
position: Position,
mut elements: Vec<SyntaxElement>,
factory: &SyntaxFactory,
) {
if let Some(first) = elements.first()
&& let Some(ws) = ws_before(&position, first, factory)
{
elements.insert(0, ws.into());
}
if let Some(last) = elements.last()
&& let Some(ws) = ws_after(&position, last, factory)
{
elements.push(ws.into());
}
self.insert_all(position, elements)
}
pub fn delete(&mut self, element: impl Element) {
let element = element.syntax_element();
debug_assert!(is_ancestor_or_self_of_element(&element, &self.root));
@@ -412,6 +443,86 @@ fn syntax_element(self) -> SyntaxElement {
}
}
fn ws_before(
position: &Position,
new: &SyntaxElement,
factory: &SyntaxFactory,
) -> Option<SyntaxToken> {
let prev = match &position.repr {
PositionRepr::FirstChild(_) => return None,
PositionRepr::After(it) => it,
};
if prev.kind() == T!['{']
&& new.kind() == SyntaxKind::USE
&& let Some(item_list) = prev.parent().and_then(ast::ItemList::cast)
{
let mut indent = IndentLevel::from_element(&item_list.syntax().clone().into());
indent.0 += 1;
return Some(factory.whitespace(&format!("\n{indent}")));
}
if prev.kind() == T!['{']
&& ast::Stmt::can_cast(new.kind())
&& let Some(stmt_list) = prev.parent().and_then(ast::StmtList::cast)
{
let mut indent = IndentLevel::from_element(&stmt_list.syntax().clone().into());
indent.0 += 1;
return Some(factory.whitespace(&format!("\n{indent}")));
}
ws_between(prev, new, factory)
}
fn ws_after(
position: &Position,
new: &SyntaxElement,
factory: &SyntaxFactory,
) -> Option<SyntaxToken> {
let next = match &position.repr {
PositionRepr::FirstChild(parent) => parent.first_child_or_token()?,
PositionRepr::After(sibling) => sibling.next_sibling_or_token()?,
};
ws_between(new, &next, factory)
}
fn ws_between(
left: &SyntaxElement,
right: &SyntaxElement,
factory: &SyntaxFactory,
) -> Option<SyntaxToken> {
if left.kind() == SyntaxKind::WHITESPACE || right.kind() == SyntaxKind::WHITESPACE {
return None;
}
if right.kind() == T![;] || right.kind() == T![,] {
return None;
}
if left.kind() == T![<] || right.kind() == T![>] {
return None;
}
if left.kind() == T![&] && right.kind() == SyntaxKind::LIFETIME {
return None;
}
if right.kind() == SyntaxKind::GENERIC_ARG_LIST {
return None;
}
if right.kind() == SyntaxKind::USE {
let mut indent = IndentLevel::from_element(left);
if left.kind() == SyntaxKind::USE {
indent.0 = IndentLevel::from_element(right).0.max(indent.0);
}
return Some(factory.whitespace(&format!("\n{indent}")));
}
if left.kind() == SyntaxKind::ATTR {
let mut indent = IndentLevel::from_element(right);
if right.kind() == SyntaxKind::ATTR {
indent.0 = IndentLevel::from_element(left).0.max(indent.0);
}
return Some(factory.whitespace(&format!("\n{indent}")));
}
Some(factory.whitespace(" "))
}
fn is_ancestor_or_self(node: &SyntaxNode, ancestor: &SyntaxNode) -> bool {
node == ancestor || node.ancestors().any(|it| &it == ancestor)
}