Add simple getsockopt shim for TTL and SO_ERROR

This commit is contained in:
WhySoBad
2026-04-21 22:32:47 +02:00
parent 8e9c315b40
commit bba4079fed
10 changed files with 393 additions and 11 deletions
@@ -661,6 +661,17 @@ fn emulate_foreign_item_inner(
this.setsockopt(socket, level, option_name, option_value, option_len)?;
this.write_scalar(result, dest)?;
}
"getsockopt" => {
let [socket, level, option_name, option_value, option_len] = this.check_shim_sig(
shim_sig!(extern "C" fn(i32, i32, i32, *mut _, *mut _) -> i32),
link_name,
abi,
args,
)?;
let result =
this.getsockopt(socket, level, option_name, option_value, option_len)?;
this.write_scalar(result, dest)?;
}
"getsockname" => {
let [socket, address, address_len] = this.check_shim_sig(
shim_sig!(extern "C" fn(i32, *mut _, *mut _) -> i32),
+157 -5
View File
@@ -7,6 +7,7 @@
use mio::event::Source;
use mio::net::{TcpListener, TcpStream};
use rustc_abi::Size;
use rustc_const_eval::interpret::{InterpResult, interp_ok};
use rustc_middle::throw_unsup_format;
use rustc_target::spec::Os;
@@ -58,6 +59,8 @@ struct Socket {
is_non_block: Cell<bool>,
/// The current blocking I/O readiness of the file description.
io_readiness: RefCell<BlockingIoSourceReadiness>,
/// [`Some`] when the socket had an async error which has not yet been fetched via `SO_ERROR`.
error: RefCell<Option<io::Error>>,
}
impl FileDescription for Socket {
@@ -340,6 +343,7 @@ fn socket(
state: RefCell::new(SocketState::Initial),
is_non_block: Cell::new(is_sock_nonblock),
io_readiness: RefCell::new(BlockingIoSourceReadiness::empty()),
error: RefCell::new(None),
});
interp_ok(Scalar::from_i32(fds.insert(fd)))
@@ -950,6 +954,152 @@ fn setsockopt(
);
}
fn getsockopt(
&mut self,
socket: &OpTy<'tcx>,
level: &OpTy<'tcx>,
option_name: &OpTy<'tcx>,
option_value: &OpTy<'tcx>,
option_len: &OpTy<'tcx>,
) -> InterpResult<'tcx, Scalar> {
let this = self.eval_context_mut();
let socket = this.read_scalar(socket)?.to_i32()?;
let level = this.read_scalar(level)?.to_i32()?;
let option_name = this.read_scalar(option_name)?.to_i32()?;
// These two pointers are used to return the value: `len_ptr` initially stores how much space
// is available. If the actual value fits into that space, it is written to
// `value_ptr` and `len_ptr` is updated to represent how many bytes
// were actually written. If the value does not fit, it is silently truncated.
// Also see <https://pubs.opengroup.org/onlinepubs/9799919799/functions/getsockopt.html>.
let option_value_ptr = this.read_pointer(option_value)?;
let option_len_ptr = this.read_pointer(option_len)?;
// Get the file handle
let Some(fd) = this.machine.fds.get(socket) else {
return this.set_last_error_and_return_i32(LibcError("EBADF"));
};
let Some(socket) = fd.downcast::<Socket>() else {
// Man page specifies to return ENOTSOCK if `fd` is not a socket.
return this.set_last_error_and_return_i32(LibcError("ENOTSOCK"));
};
if option_value_ptr == Pointer::null() || option_len_ptr == Pointer::null() {
// This socket option returns a value and thus we need to return EFAULT
// when either the value or the length pointers are null pointers.
return this.set_last_error_and_return_i32(LibcError("EFAULT"));
}
let socklen_layout = this.libc_ty_layout("socklen_t");
let option_len_ptr_mplace = this.ptr_to_mplace(option_len_ptr, socklen_layout);
let option_len: usize = this
.read_scalar(&option_len_ptr_mplace)?
.to_int(socklen_layout.size)?
.try_into()
.unwrap();
// We need a temporary buffer as `option_value_ptr` might not point to a large enough
// buffer, in which case we have to truncate.
let value_buffer = if level == this.eval_libc_i32("SOL_SOCKET") {
let opt_so_error = this.eval_libc_i32("SO_ERROR");
if option_name == opt_so_error {
// Because `TcpStream::take_error()` and `TcpListener::take_error()` consume the latest async
// error, we know that our stored `socket.error` is outdated when `TcpStream::take_error()`/
// `TcpListener::take_error()` returns `Ok(Some(...))`.
// If they return `Ok(None)`, then we fall back to the stored `socket.error`.
let error = match &*socket.state.borrow() {
SocketState::Initial | SocketState::Bound(_) => socket.error.take(),
SocketState::Listening(listener) =>
listener.take_error().unwrap_or(socket.error.take()),
SocketState::Connecting(stream) | SocketState::Connected(stream) =>
stream.take_error().unwrap_or(socket.error.take()),
};
// Clear our own stored error -- it was either `take`n above or it is outdated.
socket.error.replace(None);
// We know there is no longer an async error and thus we need to update the
// I/O and epoll readiness of the socket.
socket.io_readiness.borrow_mut().error = false;
this.update_epoll_active_events(socket, /* force_edge */ false)?;
let return_value = match error {
Some(err) => this.io_error_to_errnum(err)?.to_i32()?,
// If there is no error, we write 0 into the option value buffer.
None => 0,
};
// Allocate new buffer on the stack with the `i32` layout.
let value_buffer = this.allocate(this.machine.layouts.i32, MemoryKind::Stack)?;
this.write_int(return_value, &value_buffer)?;
value_buffer
} else {
throw_unsup_format!(
"getsockopt: option {option_name:#x} is unsupported for level SOL_SOCKET",
);
}
} else if level == this.eval_libc_i32("IPPROTO_IP") {
let opt_ip_ttl = this.eval_libc_i32("IP_TTL");
if option_name == opt_ip_ttl {
let ttl = match &*socket.state.borrow() {
SocketState::Initial | SocketState::Bound(_) =>
throw_unsup_format!(
"getsockopt: reading option IP_TTL on level IPPROTO_IP is only supported \
on connected and listening sockets"
),
SocketState::Listening(listener) => listener.ttl(),
SocketState::Connecting(stream) | SocketState::Connected(stream) =>
stream.ttl(),
};
let ttl = match ttl {
Ok(ttl) => ttl,
Err(e) => return this.set_last_error_and_return_i32(e),
};
// Allocate new buffer on the stack with the `u32` layout.
let value_buffer = this.allocate(this.machine.layouts.u32, MemoryKind::Stack)?;
this.write_int(ttl, &value_buffer)?;
value_buffer
} else {
throw_unsup_format!(
"getsockopt: option {option_name:#x} is unsupported for level IPPROTO_IP",
);
}
} else {
throw_unsup_format!(
"getsockopt: level {level:#x} is unsupported, only SOL_SOCKET is allowed"
)
};
// Truncated size of the output value.
let output_value_len = value_buffer.layout.size.min(Size::from_bytes(option_len));
// Copy the truncated value into the buffer pointed to by `option_value_ptr`.
this.mem_copy(
value_buffer.ptr(),
option_value_ptr,
// Truncate the value to fit the provided buffer.
output_value_len,
// The buffers are guaranteed to not overlap since the `value_buffer`
// was just newly allocated on the stack.
true,
)?;
// Deallocate the value buffer as it was only needed to store the value and
// copy it into the buffer pointed to by `option_value_ptr`.
this.deallocate_ptr(value_buffer.ptr(), None, MemoryKind::Stack)?;
// On output, the length pointer contains the amount of bytes written -- not the size
// of the value before truncation.
this.write_scalar(
Scalar::from_uint(output_value_len.bytes(), socklen_layout.size),
&option_len_ptr_mplace,
)?;
interp_ok(Scalar::from_i32(0))
}
fn getsockname(
&mut self,
socket: &OpTy<'tcx>,
@@ -1232,6 +1382,7 @@ fn try_non_block_accept(
state: RefCell::new(SocketState::Connected(stream)),
is_non_block: Cell::new(is_client_sock_nonblock),
io_readiness: RefCell::new(BlockingIoSourceReadiness::empty()),
error: RefCell::new(None),
});
// Register the socket to the blocking I/O manager because
// there is an associated host socket.
@@ -1490,17 +1641,18 @@ fn ensure_connected(
};
// Manually check whether there were any errors since calling `connect`.
if let Ok(Some(_)) = stream.take_error() {
if let Ok(Some(err)) = stream.take_error() {
// There was an error during connecting and thus we
// return ENOTCONN. It's the program's responsibility
// to read SO_ERROR itself.
//
// Store the error such that we can return it when
// `getsockopt(SOL_SOCKET, SO_ERROR, ...)` is called on the socket.
socket.error.replace(Some(err));
// Go back to initial state since the only way of getting into the
// `Connecting` state is from the `Initial` state and at this point
// we know that the connection won't be established anymore.
//
// FIXME: We're currently just dropping the error information. Eventually
// we'll have to store it so that it can be recovered by the user.
*state = SocketState::Initial;
drop(state);
return action.call(this, Err(()))
@@ -272,7 +272,7 @@ fn write_socket_address(
.try_into()
.unwrap();
let (address_buffer, address_layout) = match address {
let address_buffer = match address {
SocketAddr::V4(address) => {
// IPv4 address bytes; already stored in network byte order.
let address_bytes = address.ip().octets();
@@ -310,7 +310,7 @@ fn write_socket_address(
let s_addr_field = this.project_field_named(&sin_addr_field, "s_addr")?;
this.write_bytes_ptr(s_addr_field.ptr(), address_bytes)?;
(address_buffer, sockaddr_in_layout)
address_buffer
}
SocketAddr::V6(address) => {
// IPv6 address bytes; already stored in network byte order.
@@ -363,7 +363,7 @@ fn write_socket_address(
let s6_addr_field = this.project_field_named(&sin6_addr_field, "s6_addr")?;
this.write_bytes_ptr(s6_addr_field.ptr(), address_bytes)?;
(address_buffer, sockaddr_in6_layout)
address_buffer
}
};
@@ -372,7 +372,7 @@ fn write_socket_address(
address_buffer.ptr(),
address_ptr,
// Truncate the address to fit the provided buffer.
address_layout.size.min(Size::from_bytes(address_buffer_len)),
address_buffer.layout.size.min(Size::from_bytes(address_buffer_len)),
// The buffers are guaranteed to not overlap since the `address_buffer`
// was just newly allocated on the stack.
true,
@@ -381,7 +381,7 @@ fn write_socket_address(
// copy it into the buffer pointed to by `address_ptr`.
this.deallocate_ptr(address_buffer.ptr(), None, MemoryKind::Stack)?;
// Size of the non-truncated address.
let address_len = address_layout.size.bytes();
let address_len = address_buffer.layout.size.bytes();
this.write_scalar(
Scalar::from_uint(address_len, socklen_layout.size),
@@ -151,6 +151,17 @@ fn emulate_foreign_item_inner(
let result = this.getaddrinfo(node, service, hints, res)?;
this.write_scalar(result, dest)?;
}
"__xnet_getsockopt" => {
let [socket, level, option_name, option_value, option_len] = this.check_shim_sig(
shim_sig!(extern "C" fn(i32, i32, i32, *mut _, *mut _) -> i32),
link_name,
abi,
args,
)?;
let result =
this.getsockopt(socket, level, option_name, option_value, option_len)?;
this.write_scalar(result, dest)?;
}
// Miscellaneous
"___errno" => {
@@ -21,6 +21,7 @@
fn main() {
test_connect_nonblock();
test_accept_nonblock();
test_connect_nonblock_err();
test_recv_nonblock();
#[cfg(not(windows_hosts))]
test_send_nonblock();
@@ -60,7 +61,10 @@ fn test_connect_nonblock() {
// Wait until we are done connecting.
check_epoll_wait::<8>(epfd, &[Ev { events: EPOLLOUT, data: client_sockfd }], -1);
// FIXME: Check SO_ERROR here once we implemented `getsockopt`.
// There should be no error during async connection.
let errno =
net::getsockopt::<libc::c_int>(client_sockfd, libc::SOL_SOCKET, libc::SO_ERROR).unwrap();
assert_eq!(errno, 0);
// We should now be connected and thus getting the peer name should work.
net::sockname_ipv4(|storage, len| unsafe { libc::getpeername(client_sockfd, storage, len) })
@@ -112,6 +116,54 @@ fn test_accept_nonblock() {
server_thread.join().unwrap();
}
/// Test that the SO_ERROR socket option is set when attempting to
/// connect to an unbound address without blocking.
fn test_connect_nonblock_err() {
let client_sockfd =
unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() };
let epfd = errno_result(unsafe { libc::epoll_create1(0) }).unwrap();
unsafe {
// Change client socket to be non-blocking.
errno_check(libc::fcntl(client_sockfd, libc::F_SETFL, libc::O_NONBLOCK));
}
// We cannot attempt to connect to a localhost address because
// it could be the case that a socket from another test is
// currently listening on `localhost:12321` because we bind to
// random ports everywhere. For `127.0.1.1` we know that it's a loopback
// address and thus exists but because it's not the standard loopback
// address we also assume that nothing is bound to it.
// The port `12321` is just a random non-zero port because Windows
// and Apple hosts return EADDRNOTAVAIL when attempting to connect to
// a zero port.
let addr = net::sock_addr_ipv4([127, 0, 1, 1], 12321);
// Non-blocking connect should fail with EINPROGRESS.
let err = net::connect_ipv4(client_sockfd, addr).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InProgress);
// Add interest for client socket.
epoll_ctl_add(epfd, client_sockfd, EPOLLOUT | EPOLLET | libc::EPOLLERR).unwrap();
// Wait until the socket has an error.
check_epoll_wait::<8>(
epfd,
&[Ev { events: libc::EPOLLERR | EPOLLOUT | EPOLLHUP, data: client_sockfd }],
-1,
);
let errno =
net::getsockopt::<libc::c_int>(client_sockfd, libc::SOL_SOCKET, libc::SO_ERROR).unwrap();
// Depending on the host we receive different error kinds. Thus, we only check
// that it's a nonzero error code.
assert!(errno != 0);
// Ensure that error readiness is cleared after reading SO_ERROR.
let readiness = current_epoll_readiness::<8>(client_sockfd, EPOLLET | EPOLLOUT | EPOLLERR);
assert!(readiness & EPOLLERR == 0);
}
/// Test receiving bytes from a connected stream without blocking.
/// Instead of busy waiting until we no longer receive EWOULDBLOCK when trying to
/// read from the client, we register the client socket to epoll and wait for
@@ -227,6 +227,11 @@ fn test_connect_nonblock() {
assert_eq!(err.kind(), ErrorKind::InProgress);
loop {
// There should be no error during async connection.
let errno = net::getsockopt::<libc::c_int>(client_sockfd, libc::SOL_SOCKET, libc::SO_ERROR)
.unwrap();
assert_eq!(errno, 0);
let result = net::sockname_ipv4(|storage, len| unsafe {
libc::getpeername(client_sockfd, storage, len)
});
@@ -699,6 +704,11 @@ fn test_getpeername_ipv4_nonblock_no_peer() {
let err = net::connect_ipv4(client_sockfd, addr).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InProgress);
// There should be no error during async connection.
let errno =
net::getsockopt::<libc::c_int>(client_sockfd, libc::SOL_SOCKET, libc::SO_ERROR).unwrap();
assert_eq!(errno, 0);
// Since we're never accepting the connection, the socket should never be
// successfully connected and thus we should be unable to read the peername.
let Err(err) = net::sockname_ipv4(|storage, len| unsafe {
@@ -50,6 +50,8 @@ fn main() {
test_shutdown();
test_shutdown_readable_after_write_close();
test_shutdown_writable_after_read_close();
test_getsockopt_truncate();
}
/// Test creating a socket and then closing it afterwards.
@@ -640,3 +642,53 @@ fn test_shutdown_writable_after_read_close() {
server_thread.join().unwrap();
}
/// Test that the value gets silently truncated when a too small
/// length is provided and that the length gets reduced when the value
/// is smaller than the provided length.
fn test_getsockopt_truncate() {
let (sockfd, _) = net::make_listener_ipv4().unwrap();
// The actual TTL with a correctly sized buffer.
let ttl = net::getsockopt::<libc::c_uint>(sockfd, libc::IPPROTO_IP, libc::IP_TTL).unwrap();
let mut option_value = std::mem::MaybeUninit::<u32>::zeroed();
// The actual length is 4 bytes.
let mut short_option_len = 2 as libc::socklen_t;
errno_result(unsafe {
libc::getsockopt(
sockfd,
libc::IPPROTO_IP,
libc::IP_TTL,
option_value.as_mut_ptr().cast(),
&mut short_option_len,
)
})
.unwrap();
// Ensure that the size wasn't changed.
assert_eq!(short_option_len, 2);
let short_ttl = unsafe { option_value.assume_init() };
// Assert that the value was silently truncated.
assert_eq!(short_ttl.to_ne_bytes()[0..2], ttl.to_ne_bytes()[0..2]);
let mut option_value = std::mem::MaybeUninit::<u32>::zeroed();
// The actual length is 4 bytes.
let mut long_option_len = 6 as libc::socklen_t;
errno_result(unsafe {
libc::getsockopt(
sockfd,
libc::IPPROTO_IP,
libc::IP_TTL,
option_value.as_mut_ptr().cast(),
&mut long_option_len,
)
})
.unwrap();
// Ensure that the size was shortened to the actual length.
assert_eq!(long_option_len, 4);
let long_ttl = unsafe { option_value.assume_init() };
assert_eq!(long_ttl, ttl);
}
@@ -0,0 +1,56 @@
//@only-target: linux # We only support tokio on Linux
//@compile-flags: -Zmiri-disable-isolation
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
const TEST_BYTES: &[u8] = b"these are some test bytes!";
#[tokio::main]
async fn main() {
test_accept_and_connect().await;
test_read_write().await;
}
/// Test connecting and accepting a connection.
async fn test_accept_and_connect() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
// Get local address with randomized port to know where
// we need to connect to.
let address = listener.local_addr().unwrap();
// Start server thread.
tokio::spawn(async move {
let (_stream, _addr) = listener.accept().await.unwrap();
});
let _stream = TcpStream::connect(address).await.unwrap();
}
/// Test writing bytes into and reading bytes from a connected stream.
async fn test_read_write() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
// Get local address with randomized port to know where
// we need to connect to.
let address = listener.local_addr().unwrap();
// Start server thread.
tokio::spawn(async move {
let (mut stream, _addr) = listener.accept().await.unwrap();
stream.write_all(TEST_BYTES).await.unwrap();
let mut buffer = [0; TEST_BYTES.len()];
stream.read_exact(&mut buffer).await.unwrap();
assert_eq!(&buffer, TEST_BYTES);
});
let mut stream = TcpStream::connect(address).await.unwrap();
let mut buffer = [0; TEST_BYTES.len()];
stream.read_exact(&mut buffer).await.unwrap();
assert_eq!(&buffer, TEST_BYTES);
stream.write_all(TEST_BYTES).await.unwrap();
}
@@ -15,6 +15,7 @@ fn main() {
test_peek();
test_peer_addr();
test_shutdown();
test_sockopt_ttl();
}
fn test_create_ipv4_listener() {
@@ -152,3 +153,11 @@ fn test_shutdown() {
let _stream = handle.join().unwrap();
}
/// Test setting and reading the TTL socket option.
fn test_sockopt_ttl() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
listener.ttl().unwrap();
// TODO: Once we support setting the TTL we should also test it here.
}
+29
View File
@@ -388,6 +388,35 @@ pub fn setsockopt<T>(
Ok(())
}
/// Get a socket option. It's the caller's responsibility that `T` is
/// associated with the given socket option.
///
/// This function is directly copied from the standard library implementation
/// for sockets on UNIX targets.
pub fn getsockopt<T: Copy>(
sockfd: libc::c_int,
level: libc::c_int,
option_name: libc::c_int,
) -> io::Result<T> {
let mut option_value = std::mem::MaybeUninit::<T>::zeroed();
let mut option_len = size_of::<T>() as libc::socklen_t;
let provided_len = option_len;
errno_result(unsafe {
libc::getsockopt(
sockfd,
level,
option_name,
option_value.as_mut_ptr().cast(),
&mut option_len,
)
})?;
// Ensure that there was no truncation.
assert!(option_len == provided_len);
Ok(unsafe { option_value.assume_init() })
}
/// Wraps a call to a platform function that returns an IPv4 socket address.
/// Returns a tuple containing the actual return value of the performed
/// syscall and the written address of it.