add codec and framing to abstract encoding and decoding logic from run

This commit is contained in:
bit-aloo
2025-11-24 11:18:14 +05:30
parent c3708c7b43
commit ddefc4b7b4
10 changed files with 142 additions and 270 deletions
@@ -0,0 +1,12 @@
//! Protocol codec
use std::io;
use serde::de::DeserializeOwned;
use crate::framing::Framing;
pub trait Codec: Framing {
fn encode<T: serde::Serialize>(msg: &T) -> io::Result<Self::Buf>;
fn decode<T: DeserializeOwned>(buf: &mut Self::Buf) -> io::Result<T>;
}
@@ -0,0 +1,14 @@
//! Protocol framing
use std::io::{self, BufRead, Write};
pub trait Framing {
type Buf: Default;
fn read<'a, R: BufRead>(
inp: &mut R,
buf: &'a mut Self::Buf,
) -> io::Result<Option<&'a mut Self::Buf>>;
fn write<W: Write>(out: &mut W, buf: &Self::Buf) -> io::Result<()>;
}
@@ -14,14 +14,15 @@
use crate::{
ProcMacro, ProcMacroKind, ServerError,
codec::Codec,
legacy_protocol::{
json::{read_json, write_json},
json::JsonProtocol,
msg::{
ExpandMacro, ExpandMacroData, ExpnGlobals, FlatTree, Message, Request, Response,
ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map,
flat::serialize_span_data_index_map,
},
postcard::{read_postcard, write_postcard},
postcard::PostcardProtocol,
},
process::ProcMacroServerProcess,
version,
@@ -154,42 +155,26 @@ fn send_task(srv: &ProcMacroServerProcess, req: Request) -> Result<Response, Ser
}
if srv.use_postcard() {
srv.send_task(send_request_postcard, req)
srv.send_task(send_request::<PostcardProtocol>, req)
} else {
srv.send_task(send_request, req)
srv.send_task(send_request::<JsonProtocol>, req)
}
}
/// Sends a request to the server and reads the response.
fn send_request(
fn send_request<P: Codec>(
mut writer: &mut dyn Write,
mut reader: &mut dyn BufRead,
req: Request,
buf: &mut String,
buf: &mut P::Buf,
) -> Result<Option<Response>, ServerError> {
req.write(write_json, &mut writer).map_err(|err| ServerError {
req.write::<_, P>(&mut writer).map_err(|err| ServerError {
message: "failed to write request".into(),
io: Some(Arc::new(err)),
})?;
let res = Response::read(read_json, &mut reader, buf).map_err(|err| ServerError {
let res = Response::read::<_, P>(&mut reader, buf).map_err(|err| ServerError {
message: "failed to read response".into(),
io: Some(Arc::new(err)),
})?;
Ok(res)
}
fn send_request_postcard(
mut writer: &mut dyn Write,
mut reader: &mut dyn BufRead,
req: Request,
buf: &mut Vec<u8>,
) -> Result<Option<Response>, ServerError> {
req.write_postcard(write_postcard, &mut writer).map_err(|err| ServerError {
message: "failed to write request".into(),
io: Some(Arc::new(err)),
})?;
let res = Response::read_postcard(read_postcard, &mut reader, buf).map_err(|err| {
ServerError { message: "failed to read response".into(), io: Some(Arc::new(err)) }
})?;
Ok(res)
}
@@ -1,36 +1,58 @@
//! Protocol functions for json.
use std::io::{self, BufRead, Write};
/// Reads a JSON message from the input stream.
pub fn read_json<'a>(
inp: &mut impl BufRead,
buf: &'a mut String,
) -> io::Result<Option<&'a mut String>> {
loop {
buf.clear();
use serde::{Serialize, de::DeserializeOwned};
inp.read_line(buf)?;
buf.pop(); // Remove trailing '\n'
use crate::{codec::Codec, framing::Framing};
if buf.is_empty() {
return Ok(None);
pub struct JsonProtocol;
impl Framing for JsonProtocol {
type Buf = String;
fn read<'a, R: BufRead>(
inp: &mut R,
buf: &'a mut String,
) -> io::Result<Option<&'a mut String>> {
loop {
buf.clear();
inp.read_line(buf)?;
buf.pop(); // Remove trailing '\n'
if buf.is_empty() {
return Ok(None);
}
// Some ill behaved macro try to use stdout for debugging
// We ignore it here
if !buf.starts_with('{') {
tracing::error!("proc-macro tried to print : {}", buf);
continue;
}
return Ok(Some(buf));
}
}
// Some ill behaved macro try to use stdout for debugging
// We ignore it here
if !buf.starts_with('{') {
tracing::error!("proc-macro tried to print : {}", buf);
continue;
}
return Ok(Some(buf));
fn write<W: Write>(out: &mut W, buf: &String) -> io::Result<()> {
tracing::debug!("> {}", buf);
out.write_all(buf.as_bytes())?;
out.write_all(b"\n")?;
out.flush()
}
}
/// Writes a JSON message to the output stream.
pub fn write_json(out: &mut impl Write, msg: &String) -> io::Result<()> {
tracing::debug!("> {}", msg);
out.write_all(msg.as_bytes())?;
out.write_all(b"\n")?;
out.flush()
impl Codec for JsonProtocol {
fn encode<T: Serialize>(msg: &T) -> io::Result<String> {
Ok(serde_json::to_string(msg)?)
}
fn decode<T: DeserializeOwned>(buf: &mut String) -> io::Result<T> {
let mut deserializer = serde_json::Deserializer::from_str(buf);
// Note that some proc-macro generate very deep syntax tree
// We have to disable the current limit of serde here
deserializer.disable_recursion_limit();
Ok(T::deserialize(&mut deserializer)?)
}
}
@@ -8,10 +8,7 @@
use serde::de::DeserializeOwned;
use serde_derive::{Deserialize, Serialize};
use crate::{
ProcMacroKind,
legacy_protocol::postcard::{decode_cobs, encode_cobs},
};
use crate::{ProcMacroKind, codec::Codec};
/// Represents requests sent from the client to the proc-macro-srv.
#[derive(Debug, Serialize, Deserialize)]
@@ -152,60 +149,21 @@ fn skip_serializing_if(&self) -> bool {
}
pub trait Message: serde::Serialize + DeserializeOwned {
fn read<R: BufRead>(
from_proto: ProtocolRead<R, String>,
inp: &mut R,
buf: &mut String,
) -> io::Result<Option<Self>> {
Ok(match from_proto(inp, buf)? {
fn read<R: BufRead, C: Codec>(inp: &mut R, buf: &mut C::Buf) -> io::Result<Option<Self>> {
Ok(match C::read(inp, buf)? {
None => None,
Some(text) => {
let mut deserializer = serde_json::Deserializer::from_str(text);
// Note that some proc-macro generate very deep syntax tree
// We have to disable the current limit of serde here
deserializer.disable_recursion_limit();
Some(Self::deserialize(&mut deserializer)?)
}
Some(buf) => C::decode(buf)?,
})
}
fn write<W: Write>(self, to_proto: ProtocolWrite<W, String>, out: &mut W) -> io::Result<()> {
let text = serde_json::to_string(&self)?;
to_proto(out, &text)
}
fn read_postcard<R: BufRead>(
from_proto: ProtocolRead<R, Vec<u8>>,
inp: &mut R,
buf: &mut Vec<u8>,
) -> io::Result<Option<Self>> {
Ok(match from_proto(inp, buf)? {
None => None,
Some(buf) => Some(decode_cobs(buf)?),
})
}
fn write_postcard<W: Write>(
self,
to_proto: ProtocolWrite<W, Vec<u8>>,
out: &mut W,
) -> io::Result<()> {
let buf = encode_cobs(&self)?;
to_proto(out, &buf)
fn write<W: Write, C: Codec>(self, out: &mut W) -> io::Result<()> {
let value = C::encode(&self)?;
C::write(out, &value)
}
}
impl Message for Request {}
impl Message for Response {}
/// Type alias for a function that reads protocol messages from a buffered input stream.
#[allow(type_alias_bounds)]
type ProtocolRead<R: BufRead, Buf> =
for<'i, 'buf> fn(inp: &'i mut R, buf: &'buf mut Buf) -> io::Result<Option<&'buf mut Buf>>;
/// Type alias for a function that writes protocol messages to an output stream.
#[allow(type_alias_bounds)]
type ProtocolWrite<W: Write, Buf> =
for<'o, 'msg> fn(out: &'o mut W, msg: &'msg Buf) -> io::Result<()>;
#[cfg(test)]
mod tests {
use intern::{Symbol, sym};
@@ -2,28 +2,39 @@
use std::io::{self, BufRead, Write};
pub fn read_postcard<'a>(
input: &mut impl BufRead,
buf: &'a mut Vec<u8>,
) -> io::Result<Option<&'a mut Vec<u8>>> {
buf.clear();
let n = input.read_until(0, buf)?;
if n == 0 {
return Ok(None);
use serde::{Serialize, de::DeserializeOwned};
use crate::{codec::Codec, framing::Framing};
pub struct PostcardProtocol;
impl Framing for PostcardProtocol {
type Buf = Vec<u8>;
fn read<'a, R: BufRead>(
inp: &mut R,
buf: &'a mut Vec<u8>,
) -> io::Result<Option<&'a mut Vec<u8>>> {
buf.clear();
let n = inp.read_until(0, buf)?;
if n == 0 {
return Ok(None);
}
Ok(Some(buf))
}
fn write<W: Write>(out: &mut W, buf: &Vec<u8>) -> io::Result<()> {
out.write_all(buf)?;
out.flush()
}
Ok(Some(buf))
}
#[allow(clippy::ptr_arg)]
pub fn write_postcard(out: &mut impl Write, msg: &Vec<u8>) -> io::Result<()> {
out.write_all(msg)?;
out.flush()
}
impl Codec for PostcardProtocol {
fn encode<T: Serialize>(msg: &T) -> io::Result<Vec<u8>> {
postcard::to_allocvec_cobs(msg).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
pub fn encode_cobs<T: serde::Serialize>(value: &T) -> io::Result<Vec<u8>> {
postcard::to_allocvec_cobs(value).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
pub fn decode_cobs<T: serde::de::DeserializeOwned>(bytes: &mut [u8]) -> io::Result<T> {
postcard::from_bytes_cobs(bytes).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
fn decode<T: DeserializeOwned>(buf: &mut Self::Buf) -> io::Result<T> {
postcard::from_bytes_cobs(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
}
@@ -12,6 +12,8 @@
)]
#![allow(internal_features)]
mod codec;
mod framing;
pub mod legacy_protocol;
mod process;
@@ -19,7 +21,8 @@
use span::{ErasedFileAstId, FIXUP_ERASED_FILE_AST_ID_MARKER, Span};
use std::{fmt, io, sync::Arc, time::SystemTime};
use crate::process::ProcMacroServerProcess;
pub use crate::codec::Codec;
use crate::{legacy_protocol::SpanMode, process::ProcMacroServerProcess};
/// The versions of the server protocol
pub mod version {
@@ -123,7 +126,11 @@ pub fn spawn<'a>(
Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
> + Clone,
) -> io::Result<ProcMacroClient> {
let process = ProcMacroServerProcess::run(process_path, env, process::Protocol::default())?;
let process = ProcMacroServerProcess::run(
process_path,
env,
process::Protocol::Postcard { mode: SpanMode::Id },
)?;
Ok(ProcMacroClient { process: Arc::new(process), path: process_path.to_owned() })
}
@@ -34,12 +34,6 @@ pub(crate) enum Protocol {
Postcard { mode: SpanMode },
}
impl Default for Protocol {
fn default() -> Self {
Protocol::Postcard { mode: SpanMode::Id }
}
}
/// Maintains the state of the proc-macro server process.
#[derive(Debug)]
struct ProcessSrvState {
@@ -122,11 +116,10 @@ pub(crate) fn run<'a>(
srv.version = version;
if version >= version::RUST_ANALYZER_SPAN_SUPPORT
&& let Ok(mode) = srv.enable_rust_analyzer_spans()
&& let Ok(new_mode) = srv.enable_rust_analyzer_spans()
{
srv.protocol = match protocol {
Protocol::Postcard { .. } => Protocol::Postcard { mode },
Protocol::LegacyJson { .. } => Protocol::LegacyJson { mode },
match &mut srv.protocol {
Protocol::Postcard { mode } | Protocol::LegacyJson { mode } => *mode = new_mode,
};
}
@@ -18,7 +18,7 @@ postcard.workspace = true
clap = {version = "4.5.42", default-features = false, features = ["std"]}
[features]
default = ["postcard"]
default = []
sysroot-abi = ["proc-macro-srv/sysroot-abi", "proc-macro-api/sysroot-abi"]
in-rust-tree = ["proc-macro-srv/in-rust-tree", "sysroot-abi"]
@@ -2,13 +2,14 @@
use std::io;
use proc_macro_api::{
Codec,
legacy_protocol::{
json::{read_json, write_json},
json::JsonProtocol,
msg::{
self, ExpandMacroData, ExpnGlobals, Message, SpanMode, SpanTransformer,
deserialize_span_data_index_map, serialize_span_data_index_map,
},
postcard::{read_postcard, write_postcard},
postcard::PostcardProtocol,
},
version::CURRENT_API_VERSION,
};
@@ -36,12 +37,12 @@ fn span_for_token_id(
pub(crate) fn run(format: ProtocolFormat) -> io::Result<()> {
match format {
ProtocolFormat::Json => run_json(),
ProtocolFormat::Postcard => run_postcard(),
ProtocolFormat::Json => run_::<JsonProtocol>(),
ProtocolFormat::Postcard => run_::<PostcardProtocol>(),
}
}
fn run_json() -> io::Result<()> {
fn run_<C: Codec>() -> io::Result<()> {
fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind {
match kind {
proc_macro_srv::ProcMacroKind::CustomDerive => {
@@ -52,9 +53,9 @@ fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::Pro
}
}
let mut buf = String::new();
let mut read_request = || msg::Request::read(read_json, &mut io::stdin().lock(), &mut buf);
let write_response = |msg: msg::Response| msg.write(write_json, &mut io::stdout().lock());
let mut buf = C::Buf::default();
let mut read_request = || msg::Request::read::<_, C>(&mut io::stdin().lock(), &mut buf);
let write_response = |msg: msg::Response| msg.write::<_, C>(&mut io::stdout().lock());
let env = EnvSnapshot::default();
let srv = proc_macro_srv::ProcMacroSrv::new(&env);
@@ -170,134 +171,3 @@ fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::Pro
Ok(())
}
fn run_postcard() -> io::Result<()> {
fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind {
match kind {
proc_macro_srv::ProcMacroKind::CustomDerive => {
proc_macro_api::ProcMacroKind::CustomDerive
}
proc_macro_srv::ProcMacroKind::Bang => proc_macro_api::ProcMacroKind::Bang,
proc_macro_srv::ProcMacroKind::Attr => proc_macro_api::ProcMacroKind::Attr,
}
}
let mut buf = Vec::new();
let mut read_request =
|| msg::Request::read_postcard(read_postcard, &mut io::stdin().lock(), &mut buf);
let write_response =
|msg: msg::Response| msg.write_postcard(write_postcard, &mut io::stdout().lock());
let env = proc_macro_srv::EnvSnapshot::default();
let srv = proc_macro_srv::ProcMacroSrv::new(&env);
let mut span_mode = msg::SpanMode::Id;
while let Some(req) = read_request()? {
let res = match req {
msg::Request::ListMacros { dylib_path } => {
msg::Response::ListMacros(srv.list_macros(&dylib_path).map(|macros| {
macros.into_iter().map(|(name, kind)| (name, macro_kind_to_api(kind))).collect()
}))
}
msg::Request::ExpandMacro(task) => {
let msg::ExpandMacro {
lib,
env,
current_dir,
data:
msg::ExpandMacroData {
macro_body,
macro_name,
attributes,
has_global_spans:
msg::ExpnGlobals { serialize: _, def_site, call_site, mixed_site },
span_data_table,
},
} = *task;
match span_mode {
msg::SpanMode::Id => msg::Response::ExpandMacro({
let def_site = proc_macro_srv::SpanId(def_site as u32);
let call_site = proc_macro_srv::SpanId(call_site as u32);
let mixed_site = proc_macro_srv::SpanId(mixed_site as u32);
let macro_body =
macro_body.to_subtree_unresolved::<SpanTrans>(CURRENT_API_VERSION);
let attributes = attributes
.map(|it| it.to_subtree_unresolved::<SpanTrans>(CURRENT_API_VERSION));
srv.expand(
lib,
&env,
current_dir,
&macro_name,
macro_body,
attributes,
def_site,
call_site,
mixed_site,
)
.map(|it| {
msg::FlatTree::new_raw::<SpanTrans>(
tt::SubtreeView::new(&it),
CURRENT_API_VERSION,
)
})
.map_err(|e| e.into_string().unwrap_or_default())
.map_err(msg::PanicMessage)
}),
msg::SpanMode::RustAnalyzer => msg::Response::ExpandMacroExtended({
let mut span_data_table =
msg::deserialize_span_data_index_map(&span_data_table);
let def_site = span_data_table[def_site];
let call_site = span_data_table[call_site];
let mixed_site = span_data_table[mixed_site];
let macro_body =
macro_body.to_subtree_resolved(CURRENT_API_VERSION, &span_data_table);
let attributes = attributes.map(|it| {
it.to_subtree_resolved(CURRENT_API_VERSION, &span_data_table)
});
srv.expand(
lib,
&env,
current_dir,
&macro_name,
macro_body,
attributes,
def_site,
call_site,
mixed_site,
)
.map(|it| {
(
msg::FlatTree::new(
tt::SubtreeView::new(&it),
CURRENT_API_VERSION,
&mut span_data_table,
),
msg::serialize_span_data_index_map(&span_data_table),
)
})
.map(|(tree, span_data_table)| msg::ExpandMacroExtended {
tree,
span_data_table,
})
.map_err(|e| e.into_string().unwrap_or_default())
.map_err(msg::PanicMessage)
}),
}
}
msg::Request::ApiVersionCheck {} => msg::Response::ApiVersionCheck(CURRENT_API_VERSION),
msg::Request::SetConfig(config) => {
span_mode = config.span_mode;
msg::Response::SetConfig(config)
}
};
write_response(res)?;
}
Ok(())
}