From ee8da29feb2be6a1e8a5c891840ff89bd452df46 Mon Sep 17 00:00:00 2001 From: WhySoBad <49595640+WhySoBad@users.noreply.github.com> Date: Wed, 13 May 2026 15:16:36 +0200 Subject: [PATCH] Add support for socket read and write timeouts --- src/tools/miri/src/clock.rs | 6 +- src/tools/miri/src/helpers.rs | 26 --- src/tools/miri/src/provenance_gc.rs | 2 +- src/tools/miri/src/shims/time.rs | 47 ++++ src/tools/miri/src/shims/unix/socket.rs | 207 ++++++++++++++---- .../miri/tests/pass-dep/libc/libc-socket.rs | 97 +++++++- src/tools/miri/tests/pass/shims/socket.rs | 56 +++++ 7 files changed, 366 insertions(+), 75 deletions(-) diff --git a/src/tools/miri/src/clock.rs b/src/tools/miri/src/clock.rs index 655697509552..25b645189ce2 100644 --- a/src/tools/miri/src/clock.rs +++ b/src/tools/miri/src/clock.rs @@ -11,12 +11,12 @@ const NANOSECONDS_PER_BASIC_BLOCK: u128 = 5000; /// An instant (a fixed moment in time) in Miri's monotone clock. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Instant { kind: InstantKind, } -#[derive(Debug)] +#[derive(Clone, Debug)] enum InstantKind { Host(StdInstant), Virtual { nanoseconds: u128 }, @@ -134,7 +134,7 @@ pub fn now(&self) -> Instant { } /// A deadline for some event to occur. -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum Deadline { Monotonic(Instant), RealTime(SystemTime), diff --git a/src/tools/miri/src/helpers.rs b/src/tools/miri/src/helpers.rs index 349cb2d66482..c65ff125f61c 100644 --- a/src/tools/miri/src/helpers.rs +++ b/src/tools/miri/src/helpers.rs @@ -1,6 +1,5 @@ use std::num::NonZero; use std::sync::Mutex; -use std::time::Duration; use std::{cmp, iter}; use rand::Rng; @@ -714,31 +713,6 @@ fn deref_pointer_and_write( this.write_scalar(value, &value_place) } - /// Parse a `timespec` struct and return it as a `std::time::Duration`. It returns `None` - /// if the value in the `timespec` struct is invalid. Some libc functions will return - /// `EINVAL` in this case. - fn read_timespec(&mut self, tp: &MPlaceTy<'tcx>) -> InterpResult<'tcx, Option> { - let this = self.eval_context_mut(); - let seconds_place = this.project_field(tp, FieldIdx::ZERO)?; - let seconds_scalar = this.read_scalar(&seconds_place)?; - let seconds = seconds_scalar.to_target_isize(this)?; - let nanoseconds_place = this.project_field(tp, FieldIdx::ONE)?; - let nanoseconds_scalar = this.read_scalar(&nanoseconds_place)?; - let nanoseconds = nanoseconds_scalar.to_target_isize(this)?; - - interp_ok(try { - // tv_sec must be non-negative. - let seconds: u64 = seconds.try_into().ok()?; - // tv_nsec must be non-negative. - let nanoseconds: u32 = nanoseconds.try_into().ok()?; - if nanoseconds >= 1_000_000_000 { - // tv_nsec must not be greater than 999,999,999. - None? - } - Duration::new(seconds, nanoseconds) - }) - } - /// Read bytes from a byte slice. fn read_byte_slice<'a>(&'a self, slice: &ImmTy<'tcx>) -> InterpResult<'tcx, &'a [u8]> where diff --git a/src/tools/miri/src/provenance_gc.rs b/src/tools/miri/src/provenance_gc.rs index 3656a9eaa87c..c292f764d6d1 100644 --- a/src/tools/miri/src/provenance_gc.rs +++ b/src/tools/miri/src/provenance_gc.rs @@ -19,7 +19,7 @@ fn visit_provenance(&self, _visit: &mut VisitWith<'_>) {} )+ } } -no_provenance!(i8 i16 i32 i64 isize u8 u16 u32 u64 usize bool ThreadId); +no_provenance!(i8 i16 i32 i64 isize u8 u16 u32 u64 usize bool ThreadId Deadline); impl VisitProvenance for &'static str { fn visit_provenance(&self, _visit: &mut VisitWith<'_>) {} diff --git a/src/tools/miri/src/shims/time.rs b/src/tools/miri/src/shims/time.rs index 9dfce51d2ea4..acc0d5bc8841 100644 --- a/src/tools/miri/src/shims/time.rs +++ b/src/tools/miri/src/shims/time.rs @@ -459,4 +459,51 @@ fn Sleep(&mut self, timeout: &OpTy<'tcx>) -> InterpResult<'tcx> { ); interp_ok(()) } + + /// Parse a `timespec` struct and return it as a [`Duration`]. It returns [`None`] + /// if the value in the `timespec` struct is invalid. Some libc functions will return + /// EINVAL in this case. + fn read_timespec(&mut self, tp: &MPlaceTy<'tcx>) -> InterpResult<'tcx, Option> { + let this = self.eval_context_mut(); + let sec_field = this.project_field_named(tp, "tv_sec")?; + let sec = this.read_scalar(&sec_field)?.to_int(sec_field.layout.size)?; + let nsec_field = this.project_field_named(tp, "tv_nsec")?; + let nsec = this.read_scalar(&nsec_field)?.to_int(nsec_field.layout.size)?; + + interp_ok(try { + // tv_sec must be non-negative. + let seconds: u64 = sec.try_into().ok()?; + // tv_nsec must be non-negative. + let nanoseconds: u32 = nsec.try_into().ok()?; + if nanoseconds >= 1_000_000_000 { + // tv_nsec must not be greater than 999,999,999. + None? + } + Duration::new(seconds, nanoseconds) + }) + } + + /// Parse a `timeval` struct and return it as a [`Duration`]. It returns [`None`] + /// if the value in the `timeval` struct is invalid. Some libc functions will return + /// EINVAL in this case. + fn read_timeval(&mut self, tp: &MPlaceTy<'tcx>) -> InterpResult<'tcx, Option> { + let this = self.eval_context_mut(); + let sec_field = this.project_field_named(tp, "tv_sec")?; + let sec = this.read_scalar(&sec_field)?.to_int(sec_field.layout.size)?; + + let usec_field = this.project_field_named(tp, "tv_usec")?; + let usec = this.read_scalar(&usec_field)?.to_int(usec_field.layout.size)?; + + interp_ok(try { + // tv_sec must be non-negative. + let seconds: u64 = sec.try_into().ok()?; + // tv_usec must be non-negative. + let microseconds: u32 = usec.try_into().ok()?; + if microseconds >= 1_000_000 { + // tv_usec must not be greater than 999,999. + None? + } + Duration::new(seconds, microseconds.strict_mul(1000)) + }) + } } diff --git a/src/tools/miri/src/shims/unix/socket.rs b/src/tools/miri/src/shims/unix/socket.rs index e7754ceedd7c..ae882f8ff3a4 100644 --- a/src/tools/miri/src/shims/unix/socket.rs +++ b/src/tools/miri/src/shims/unix/socket.rs @@ -3,6 +3,7 @@ use std::io::Read; use std::net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4}; use std::sync::atomic::AtomicBool; +use std::time::Duration; use mio::event::Source; use mio::net::{TcpListener, TcpStream}; @@ -67,6 +68,18 @@ struct Socket { io_readiness: RefCell, /// [`Some`] when the socket had an async error which has not yet been fetched via `SO_ERROR`. error: RefCell>, + /// Read timeout of the socket. [`None`] means that reads can block indefinitely. + /// The timeout is applied to the monotonic clock (the Unix specification doesn't + /// specify which clock to use, but the monotonic clock is more common for + /// relative timeouts). + /// This is ignored when the socket is non-blocking. + read_timeout: Cell>, + /// Write timeout of the socket. [`None`] means that writes can block indefinitely. + /// The timeout is applied to the monotonic clock (the Unix specification doesn't + /// specify which clock to use, but the monotonic clock is more common + /// for relative timeouts). + /// This is ignored when the socket is non-blocking. + write_timeout: Cell>, } impl FileDescription for Socket { @@ -108,14 +121,16 @@ fn read<'tcx>( assert!(communicate_allowed, "cannot have `Socket` with isolation enabled!"); let socket = self; + let deadline = ecx.action_deadline(socket.is_non_block.get(), socket.read_timeout.get()); ecx.ensure_connected( socket.clone(), - !socket.is_non_block.get(), + deadline.clone(), "read", callback!( @capture<'tcx> { socket: FileDescriptionRef, + deadline: Option, ptr: Pointer, len: usize, finish: DynMachineCallback<'tcx, Result>, @@ -134,8 +149,8 @@ fn read<'tcx>( finish.call(this, result) } else { // The socket is in blocking mode and thus the read call should block - // until we can read some bytes from the socket. - this.block_for_recv(socket, ptr, len, /* should_peek */ false, finish) + // until we can read some bytes from the socket or the timeout exceeded. + this.block_for_recv(socket, deadline, ptr, len, /* should_peek */ false, finish) } } ), @@ -153,14 +168,16 @@ fn write<'tcx>( assert!(communicate_allowed, "cannot have `Socket` with isolation enabled!"); let socket = self; + let deadline = ecx.action_deadline(socket.is_non_block.get(), socket.write_timeout.get()); ecx.ensure_connected( socket.clone(), - !socket.is_non_block.get(), + deadline.clone(), "write", callback!( @capture<'tcx> { socket: FileDescriptionRef, + deadline: Option, ptr: Pointer, len: usize, finish: DynMachineCallback<'tcx, Result> @@ -179,8 +196,8 @@ fn write<'tcx>( return finish.call(this, result) } else { // The socket is in blocking mode and thus the write call should block - // until we can write some bytes into the socket. - this.block_for_send(socket, ptr, len, finish) + // until we can write some bytes into the socket or the timeout exceeded. + this.block_for_send(socket, deadline, ptr, len, finish) } } ), @@ -353,6 +370,8 @@ fn socket( is_non_block: Cell::new(is_sock_nonblock), io_readiness: RefCell::new(BlockingIoSourceReadiness::empty()), error: RefCell::new(None), + read_timeout: Cell::new(None), + write_timeout: Cell::new(None), }); interp_ok(Scalar::from_i32(fds.insert(fd))) @@ -567,6 +586,17 @@ fn accept4( } else { // The socket is in blocking mode and thus the accept call should block // until an incoming connection is ready. + + if socket.read_timeout.get().is_some() { + // Some Unixes like Linux also apply the SO_RCVTIMEO socket option + // to `accept` calls: + // + // This is currently not supported by Miri. + throw_unsup_format!( + "accept4: blocking accept is not supported when SO_RCVTIMEO is non-zero" + ) + } + this.block_for_accept( socket, address_ptr, @@ -645,11 +675,20 @@ fn connect( // The socket is in blocking mode and thus the connect call should block // until the connection with the server is established. - let dest = dest.clone(); + if socket.write_timeout.get().is_some() { + // Some Unixes like Linux also apply the SO_SNDTIMEO socket option + // to `connect` calls: + // + // This is currently not supported by Miri. + throw_unsup_format!( + "connect: blocking connect is not supported when SO_SNDTIMEO is non-zero" + ) + } + let dest = dest.clone(); this.ensure_connected( socket.clone(), - /* should_wait */ true, + /* deadline */ None, "connect", callback!( @capture<'tcx> { @@ -729,29 +768,29 @@ fn send( ); } - // If either the operation or the socket is non-blocking, we don't want - // to wait until the connection is established. - let should_wait = !is_op_non_block && !socket.is_non_block.get(); + let is_non_block = is_op_non_block || socket.is_non_block.get(); + let deadline = this.action_deadline(is_non_block, socket.write_timeout.get()); let dest = dest.clone(); this.ensure_connected( socket.clone(), - should_wait, + deadline.clone(), "send", callback!( @capture<'tcx> { socket: FileDescriptionRef, + deadline: Option, flags: i32, buffer_ptr: Pointer, length: usize, - is_op_non_block: bool, + is_non_block: bool, dest: MPlaceTy<'tcx>, } |this, result: Result<(), ()>| { if result.is_err() { return this.set_errno_and_return_neg1(LibcError("ENOTCONN"), &dest) } - if is_op_non_block || socket.is_non_block.get() { + if is_non_block { // We have a non-blocking operation or a non-blocking socket and // thus don't want to block until we can send. match this.try_non_block_send(&socket, buffer_ptr, length)? { @@ -760,9 +799,10 @@ fn send( } } else { // The socket is in blocking mode and thus the send call should block - // until we can send some bytes into the socket. + // until we can send some bytes into the socket or the timeout exceeded. this.block_for_send( socket, + deadline, buffer_ptr, length, callback!(@capture<'tcx> { @@ -850,29 +890,29 @@ fn recv( ); } - // If either the operation or the socket is non-blocking, we don't want - // to wait until the connection is established. - let should_wait = !is_op_non_block && !socket.is_non_block.get(); + let is_non_block = is_op_non_block || socket.is_non_block.get(); + let deadline = this.action_deadline(is_non_block, socket.read_timeout.get()); let dest = dest.clone(); this.ensure_connected( socket.clone(), - should_wait, + deadline.clone(), "recv", callback!( @capture<'tcx> { socket: FileDescriptionRef, + deadline: Option, buffer_ptr: Pointer, length: usize, should_peek: bool, - is_op_non_block: bool, + is_non_block: bool, dest: MPlaceTy<'tcx>, } |this, result: Result<(), ()>| { if result.is_err() { return this.set_errno_and_return_neg1(LibcError("ENOTCONN"), &dest) } - if is_op_non_block || socket.is_non_block.get() { + if is_non_block { // We have a non-blocking operation or a non-blocking socket and // thus don't want to block until we can receive. match this.try_non_block_recv(&socket, buffer_ptr, length, should_peek)? { @@ -881,9 +921,10 @@ fn recv( } } else { // The socket is in blocking mode and thus the receive call should block - // until we can receive some bytes from the socket. + // until we can receive some bytes from the socket or the timeout exceeded. this.block_for_recv( socket, + deadline, buffer_ptr, length, should_peek, @@ -930,6 +971,8 @@ fn setsockopt( }; if level == this.eval_libc_i32("SOL_SOCKET") { + let opt_so_rcvtimeo = this.eval_libc_i32("SO_RCVTIMEO"); + let opt_so_sndtimeo = this.eval_libc_i32("SO_SNDTIMEO"); let opt_so_reuseaddr = this.eval_libc_i32("SO_REUSEADDR"); if matches!(this.tcx.sess.target.os, Os::MacOs | Os::FreeBsd | Os::NetBsd) { @@ -950,6 +993,25 @@ fn setsockopt( } } + if option_name == opt_so_rcvtimeo || option_name == opt_so_sndtimeo { + let timeval_layout = this.libc_ty_layout("timeval"); + let option_value = this.ptr_to_mplace(option_value_ptr, timeval_layout); + + let timeout = match this.read_timeval(&option_value)? { + None => return this.set_errno_and_return_neg1_i32(LibcError("EINVAL")), + Some(Duration::ZERO) => None, + Some(duration) => Some(duration), + }; + + if option_name == opt_so_rcvtimeo { + socket.read_timeout.set(timeout); + } else { + socket.write_timeout.set(timeout); + } + + return interp_ok(Scalar::from_i32(0)); + } + if option_name == opt_so_reuseaddr { if option_len != 4 { // Option value should be C-int which is usually 4 bytes. @@ -1085,6 +1147,8 @@ fn getsockopt( // buffer, in which case we have to truncate. let value_buffer = if level == this.eval_libc_i32("SOL_SOCKET") { let opt_so_error = this.eval_libc_i32("SO_ERROR"); + let opt_so_rcvtimeo = this.eval_libc_i32("SO_RCVTIMEO"); + let opt_so_sndtimeo = this.eval_libc_i32("SO_SNDTIMEO"); if option_name == opt_so_error { // Reading SO_ERROR should always return the latest async error. Because our stored @@ -1109,6 +1173,28 @@ fn getsockopt( let value_buffer = this.allocate(this.machine.layouts.i32, MemoryKind::Stack)?; this.write_int(return_value, &value_buffer)?; value_buffer + } else if option_name == opt_so_rcvtimeo || option_name == opt_so_sndtimeo { + let timeout = if option_name == opt_so_rcvtimeo { + socket.read_timeout.get() + } else { + socket.write_timeout.get() + } + .unwrap_or_default(); + + let secs = timeout.as_secs(); + let usecs = timeout.subsec_micros(); + + let timeval_layout = this.libc_ty_layout("timeval"); + // Allocate new buffer on the stack with the `timeval` layout. + let timeval_buffer = this.allocate(timeval_layout, MemoryKind::Stack)?; + + let sec_field = this.project_field_named(&timeval_buffer, "tv_sec")?; + this.write_int(secs, &sec_field)?; + + let usec_field = this.project_field_named(&timeval_buffer, "tv_usec")?; + this.write_int(usecs, &usec_field)?; + + timeval_buffer } else { throw_unsup_format!( "getsockopt: option {option_name:#x} is unsupported for level SOL_SOCKET", @@ -1312,7 +1398,8 @@ fn getpeername( // UNIX targets should return ENOTCONN when the connection is not yet established. this.ensure_connected( socket.clone(), - /* should_wait */ false, + // Check whether the socket is connected without blocking. + Some(this.machine.monotonic_clock.now().into()), "getpeername", callback!( @capture<'tcx> { @@ -1415,6 +1502,28 @@ fn shutdown(&mut self, socket: &OpTy<'tcx>, how: &OpTy<'tcx>) -> InterpResult<'t impl<'tcx> EvalContextPrivExt<'tcx> for crate::MiriInterpCx<'tcx> {} trait EvalContextPrivExt<'tcx>: crate::MiriInterpCxExt<'tcx> { + /// Get the deadline for an action (e.g. reading or writing). + /// When `is_non_block` is [`true`], the returned deadline is "now", i.e., + /// we wake up immediately if the action cannot be completed. + /// If `action_timeout` is `Some(duration)`, the returned deadline is in the + /// future be the specified `duration`. Otherwise, no deadline ([`None`]) is + /// returned, indicating that the action can block indefinitely. + fn action_deadline( + &self, + is_non_block: bool, + action_timeout: Option, + ) -> Option { + let this = self.eval_context_ref(); + + if is_non_block { + // Non-blocking sockets always have a zero timeout. + Some(this.machine.monotonic_clock.now().into()) + } else { + action_timeout + .map(|duration| this.machine.monotonic_clock.now().add_lossy(duration).into()) + } + } + /// Block the thread until there's an incoming connection or an error occurred. /// /// This recursively calls itself should the operation still block for some reason. @@ -1433,19 +1542,23 @@ fn block_for_accept( this.block_thread_for_io( socket.clone(), BlockingIoInterest::Read, - None, + /* deadline */ None, callback!(@capture<'tcx> { + socket: FileDescriptionRef, address_ptr: Pointer, address_len_ptr: Pointer, is_client_sock_nonblock: bool, - socket: FileDescriptionRef, dest: MPlaceTy<'tcx>, } |this, kind: UnblockKind| { - assert_eq!(kind, UnblockKind::Ready); - // Remove the blocking I/O interest for unblocking this thread. this.machine.blocking_io.remove_blocked_thread(socket.id(), this.machine.threads.active_thread()); + match kind { + UnblockKind::Ready => { /* fall-through to below */ }, + // When the read timeout is exceeded EAGAIN/EWOULDBLOCK is returned. + UnblockKind::TimedOut => return this.set_errno_and_return_neg1(LibcError("EWOULDBLOCK"), &dest) + } + match this.try_non_block_accept(&socket, address_ptr, address_len_ptr, is_client_sock_nonblock)? { Ok(sockfd) => { // We need to create the scalar using the destination size since @@ -1515,6 +1628,8 @@ fn try_non_block_accept( is_non_block: Cell::new(is_client_sock_nonblock), io_readiness: RefCell::new(BlockingIoSourceReadiness::empty()), error: RefCell::new(None), + read_timeout: Cell::new(None), + write_timeout: Cell::new(None), }); // Register the socket to the blocking I/O manager because // there is an associated host socket. @@ -1533,6 +1648,7 @@ fn try_non_block_accept( fn block_for_send( &mut self, socket: FileDescriptionRef, + deadline: Option, buffer_ptr: Pointer, length: usize, finish: DynMachineCallback<'tcx, Result>, @@ -1541,22 +1657,27 @@ fn block_for_send( this.block_thread_for_io( socket.clone(), BlockingIoInterest::Write, - None, + deadline.clone(), callback!(@capture<'tcx> { socket: FileDescriptionRef, + deadline: Option, buffer_ptr: Pointer, length: usize, finish: DynMachineCallback<'tcx, Result>, } |this, kind: UnblockKind| { - assert_eq!(kind, UnblockKind::Ready); - // Remove the blocking I/O interest for unblocking this thread. this.machine.blocking_io.remove_blocked_thread(socket.id(), this.machine.threads.active_thread()); + match kind { + UnblockKind::Ready => { /* fall-through to below */ }, + // When the write timeout is exceeded EAGAIN/EWOULDBLOCK is returned. + UnblockKind::TimedOut => return finish.call(this, Err(LibcError("EWOULDBLOCK"))) + } + match this.try_non_block_send(&socket, buffer_ptr, length)? { Err(IoError::HostError(e)) if e.kind() == io::ErrorKind::WouldBlock => { // We need to block the thread again as it would still block. - this.block_for_send(socket, buffer_ptr, length, finish) + this.block_for_send(socket, deadline, buffer_ptr, length, finish) }, result => finish.call(this, result) } @@ -1647,6 +1768,7 @@ fn try_non_block_send( fn block_for_recv( &mut self, socket: FileDescriptionRef, + deadline: Option, buffer_ptr: Pointer, length: usize, should_peek: bool, @@ -1656,23 +1778,28 @@ fn block_for_recv( this.block_thread_for_io( socket.clone(), BlockingIoInterest::Read, - None, + deadline.clone(), callback!(@capture<'tcx> { socket: FileDescriptionRef, + deadline: Option, buffer_ptr: Pointer, length: usize, should_peek: bool, finish: DynMachineCallback<'tcx, Result>, } |this, kind: UnblockKind| { - assert_eq!(kind, UnblockKind::Ready); - // Remove the blocking I/O interest for unblocking this thread. this.machine.blocking_io.remove_blocked_thread(socket.id(), this.machine.threads.active_thread()); + match kind { + UnblockKind::Ready => { /* fall-through to below */ }, + // When the read timeout is exceeded EAGAIN/EWOULDBLOCK is returned. + UnblockKind::TimedOut => return finish.call(this, Err(LibcError("EWOULDBLOCK"))) + } + match this.try_non_block_recv(&socket, buffer_ptr, length, should_peek)? { Err(IoError::HostError(e)) if e.kind() == io::ErrorKind::WouldBlock => { // We need to block the thread again as it would still block. - this.block_for_recv(socket, buffer_ptr, length, should_peek, finish) + this.block_for_recv(socket, deadline, buffer_ptr, length, should_peek, finish) }, result => finish.call(this, result) } @@ -1777,7 +1904,7 @@ fn try_non_block_recv( fn ensure_connected( &mut self, socket: FileDescriptionRef, - should_wait: bool, + deadline: Option, foreign_name: &'static str, action: DynMachineCallback<'tcx, Result<(), ()>>, ) -> InterpResult<'tcx> { @@ -1801,11 +1928,6 @@ fn ensure_connected( // We're currently connecting. Since the underlying mio socket is non-blocking, // the only way to determine whether we are done connecting is by polling. - // If we should wait until the connection is established, the timeout is `None`. - // Otherwise, we use a zero duration timeout, i.e. we return immediately - // (but we still go through the scheduler once -- which is fine). - let deadline = - if should_wait { None } else { Some(this.machine.monotonic_clock.now().into()) }; this.block_thread_for_io( socket.clone(), @@ -1814,7 +1936,6 @@ fn ensure_connected( callback!( @capture<'tcx> { socket: FileDescriptionRef, - should_wait: bool, foreign_name: &'static str, action: DynMachineCallback<'tcx, Result<(), ()>>, } |this, kind: UnblockKind| { @@ -1822,9 +1943,7 @@ fn ensure_connected( this.machine.blocking_io.remove_blocked_thread(socket.id(), this.machine.threads.active_thread()); if UnblockKind::TimedOut == kind { - // We can only time out when `should_wait` is false. // This then means that the socket is not yet connected. - assert!(!should_wait); return action.call(this, Err(())) } diff --git a/src/tools/miri/tests/pass-dep/libc/libc-socket.rs b/src/tools/miri/tests/pass-dep/libc/libc-socket.rs index 1ecfba3ff6ff..9262bfc8829a 100644 --- a/src/tools/miri/tests/pass-dep/libc/libc-socket.rs +++ b/src/tools/miri/tests/pass-dep/libc/libc-socket.rs @@ -7,7 +7,7 @@ mod utils; use std::io::ErrorKind; -use std::time::Duration; +use std::time::{Duration, Instant}; use std::{ptr, thread}; use libc_utils::*; @@ -55,6 +55,9 @@ fn main() { test_shutdown_writable_after_read_close(); test_getsockopt_truncate(); + + test_sockopt_sndtimeo(); + test_sockopt_rcvtimeo(); } /// Test creating a socket and then closing it afterwards. @@ -759,3 +762,95 @@ fn test_getsockopt_truncate() { let long_ttl = unsafe { option_value.assume_init() }; assert_eq!(long_ttl, ttl); } + +/// Test setting and getting the SO_SNDTIMEO socket option. +/// Also test that writes don't block indefinitely when we +/// have a nonzero timeout. +fn test_sockopt_sndtimeo() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + net::connect_ipv4(client_sockfd, addr).unwrap(); + net::accept_ipv4(server_sockfd).unwrap(); + + let timeout = + net::getsockopt::(client_sockfd, libc::SOL_SOCKET, libc::SO_SNDTIMEO) + .unwrap(); + // By default, no write timeout should be set. + assert_eq!(timeout.tv_sec, 0); + assert_eq!(timeout.tv_usec, 0); + + // A 50 millisecond timeout. + let short_timeout = libc::timeval { tv_sec: 0, tv_usec: 50_000 }; + net::setsockopt(client_sockfd, libc::SOL_SOCKET, libc::SO_SNDTIMEO, short_timeout).unwrap(); + + let timeout = + net::getsockopt::(client_sockfd, libc::SOL_SOCKET, libc::SO_SNDTIMEO) + .unwrap(); + // We should now read the same value as we wrote above. + assert_eq!(timeout.tv_sec, short_timeout.tv_sec); + assert_eq!(timeout.tv_usec, short_timeout.tv_usec); + + let buffer = [1u8; 32_000]; + loop { + let before = Instant::now(); + let result = unsafe { + errno_result(libc::write(client_sockfd, buffer.as_ptr().cast(), buffer.len())) + }; + match result { + Ok(_) => { /* continue to fill up buffer */ } + // When we get an EAGAIN/EWOULDBLOCK when writing into a blocking socket, we know + // it's because of the write timeout exceeding because the write buffer + // is full. + Err(err) if err.kind() == ErrorKind::WouldBlock => { + // The last write should return an EAGAIN/EWOULDBLOCK after ~50ms instead + // of blocking indefinitely. + assert!(Instant::now().duration_since(before) >= Duration::from_millis(50)); + break; + } + Err(err) => panic!("unexpected error whilst filling up buffer: {err}"), + } + } +} + +/// Test setting and getting the SO_RCVTIMEO socket option. +/// Also test that reads don't block indefinitely when we +/// have a nonzero timeout. +fn test_sockopt_rcvtimeo() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + net::connect_ipv4(client_sockfd, addr).unwrap(); + net::accept_ipv4(server_sockfd).unwrap(); + + let timeout = + net::getsockopt::(client_sockfd, libc::SOL_SOCKET, libc::SO_RCVTIMEO) + .unwrap(); + // By default, no read timeout should be set. + assert_eq!(timeout.tv_sec, 0); + assert_eq!(timeout.tv_usec, 0); + + // A 50 millisecond timeout. + let short_timeout = libc::timeval { tv_sec: 0, tv_usec: 50_000 }; + net::setsockopt(client_sockfd, libc::SOL_SOCKET, libc::SO_RCVTIMEO, short_timeout).unwrap(); + + let timeout = + net::getsockopt::(client_sockfd, libc::SOL_SOCKET, libc::SO_RCVTIMEO) + .unwrap(); + // We should now read the same value as we wrote above. + assert_eq!(timeout.tv_sec, short_timeout.tv_sec); + assert_eq!(timeout.tv_usec, short_timeout.tv_usec); + + let mut buffer = [0u8; 16]; + // The read should return an EAGAIN/EWOULDBLOCK after ~10ms instead of blocking indefinitely. + let before = Instant::now(); + let err = unsafe { + errno_result(libc::read(client_sockfd, buffer.as_mut_ptr().cast(), buffer.len())) + .unwrap_err() + }; + assert_eq!(err.kind(), ErrorKind::WouldBlock); + // Ensure that we blocked for at least 50 milliseconds. + assert!(Instant::now().duration_since(before) >= Duration::from_millis(50)) +} diff --git a/src/tools/miri/tests/pass/shims/socket.rs b/src/tools/miri/tests/pass/shims/socket.rs index 335df02b54f8..785be17e2c0d 100644 --- a/src/tools/miri/tests/pass/shims/socket.rs +++ b/src/tools/miri/tests/pass/shims/socket.rs @@ -4,6 +4,7 @@ use std::io::{ErrorKind, Read, Write}; use std::net::{Shutdown, TcpListener, TcpStream}; use std::thread; +use std::time::Duration; const TEST_BYTES: &[u8] = b"these are some test bytes!"; @@ -17,6 +18,8 @@ fn main() { test_shutdown(); test_sockopt_ttl(); test_sockopt_nodelay(); + test_sockopt_read_timeout(); + test_sockopt_write_timeout(); } fn test_create_ipv4_listener() { @@ -167,3 +170,56 @@ fn test_sockopt_nodelay() { stream.set_nodelay(false).unwrap(); assert_eq!(stream.nodelay().unwrap(), false); } + +/// Test setting and reading the SNDTIMEO socket option. +/// This also tests that a read won't block indefinitely +/// when the read timeout is set to [`Some`] duration. +fn test_sockopt_read_timeout() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + + let mut stream = TcpStream::connect(address).unwrap(); + let _other_end = listener.accept().unwrap(); + + // By default, reads on blocking sockets should block indefinitely. + assert_eq!(stream.read_timeout().unwrap(), None); + + let short_read_timeout = Some(Duration::from_millis(10)); + stream.set_read_timeout(short_read_timeout).unwrap(); + assert_eq!(stream.read_timeout().unwrap(), short_read_timeout); + + let mut buffer = [0u8; 128]; + // This should not block indefinitely and instead return EAGAIN/EWOULDBLOCK. + let err = stream.read(&mut buffer).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::WouldBlock); +} + +/// Test setting and reading the RCVTIMEO socket option. +/// This also tests that a write won't block indefinitely when +/// the write timeout is set to [`Some`] duration. +fn test_sockopt_write_timeout() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + + let mut stream = TcpStream::connect(address).unwrap(); + let _other_end = listener.accept().unwrap(); + + // By default, writes on blocking sockets should block indefinitely. + assert_eq!(stream.write_timeout().unwrap(), None); + + let short_write_timeout = Some(Duration::from_millis(10)); + stream.set_write_timeout(short_write_timeout).unwrap(); + assert_eq!(stream.write_timeout().unwrap(), short_write_timeout); + + let fill_buffer = [1u8; 1024]; + loop { + match stream.write_all(&fill_buffer) { + Ok(_) => { /* continue to fill up buffer */ } + // When we get an EAGAIN/EWOULDBLOCK when writing into a blocking socket, + // we know it's because of the write timeout exceeding because the write + // buffer is full. + Err(err) if err.kind() == ErrorKind::WouldBlock => break, + Err(err) => panic!("unexpected error whilst filling up buffer: {err}"), + } + } +}