diff --git a/src/tools/miri/src/shims/unix/foreign_items.rs b/src/tools/miri/src/shims/unix/foreign_items.rs index 3a47c9552788..0766352bab20 100644 --- a/src/tools/miri/src/shims/unix/foreign_items.rs +++ b/src/tools/miri/src/shims/unix/foreign_items.rs @@ -661,6 +661,17 @@ fn emulate_foreign_item_inner( this.setsockopt(socket, level, option_name, option_value, option_len)?; this.write_scalar(result, dest)?; } + "getsockopt" => { + let [socket, level, option_name, option_value, option_len] = this.check_shim_sig( + shim_sig!(extern "C" fn(i32, i32, i32, *mut _, *mut _) -> i32), + link_name, + abi, + args, + )?; + let result = + this.getsockopt(socket, level, option_name, option_value, option_len)?; + this.write_scalar(result, dest)?; + } "getsockname" => { let [socket, address, address_len] = this.check_shim_sig( shim_sig!(extern "C" fn(i32, *mut _, *mut _) -> i32), diff --git a/src/tools/miri/src/shims/unix/socket.rs b/src/tools/miri/src/shims/unix/socket.rs index ca0ddfd4726a..853e69c23411 100644 --- a/src/tools/miri/src/shims/unix/socket.rs +++ b/src/tools/miri/src/shims/unix/socket.rs @@ -7,6 +7,7 @@ use mio::event::Source; use mio::net::{TcpListener, TcpStream}; +use rustc_abi::Size; use rustc_const_eval::interpret::{InterpResult, interp_ok}; use rustc_middle::throw_unsup_format; use rustc_target::spec::Os; @@ -58,6 +59,8 @@ struct Socket { is_non_block: Cell, /// The current blocking I/O readiness of the file description. io_readiness: RefCell, + /// [`Some`] when the socket had an async error which has not yet been fetched via `SO_ERROR`. + error: RefCell>, } impl FileDescription for Socket { @@ -340,6 +343,7 @@ fn socket( state: RefCell::new(SocketState::Initial), is_non_block: Cell::new(is_sock_nonblock), io_readiness: RefCell::new(BlockingIoSourceReadiness::empty()), + error: RefCell::new(None), }); interp_ok(Scalar::from_i32(fds.insert(fd))) @@ -950,6 +954,152 @@ fn setsockopt( ); } + fn getsockopt( + &mut self, + socket: &OpTy<'tcx>, + level: &OpTy<'tcx>, + option_name: &OpTy<'tcx>, + option_value: &OpTy<'tcx>, + option_len: &OpTy<'tcx>, + ) -> InterpResult<'tcx, Scalar> { + let this = self.eval_context_mut(); + + let socket = this.read_scalar(socket)?.to_i32()?; + let level = this.read_scalar(level)?.to_i32()?; + let option_name = this.read_scalar(option_name)?.to_i32()?; + // These two pointers are used to return the value: `len_ptr` initially stores how much space + // is available. If the actual value fits into that space, it is written to + // `value_ptr` and `len_ptr` is updated to represent how many bytes + // were actually written. If the value does not fit, it is silently truncated. + // Also see . + let option_value_ptr = this.read_pointer(option_value)?; + let option_len_ptr = this.read_pointer(option_len)?; + + // Get the file handle + let Some(fd) = this.machine.fds.get(socket) else { + return this.set_last_error_and_return_i32(LibcError("EBADF")); + }; + + let Some(socket) = fd.downcast::() else { + // Man page specifies to return ENOTSOCK if `fd` is not a socket. + return this.set_last_error_and_return_i32(LibcError("ENOTSOCK")); + }; + + if option_value_ptr == Pointer::null() || option_len_ptr == Pointer::null() { + // This socket option returns a value and thus we need to return EFAULT + // when either the value or the length pointers are null pointers. + return this.set_last_error_and_return_i32(LibcError("EFAULT")); + } + + let socklen_layout = this.libc_ty_layout("socklen_t"); + let option_len_ptr_mplace = this.ptr_to_mplace(option_len_ptr, socklen_layout); + let option_len: usize = this + .read_scalar(&option_len_ptr_mplace)? + .to_int(socklen_layout.size)? + .try_into() + .unwrap(); + + // We need a temporary buffer as `option_value_ptr` might not point to a large enough + // buffer, in which case we have to truncate. + let value_buffer = if level == this.eval_libc_i32("SOL_SOCKET") { + let opt_so_error = this.eval_libc_i32("SO_ERROR"); + + if option_name == opt_so_error { + // Because `TcpStream::take_error()` and `TcpListener::take_error()` consume the latest async + // error, we know that our stored `socket.error` is outdated when `TcpStream::take_error()`/ + // `TcpListener::take_error()` returns `Ok(Some(...))`. + // If they return `Ok(None)`, then we fall back to the stored `socket.error`. + let error = match &*socket.state.borrow() { + SocketState::Initial | SocketState::Bound(_) => socket.error.take(), + SocketState::Listening(listener) => + listener.take_error().unwrap_or(socket.error.take()), + SocketState::Connecting(stream) | SocketState::Connected(stream) => + stream.take_error().unwrap_or(socket.error.take()), + }; + // Clear our own stored error -- it was either `take`n above or it is outdated. + socket.error.replace(None); + + // We know there is no longer an async error and thus we need to update the + // I/O and epoll readiness of the socket. + socket.io_readiness.borrow_mut().error = false; + this.update_epoll_active_events(socket, /* force_edge */ false)?; + + let return_value = match error { + Some(err) => this.io_error_to_errnum(err)?.to_i32()?, + // If there is no error, we write 0 into the option value buffer. + None => 0, + }; + + // Allocate new buffer on the stack with the `i32` layout. + let value_buffer = this.allocate(this.machine.layouts.i32, MemoryKind::Stack)?; + this.write_int(return_value, &value_buffer)?; + value_buffer + } else { + throw_unsup_format!( + "getsockopt: option {option_name:#x} is unsupported for level SOL_SOCKET", + ); + } + } else if level == this.eval_libc_i32("IPPROTO_IP") { + let opt_ip_ttl = this.eval_libc_i32("IP_TTL"); + + if option_name == opt_ip_ttl { + let ttl = match &*socket.state.borrow() { + SocketState::Initial | SocketState::Bound(_) => + throw_unsup_format!( + "getsockopt: reading option IP_TTL on level IPPROTO_IP is only supported \ + on connected and listening sockets" + ), + SocketState::Listening(listener) => listener.ttl(), + SocketState::Connecting(stream) | SocketState::Connected(stream) => + stream.ttl(), + }; + + let ttl = match ttl { + Ok(ttl) => ttl, + Err(e) => return this.set_last_error_and_return_i32(e), + }; + + // Allocate new buffer on the stack with the `u32` layout. + let value_buffer = this.allocate(this.machine.layouts.u32, MemoryKind::Stack)?; + this.write_int(ttl, &value_buffer)?; + value_buffer + } else { + throw_unsup_format!( + "getsockopt: option {option_name:#x} is unsupported for level IPPROTO_IP", + ); + } + } else { + throw_unsup_format!( + "getsockopt: level {level:#x} is unsupported, only SOL_SOCKET is allowed" + ) + }; + + // Truncated size of the output value. + let output_value_len = value_buffer.layout.size.min(Size::from_bytes(option_len)); + // Copy the truncated value into the buffer pointed to by `option_value_ptr`. + this.mem_copy( + value_buffer.ptr(), + option_value_ptr, + // Truncate the value to fit the provided buffer. + output_value_len, + // The buffers are guaranteed to not overlap since the `value_buffer` + // was just newly allocated on the stack. + true, + )?; + // Deallocate the value buffer as it was only needed to store the value and + // copy it into the buffer pointed to by `option_value_ptr`. + this.deallocate_ptr(value_buffer.ptr(), None, MemoryKind::Stack)?; + + // On output, the length pointer contains the amount of bytes written -- not the size + // of the value before truncation. + this.write_scalar( + Scalar::from_uint(output_value_len.bytes(), socklen_layout.size), + &option_len_ptr_mplace, + )?; + + interp_ok(Scalar::from_i32(0)) + } + fn getsockname( &mut self, socket: &OpTy<'tcx>, @@ -1232,6 +1382,7 @@ fn try_non_block_accept( state: RefCell::new(SocketState::Connected(stream)), is_non_block: Cell::new(is_client_sock_nonblock), io_readiness: RefCell::new(BlockingIoSourceReadiness::empty()), + error: RefCell::new(None), }); // Register the socket to the blocking I/O manager because // there is an associated host socket. @@ -1490,17 +1641,18 @@ fn ensure_connected( }; // Manually check whether there were any errors since calling `connect`. - if let Ok(Some(_)) = stream.take_error() { + if let Ok(Some(err)) = stream.take_error() { // There was an error during connecting and thus we // return ENOTCONN. It's the program's responsibility // to read SO_ERROR itself. - // + + // Store the error such that we can return it when + // `getsockopt(SOL_SOCKET, SO_ERROR, ...)` is called on the socket. + socket.error.replace(Some(err)); + // Go back to initial state since the only way of getting into the // `Connecting` state is from the `Initial` state and at this point // we know that the connection won't be established anymore. - // - // FIXME: We're currently just dropping the error information. Eventually - // we'll have to store it so that it can be recovered by the user. *state = SocketState::Initial; drop(state); return action.call(this, Err(())) diff --git a/src/tools/miri/src/shims/unix/socket_address.rs b/src/tools/miri/src/shims/unix/socket_address.rs index c0f7e8e1720f..90c316d0d116 100644 --- a/src/tools/miri/src/shims/unix/socket_address.rs +++ b/src/tools/miri/src/shims/unix/socket_address.rs @@ -272,7 +272,7 @@ fn write_socket_address( .try_into() .unwrap(); - let (address_buffer, address_layout) = match address { + let address_buffer = match address { SocketAddr::V4(address) => { // IPv4 address bytes; already stored in network byte order. let address_bytes = address.ip().octets(); @@ -310,7 +310,7 @@ fn write_socket_address( let s_addr_field = this.project_field_named(&sin_addr_field, "s_addr")?; this.write_bytes_ptr(s_addr_field.ptr(), address_bytes)?; - (address_buffer, sockaddr_in_layout) + address_buffer } SocketAddr::V6(address) => { // IPv6 address bytes; already stored in network byte order. @@ -363,7 +363,7 @@ fn write_socket_address( let s6_addr_field = this.project_field_named(&sin6_addr_field, "s6_addr")?; this.write_bytes_ptr(s6_addr_field.ptr(), address_bytes)?; - (address_buffer, sockaddr_in6_layout) + address_buffer } }; @@ -372,7 +372,7 @@ fn write_socket_address( address_buffer.ptr(), address_ptr, // Truncate the address to fit the provided buffer. - address_layout.size.min(Size::from_bytes(address_buffer_len)), + address_buffer.layout.size.min(Size::from_bytes(address_buffer_len)), // The buffers are guaranteed to not overlap since the `address_buffer` // was just newly allocated on the stack. true, @@ -381,7 +381,7 @@ fn write_socket_address( // copy it into the buffer pointed to by `address_ptr`. this.deallocate_ptr(address_buffer.ptr(), None, MemoryKind::Stack)?; // Size of the non-truncated address. - let address_len = address_layout.size.bytes(); + let address_len = address_buffer.layout.size.bytes(); this.write_scalar( Scalar::from_uint(address_len, socklen_layout.size), diff --git a/src/tools/miri/src/shims/unix/solarish/foreign_items.rs b/src/tools/miri/src/shims/unix/solarish/foreign_items.rs index 7f2af20a5a7e..37b665ceebd1 100644 --- a/src/tools/miri/src/shims/unix/solarish/foreign_items.rs +++ b/src/tools/miri/src/shims/unix/solarish/foreign_items.rs @@ -151,6 +151,17 @@ fn emulate_foreign_item_inner( let result = this.getaddrinfo(node, service, hints, res)?; this.write_scalar(result, dest)?; } + "__xnet_getsockopt" => { + let [socket, level, option_name, option_value, option_len] = this.check_shim_sig( + shim_sig!(extern "C" fn(i32, i32, i32, *mut _, *mut _) -> i32), + link_name, + abi, + args, + )?; + let result = + this.getsockopt(socket, level, option_name, option_value, option_len)?; + this.write_scalar(result, dest)?; + } // Miscellaneous "___errno" => { diff --git a/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking-epoll.rs b/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking-epoll.rs index b162319d7557..132f1c81d884 100644 --- a/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking-epoll.rs +++ b/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking-epoll.rs @@ -21,6 +21,7 @@ fn main() { test_connect_nonblock(); test_accept_nonblock(); + test_connect_nonblock_err(); test_recv_nonblock(); #[cfg(not(windows_hosts))] test_send_nonblock(); @@ -60,7 +61,10 @@ fn test_connect_nonblock() { // Wait until we are done connecting. check_epoll_wait::<8>(epfd, &[Ev { events: EPOLLOUT, data: client_sockfd }], -1); - // FIXME: Check SO_ERROR here once we implemented `getsockopt`. + // There should be no error during async connection. + let errno = + net::getsockopt::(client_sockfd, libc::SOL_SOCKET, libc::SO_ERROR).unwrap(); + assert_eq!(errno, 0); // We should now be connected and thus getting the peer name should work. net::sockname_ipv4(|storage, len| unsafe { libc::getpeername(client_sockfd, storage, len) }) @@ -112,6 +116,54 @@ fn test_accept_nonblock() { server_thread.join().unwrap(); } +/// Test that the SO_ERROR socket option is set when attempting to +/// connect to an unbound address without blocking. +fn test_connect_nonblock_err() { + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + let epfd = errno_result(unsafe { libc::epoll_create1(0) }).unwrap(); + + unsafe { + // Change client socket to be non-blocking. + errno_check(libc::fcntl(client_sockfd, libc::F_SETFL, libc::O_NONBLOCK)); + } + + // We cannot attempt to connect to a localhost address because + // it could be the case that a socket from another test is + // currently listening on `localhost:12321` because we bind to + // random ports everywhere. For `127.0.1.1` we know that it's a loopback + // address and thus exists but because it's not the standard loopback + // address we also assume that nothing is bound to it. + // The port `12321` is just a random non-zero port because Windows + // and Apple hosts return EADDRNOTAVAIL when attempting to connect to + // a zero port. + let addr = net::sock_addr_ipv4([127, 0, 1, 1], 12321); + + // Non-blocking connect should fail with EINPROGRESS. + let err = net::connect_ipv4(client_sockfd, addr).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InProgress); + + // Add interest for client socket. + epoll_ctl_add(epfd, client_sockfd, EPOLLOUT | EPOLLET | libc::EPOLLERR).unwrap(); + + // Wait until the socket has an error. + check_epoll_wait::<8>( + epfd, + &[Ev { events: libc::EPOLLERR | EPOLLOUT | EPOLLHUP, data: client_sockfd }], + -1, + ); + + let errno = + net::getsockopt::(client_sockfd, libc::SOL_SOCKET, libc::SO_ERROR).unwrap(); + // Depending on the host we receive different error kinds. Thus, we only check + // that it's a nonzero error code. + assert!(errno != 0); + + // Ensure that error readiness is cleared after reading SO_ERROR. + let readiness = current_epoll_readiness::<8>(client_sockfd, EPOLLET | EPOLLOUT | EPOLLERR); + assert!(readiness & EPOLLERR == 0); +} + /// Test receiving bytes from a connected stream without blocking. /// Instead of busy waiting until we no longer receive EWOULDBLOCK when trying to /// read from the client, we register the client socket to epoll and wait for diff --git a/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking.rs b/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking.rs index 40c821e858cd..236e72d4b4ec 100644 --- a/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking.rs +++ b/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking.rs @@ -227,6 +227,11 @@ fn test_connect_nonblock() { assert_eq!(err.kind(), ErrorKind::InProgress); loop { + // There should be no error during async connection. + let errno = net::getsockopt::(client_sockfd, libc::SOL_SOCKET, libc::SO_ERROR) + .unwrap(); + assert_eq!(errno, 0); + let result = net::sockname_ipv4(|storage, len| unsafe { libc::getpeername(client_sockfd, storage, len) }); @@ -699,6 +704,11 @@ fn test_getpeername_ipv4_nonblock_no_peer() { let err = net::connect_ipv4(client_sockfd, addr).unwrap_err(); assert_eq!(err.kind(), ErrorKind::InProgress); + // There should be no error during async connection. + let errno = + net::getsockopt::(client_sockfd, libc::SOL_SOCKET, libc::SO_ERROR).unwrap(); + assert_eq!(errno, 0); + // Since we're never accepting the connection, the socket should never be // successfully connected and thus we should be unable to read the peername. let Err(err) = net::sockname_ipv4(|storage, len| unsafe { diff --git a/src/tools/miri/tests/pass-dep/libc/libc-socket.rs b/src/tools/miri/tests/pass-dep/libc/libc-socket.rs index 33f99371eee6..f05c9c4da0c9 100644 --- a/src/tools/miri/tests/pass-dep/libc/libc-socket.rs +++ b/src/tools/miri/tests/pass-dep/libc/libc-socket.rs @@ -50,6 +50,8 @@ fn main() { test_shutdown(); test_shutdown_readable_after_write_close(); test_shutdown_writable_after_read_close(); + + test_getsockopt_truncate(); } /// Test creating a socket and then closing it afterwards. @@ -640,3 +642,53 @@ fn test_shutdown_writable_after_read_close() { server_thread.join().unwrap(); } + +/// Test that the value gets silently truncated when a too small +/// length is provided and that the length gets reduced when the value +/// is smaller than the provided length. +fn test_getsockopt_truncate() { + let (sockfd, _) = net::make_listener_ipv4().unwrap(); + + // The actual TTL with a correctly sized buffer. + let ttl = net::getsockopt::(sockfd, libc::IPPROTO_IP, libc::IP_TTL).unwrap(); + + let mut option_value = std::mem::MaybeUninit::::zeroed(); + // The actual length is 4 bytes. + let mut short_option_len = 2 as libc::socklen_t; + + errno_result(unsafe { + libc::getsockopt( + sockfd, + libc::IPPROTO_IP, + libc::IP_TTL, + option_value.as_mut_ptr().cast(), + &mut short_option_len, + ) + }) + .unwrap(); + // Ensure that the size wasn't changed. + assert_eq!(short_option_len, 2); + let short_ttl = unsafe { option_value.assume_init() }; + + // Assert that the value was silently truncated. + assert_eq!(short_ttl.to_ne_bytes()[0..2], ttl.to_ne_bytes()[0..2]); + + let mut option_value = std::mem::MaybeUninit::::zeroed(); + // The actual length is 4 bytes. + let mut long_option_len = 6 as libc::socklen_t; + + errno_result(unsafe { + libc::getsockopt( + sockfd, + libc::IPPROTO_IP, + libc::IP_TTL, + option_value.as_mut_ptr().cast(), + &mut long_option_len, + ) + }) + .unwrap(); + // Ensure that the size was shortened to the actual length. + assert_eq!(long_option_len, 4); + let long_ttl = unsafe { option_value.assume_init() }; + assert_eq!(long_ttl, ttl); +} diff --git a/src/tools/miri/tests/pass-dep/tokio/socket.rs b/src/tools/miri/tests/pass-dep/tokio/socket.rs new file mode 100644 index 000000000000..a8014fa8ce7b --- /dev/null +++ b/src/tools/miri/tests/pass-dep/tokio/socket.rs @@ -0,0 +1,56 @@ +//@only-target: linux # We only support tokio on Linux +//@compile-flags: -Zmiri-disable-isolation + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +const TEST_BYTES: &[u8] = b"these are some test bytes!"; + +#[tokio::main] +async fn main() { + test_accept_and_connect().await; + test_read_write().await; +} + +/// Test connecting and accepting a connection. +async fn test_accept_and_connect() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + // Get local address with randomized port to know where + // we need to connect to. + let address = listener.local_addr().unwrap(); + + // Start server thread. + tokio::spawn(async move { + let (_stream, _addr) = listener.accept().await.unwrap(); + }); + + let _stream = TcpStream::connect(address).await.unwrap(); +} + +/// Test writing bytes into and reading bytes from a connected stream. +async fn test_read_write() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + // Get local address with randomized port to know where + // we need to connect to. + let address = listener.local_addr().unwrap(); + + // Start server thread. + tokio::spawn(async move { + let (mut stream, _addr) = listener.accept().await.unwrap(); + + stream.write_all(TEST_BYTES).await.unwrap(); + + let mut buffer = [0; TEST_BYTES.len()]; + stream.read_exact(&mut buffer).await.unwrap(); + + assert_eq!(&buffer, TEST_BYTES); + }); + + let mut stream = TcpStream::connect(address).await.unwrap(); + + let mut buffer = [0; TEST_BYTES.len()]; + stream.read_exact(&mut buffer).await.unwrap(); + assert_eq!(&buffer, TEST_BYTES); + + stream.write_all(TEST_BYTES).await.unwrap(); +} diff --git a/src/tools/miri/tests/pass/shims/socket.rs b/src/tools/miri/tests/pass/shims/socket.rs index 957eefc628a3..4f448fa44780 100644 --- a/src/tools/miri/tests/pass/shims/socket.rs +++ b/src/tools/miri/tests/pass/shims/socket.rs @@ -15,6 +15,7 @@ fn main() { test_peek(); test_peer_addr(); test_shutdown(); + test_sockopt_ttl(); } fn test_create_ipv4_listener() { @@ -152,3 +153,11 @@ fn test_shutdown() { let _stream = handle.join().unwrap(); } + +/// Test setting and reading the TTL socket option. +fn test_sockopt_ttl() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + listener.ttl().unwrap(); + + // TODO: Once we support setting the TTL we should also test it here. +} diff --git a/src/tools/miri/tests/utils/libc.rs b/src/tools/miri/tests/utils/libc.rs index 89a610194ccc..8f4c8f4473c5 100644 --- a/src/tools/miri/tests/utils/libc.rs +++ b/src/tools/miri/tests/utils/libc.rs @@ -388,6 +388,35 @@ pub fn setsockopt( Ok(()) } + /// Get a socket option. It's the caller's responsibility that `T` is + /// associated with the given socket option. + /// + /// This function is directly copied from the standard library implementation + /// for sockets on UNIX targets. + pub fn getsockopt( + sockfd: libc::c_int, + level: libc::c_int, + option_name: libc::c_int, + ) -> io::Result { + let mut option_value = std::mem::MaybeUninit::::zeroed(); + let mut option_len = size_of::() as libc::socklen_t; + let provided_len = option_len; + + errno_result(unsafe { + libc::getsockopt( + sockfd, + level, + option_name, + option_value.as_mut_ptr().cast(), + &mut option_len, + ) + })?; + // Ensure that there was no truncation. + assert!(option_len == provided_len); + + Ok(unsafe { option_value.assume_init() }) + } + /// Wraps a call to a platform function that returns an IPv4 socket address. /// Returns a tuple containing the actual return value of the performed /// syscall and the written address of it.