From ddefc4b7b4e68e8dace3ddd9f485a663c40b0017 Mon Sep 17 00:00:00 2001 From: bit-aloo Date: Mon, 24 Nov 2025 11:18:14 +0530 Subject: [PATCH] add codec and framing to abstract encoding and decoding logic from run --- .../crates/proc-macro-api/src/codec.rs | 12 ++ .../crates/proc-macro-api/src/framing.rs | 14 ++ .../proc-macro-api/src/legacy_protocol.rs | 33 ++-- .../src/legacy_protocol/json.rs | 72 ++++++--- .../proc-macro-api/src/legacy_protocol/msg.rs | 56 +------ .../src/legacy_protocol/postcard.rs | 51 +++--- .../crates/proc-macro-api/src/lib.rs | 11 +- .../crates/proc-macro-api/src/process.rs | 13 +- .../crates/proc-macro-srv-cli/Cargo.toml | 2 +- .../proc-macro-srv-cli/src/main_loop.rs | 148 ++---------------- 10 files changed, 142 insertions(+), 270 deletions(-) create mode 100644 src/tools/rust-analyzer/crates/proc-macro-api/src/codec.rs create mode 100644 src/tools/rust-analyzer/crates/proc-macro-api/src/framing.rs diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/codec.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/codec.rs new file mode 100644 index 000000000000..baccaa6be4c2 --- /dev/null +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/codec.rs @@ -0,0 +1,12 @@ +//! Protocol codec + +use std::io; + +use serde::de::DeserializeOwned; + +use crate::framing::Framing; + +pub trait Codec: Framing { + fn encode(msg: &T) -> io::Result; + fn decode(buf: &mut Self::Buf) -> io::Result; +} diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/framing.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/framing.rs new file mode 100644 index 000000000000..a1e6fc05ca11 --- /dev/null +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/framing.rs @@ -0,0 +1,14 @@ +//! Protocol framing + +use std::io::{self, BufRead, Write}; + +pub trait Framing { + type Buf: Default; + + fn read<'a, R: BufRead>( + inp: &mut R, + buf: &'a mut Self::Buf, + ) -> io::Result>; + + fn write(out: &mut W, buf: &Self::Buf) -> io::Result<()>; +} diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol.rs index 6d521d00cd90..c2b132ddcc1d 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol.rs @@ -14,14 +14,15 @@ use crate::{ ProcMacro, ProcMacroKind, ServerError, + codec::Codec, legacy_protocol::{ - json::{read_json, write_json}, + json::JsonProtocol, msg::{ ExpandMacro, ExpandMacroData, ExpnGlobals, FlatTree, Message, Request, Response, ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map, flat::serialize_span_data_index_map, }, - postcard::{read_postcard, write_postcard}, + postcard::PostcardProtocol, }, process::ProcMacroServerProcess, version, @@ -154,42 +155,26 @@ fn send_task(srv: &ProcMacroServerProcess, req: Request) -> Result, req) } else { - srv.send_task(send_request, req) + srv.send_task(send_request::, req) } } /// Sends a request to the server and reads the response. -fn send_request( +fn send_request( mut writer: &mut dyn Write, mut reader: &mut dyn BufRead, req: Request, - buf: &mut String, + buf: &mut P::Buf, ) -> Result, ServerError> { - req.write(write_json, &mut writer).map_err(|err| ServerError { + req.write::<_, P>(&mut writer).map_err(|err| ServerError { message: "failed to write request".into(), io: Some(Arc::new(err)), })?; - let res = Response::read(read_json, &mut reader, buf).map_err(|err| ServerError { + let res = Response::read::<_, P>(&mut reader, buf).map_err(|err| ServerError { message: "failed to read response".into(), io: Some(Arc::new(err)), })?; Ok(res) } - -fn send_request_postcard( - mut writer: &mut dyn Write, - mut reader: &mut dyn BufRead, - req: Request, - buf: &mut Vec, -) -> Result, ServerError> { - req.write_postcard(write_postcard, &mut writer).map_err(|err| ServerError { - message: "failed to write request".into(), - io: Some(Arc::new(err)), - })?; - let res = Response::read_postcard(read_postcard, &mut reader, buf).map_err(|err| { - ServerError { message: "failed to read response".into(), io: Some(Arc::new(err)) } - })?; - Ok(res) -} diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/json.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/json.rs index cf8535f77d53..1359c0568402 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/json.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/json.rs @@ -1,36 +1,58 @@ //! Protocol functions for json. use std::io::{self, BufRead, Write}; -/// Reads a JSON message from the input stream. -pub fn read_json<'a>( - inp: &mut impl BufRead, - buf: &'a mut String, -) -> io::Result> { - loop { - buf.clear(); +use serde::{Serialize, de::DeserializeOwned}; - inp.read_line(buf)?; - buf.pop(); // Remove trailing '\n' +use crate::{codec::Codec, framing::Framing}; - if buf.is_empty() { - return Ok(None); +pub struct JsonProtocol; + +impl Framing for JsonProtocol { + type Buf = String; + + fn read<'a, R: BufRead>( + inp: &mut R, + buf: &'a mut String, + ) -> io::Result> { + loop { + buf.clear(); + + inp.read_line(buf)?; + buf.pop(); // Remove trailing '\n' + + if buf.is_empty() { + return Ok(None); + } + + // Some ill behaved macro try to use stdout for debugging + // We ignore it here + if !buf.starts_with('{') { + tracing::error!("proc-macro tried to print : {}", buf); + continue; + } + + return Ok(Some(buf)); } + } - // Some ill behaved macro try to use stdout for debugging - // We ignore it here - if !buf.starts_with('{') { - tracing::error!("proc-macro tried to print : {}", buf); - continue; - } - - return Ok(Some(buf)); + fn write(out: &mut W, buf: &String) -> io::Result<()> { + tracing::debug!("> {}", buf); + out.write_all(buf.as_bytes())?; + out.write_all(b"\n")?; + out.flush() } } -/// Writes a JSON message to the output stream. -pub fn write_json(out: &mut impl Write, msg: &String) -> io::Result<()> { - tracing::debug!("> {}", msg); - out.write_all(msg.as_bytes())?; - out.write_all(b"\n")?; - out.flush() +impl Codec for JsonProtocol { + fn encode(msg: &T) -> io::Result { + Ok(serde_json::to_string(msg)?) + } + + fn decode(buf: &mut String) -> io::Result { + let mut deserializer = serde_json::Deserializer::from_str(buf); + // Note that some proc-macro generate very deep syntax tree + // We have to disable the current limit of serde here + deserializer.disable_recursion_limit(); + Ok(T::deserialize(&mut deserializer)?) + } } diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg.rs index 6df184630de7..1c77863aac34 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg.rs @@ -8,10 +8,7 @@ use serde::de::DeserializeOwned; use serde_derive::{Deserialize, Serialize}; -use crate::{ - ProcMacroKind, - legacy_protocol::postcard::{decode_cobs, encode_cobs}, -}; +use crate::{ProcMacroKind, codec::Codec}; /// Represents requests sent from the client to the proc-macro-srv. #[derive(Debug, Serialize, Deserialize)] @@ -152,60 +149,21 @@ fn skip_serializing_if(&self) -> bool { } pub trait Message: serde::Serialize + DeserializeOwned { - fn read( - from_proto: ProtocolRead, - inp: &mut R, - buf: &mut String, - ) -> io::Result> { - Ok(match from_proto(inp, buf)? { + fn read(inp: &mut R, buf: &mut C::Buf) -> io::Result> { + Ok(match C::read(inp, buf)? { None => None, - Some(text) => { - let mut deserializer = serde_json::Deserializer::from_str(text); - // Note that some proc-macro generate very deep syntax tree - // We have to disable the current limit of serde here - deserializer.disable_recursion_limit(); - Some(Self::deserialize(&mut deserializer)?) - } + Some(buf) => C::decode(buf)?, }) } - fn write(self, to_proto: ProtocolWrite, out: &mut W) -> io::Result<()> { - let text = serde_json::to_string(&self)?; - to_proto(out, &text) - } - - fn read_postcard( - from_proto: ProtocolRead>, - inp: &mut R, - buf: &mut Vec, - ) -> io::Result> { - Ok(match from_proto(inp, buf)? { - None => None, - Some(buf) => Some(decode_cobs(buf)?), - }) - } - - fn write_postcard( - self, - to_proto: ProtocolWrite>, - out: &mut W, - ) -> io::Result<()> { - let buf = encode_cobs(&self)?; - to_proto(out, &buf) + fn write(self, out: &mut W) -> io::Result<()> { + let value = C::encode(&self)?; + C::write(out, &value) } } impl Message for Request {} impl Message for Response {} -/// Type alias for a function that reads protocol messages from a buffered input stream. -#[allow(type_alias_bounds)] -type ProtocolRead = - for<'i, 'buf> fn(inp: &'i mut R, buf: &'buf mut Buf) -> io::Result>; -/// Type alias for a function that writes protocol messages to an output stream. -#[allow(type_alias_bounds)] -type ProtocolWrite = - for<'o, 'msg> fn(out: &'o mut W, msg: &'msg Buf) -> io::Result<()>; - #[cfg(test)] mod tests { use intern::{Symbol, sym}; diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/postcard.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/postcard.rs index 305e4de93415..c28a9bfe3a1a 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/postcard.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/postcard.rs @@ -2,28 +2,39 @@ use std::io::{self, BufRead, Write}; -pub fn read_postcard<'a>( - input: &mut impl BufRead, - buf: &'a mut Vec, -) -> io::Result>> { - buf.clear(); - let n = input.read_until(0, buf)?; - if n == 0 { - return Ok(None); +use serde::{Serialize, de::DeserializeOwned}; + +use crate::{codec::Codec, framing::Framing}; + +pub struct PostcardProtocol; + +impl Framing for PostcardProtocol { + type Buf = Vec; + + fn read<'a, R: BufRead>( + inp: &mut R, + buf: &'a mut Vec, + ) -> io::Result>> { + buf.clear(); + let n = inp.read_until(0, buf)?; + if n == 0 { + return Ok(None); + } + Ok(Some(buf)) + } + + fn write(out: &mut W, buf: &Vec) -> io::Result<()> { + out.write_all(buf)?; + out.flush() } - Ok(Some(buf)) } -#[allow(clippy::ptr_arg)] -pub fn write_postcard(out: &mut impl Write, msg: &Vec) -> io::Result<()> { - out.write_all(msg)?; - out.flush() -} +impl Codec for PostcardProtocol { + fn encode(msg: &T) -> io::Result> { + postcard::to_allocvec_cobs(msg).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + } -pub fn encode_cobs(value: &T) -> io::Result> { - postcard::to_allocvec_cobs(value).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) -} - -pub fn decode_cobs(bytes: &mut [u8]) -> io::Result { - postcard::from_bytes_cobs(bytes).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + fn decode(buf: &mut Self::Buf) -> io::Result { + postcard::from_bytes_cobs(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + } } diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/lib.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/lib.rs index 2cdb33ff81eb..a725b94f04b2 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-api/src/lib.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/lib.rs @@ -12,6 +12,8 @@ )] #![allow(internal_features)] +mod codec; +mod framing; pub mod legacy_protocol; mod process; @@ -19,7 +21,8 @@ use span::{ErasedFileAstId, FIXUP_ERASED_FILE_AST_ID_MARKER, Span}; use std::{fmt, io, sync::Arc, time::SystemTime}; -use crate::process::ProcMacroServerProcess; +pub use crate::codec::Codec; +use crate::{legacy_protocol::SpanMode, process::ProcMacroServerProcess}; /// The versions of the server protocol pub mod version { @@ -123,7 +126,11 @@ pub fn spawn<'a>( Item = (impl AsRef, &'a Option>), > + Clone, ) -> io::Result { - let process = ProcMacroServerProcess::run(process_path, env, process::Protocol::default())?; + let process = ProcMacroServerProcess::run( + process_path, + env, + process::Protocol::Postcard { mode: SpanMode::Id }, + )?; Ok(ProcMacroClient { process: Arc::new(process), path: process_path.to_owned() }) } diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/process.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/process.rs index 7f0cd05c8058..1365245f9846 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-api/src/process.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/process.rs @@ -34,12 +34,6 @@ pub(crate) enum Protocol { Postcard { mode: SpanMode }, } -impl Default for Protocol { - fn default() -> Self { - Protocol::Postcard { mode: SpanMode::Id } - } -} - /// Maintains the state of the proc-macro server process. #[derive(Debug)] struct ProcessSrvState { @@ -122,11 +116,10 @@ pub(crate) fn run<'a>( srv.version = version; if version >= version::RUST_ANALYZER_SPAN_SUPPORT - && let Ok(mode) = srv.enable_rust_analyzer_spans() + && let Ok(new_mode) = srv.enable_rust_analyzer_spans() { - srv.protocol = match protocol { - Protocol::Postcard { .. } => Protocol::Postcard { mode }, - Protocol::LegacyJson { .. } => Protocol::LegacyJson { mode }, + match &mut srv.protocol { + Protocol::Postcard { mode } | Protocol::LegacyJson { mode } => *mode = new_mode, }; } diff --git a/src/tools/rust-analyzer/crates/proc-macro-srv-cli/Cargo.toml b/src/tools/rust-analyzer/crates/proc-macro-srv-cli/Cargo.toml index f6022cf2c7bd..aa153897fa96 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-srv-cli/Cargo.toml +++ b/src/tools/rust-analyzer/crates/proc-macro-srv-cli/Cargo.toml @@ -18,7 +18,7 @@ postcard.workspace = true clap = {version = "4.5.42", default-features = false, features = ["std"]} [features] -default = ["postcard"] +default = [] sysroot-abi = ["proc-macro-srv/sysroot-abi", "proc-macro-api/sysroot-abi"] in-rust-tree = ["proc-macro-srv/in-rust-tree", "sysroot-abi"] diff --git a/src/tools/rust-analyzer/crates/proc-macro-srv-cli/src/main_loop.rs b/src/tools/rust-analyzer/crates/proc-macro-srv-cli/src/main_loop.rs index b0e7108d20a5..029ab6eca941 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-srv-cli/src/main_loop.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-srv-cli/src/main_loop.rs @@ -2,13 +2,14 @@ use std::io; use proc_macro_api::{ + Codec, legacy_protocol::{ - json::{read_json, write_json}, + json::JsonProtocol, msg::{ self, ExpandMacroData, ExpnGlobals, Message, SpanMode, SpanTransformer, deserialize_span_data_index_map, serialize_span_data_index_map, }, - postcard::{read_postcard, write_postcard}, + postcard::PostcardProtocol, }, version::CURRENT_API_VERSION, }; @@ -36,12 +37,12 @@ fn span_for_token_id( pub(crate) fn run(format: ProtocolFormat) -> io::Result<()> { match format { - ProtocolFormat::Json => run_json(), - ProtocolFormat::Postcard => run_postcard(), + ProtocolFormat::Json => run_::(), + ProtocolFormat::Postcard => run_::(), } } -fn run_json() -> io::Result<()> { +fn run_() -> io::Result<()> { fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind { match kind { proc_macro_srv::ProcMacroKind::CustomDerive => { @@ -52,9 +53,9 @@ fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::Pro } } - let mut buf = String::new(); - let mut read_request = || msg::Request::read(read_json, &mut io::stdin().lock(), &mut buf); - let write_response = |msg: msg::Response| msg.write(write_json, &mut io::stdout().lock()); + let mut buf = C::Buf::default(); + let mut read_request = || msg::Request::read::<_, C>(&mut io::stdin().lock(), &mut buf); + let write_response = |msg: msg::Response| msg.write::<_, C>(&mut io::stdout().lock()); let env = EnvSnapshot::default(); let srv = proc_macro_srv::ProcMacroSrv::new(&env); @@ -170,134 +171,3 @@ fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::Pro Ok(()) } - -fn run_postcard() -> io::Result<()> { - fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind { - match kind { - proc_macro_srv::ProcMacroKind::CustomDerive => { - proc_macro_api::ProcMacroKind::CustomDerive - } - proc_macro_srv::ProcMacroKind::Bang => proc_macro_api::ProcMacroKind::Bang, - proc_macro_srv::ProcMacroKind::Attr => proc_macro_api::ProcMacroKind::Attr, - } - } - - let mut buf = Vec::new(); - let mut read_request = - || msg::Request::read_postcard(read_postcard, &mut io::stdin().lock(), &mut buf); - let write_response = - |msg: msg::Response| msg.write_postcard(write_postcard, &mut io::stdout().lock()); - - let env = proc_macro_srv::EnvSnapshot::default(); - let srv = proc_macro_srv::ProcMacroSrv::new(&env); - - let mut span_mode = msg::SpanMode::Id; - - while let Some(req) = read_request()? { - let res = match req { - msg::Request::ListMacros { dylib_path } => { - msg::Response::ListMacros(srv.list_macros(&dylib_path).map(|macros| { - macros.into_iter().map(|(name, kind)| (name, macro_kind_to_api(kind))).collect() - })) - } - msg::Request::ExpandMacro(task) => { - let msg::ExpandMacro { - lib, - env, - current_dir, - data: - msg::ExpandMacroData { - macro_body, - macro_name, - attributes, - has_global_spans: - msg::ExpnGlobals { serialize: _, def_site, call_site, mixed_site }, - span_data_table, - }, - } = *task; - match span_mode { - msg::SpanMode::Id => msg::Response::ExpandMacro({ - let def_site = proc_macro_srv::SpanId(def_site as u32); - let call_site = proc_macro_srv::SpanId(call_site as u32); - let mixed_site = proc_macro_srv::SpanId(mixed_site as u32); - - let macro_body = - macro_body.to_subtree_unresolved::(CURRENT_API_VERSION); - let attributes = attributes - .map(|it| it.to_subtree_unresolved::(CURRENT_API_VERSION)); - - srv.expand( - lib, - &env, - current_dir, - ¯o_name, - macro_body, - attributes, - def_site, - call_site, - mixed_site, - ) - .map(|it| { - msg::FlatTree::new_raw::( - tt::SubtreeView::new(&it), - CURRENT_API_VERSION, - ) - }) - .map_err(|e| e.into_string().unwrap_or_default()) - .map_err(msg::PanicMessage) - }), - msg::SpanMode::RustAnalyzer => msg::Response::ExpandMacroExtended({ - let mut span_data_table = - msg::deserialize_span_data_index_map(&span_data_table); - - let def_site = span_data_table[def_site]; - let call_site = span_data_table[call_site]; - let mixed_site = span_data_table[mixed_site]; - - let macro_body = - macro_body.to_subtree_resolved(CURRENT_API_VERSION, &span_data_table); - let attributes = attributes.map(|it| { - it.to_subtree_resolved(CURRENT_API_VERSION, &span_data_table) - }); - srv.expand( - lib, - &env, - current_dir, - ¯o_name, - macro_body, - attributes, - def_site, - call_site, - mixed_site, - ) - .map(|it| { - ( - msg::FlatTree::new( - tt::SubtreeView::new(&it), - CURRENT_API_VERSION, - &mut span_data_table, - ), - msg::serialize_span_data_index_map(&span_data_table), - ) - }) - .map(|(tree, span_data_table)| msg::ExpandMacroExtended { - tree, - span_data_table, - }) - .map_err(|e| e.into_string().unwrap_or_default()) - .map_err(msg::PanicMessage) - }), - } - } - msg::Request::ApiVersionCheck {} => msg::Response::ApiVersionCheck(CURRENT_API_VERSION), - msg::Request::SetConfig(config) => { - span_mode = config.span_mode; - msg::Response::SetConfig(config) - } - }; - - write_response(res)?; - } - - Ok(()) -}