diff --git a/lib/std/Io/RwLock.zig b/lib/std/Io/RwLock.zig index 7be4f026e1..de9bf86f36 100644 --- a/lib/std/Io/RwLock.zig +++ b/lib/std/Io/RwLock.zig @@ -48,6 +48,25 @@ pub fn lockUncancelable(rl: *RwLock, io: Io) void { rl.semaphore.waitUncancelable(io); } +pub fn lock(rl: *RwLock, io: Io) Io.Cancelable!void { + _ = @atomicRmw(usize, &rl.state, .Add, writer, .seq_cst); + rl.mutex.lock(io) catch |err| switch (err) { + error.Canceled => { + _ = @atomicRmw(usize, &rl.state, .Sub, writer, .seq_cst); + return error.Canceled; + }, + }; + + const state = @atomicRmw(usize, &rl.state, .Add, is_writing -% writer, .seq_cst); + if (state & reader_mask != 0) + rl.semaphore.wait(io) catch |err| switch (err) { + error.Canceled => { + rl.unlock(io); + return error.Canceled; + }, + }; +} + pub fn unlock(rl: *RwLock, io: Io) void { _ = @atomicRmw(usize, &rl.state, .And, ~is_writing, .seq_cst); rl.mutex.unlock(io); @@ -93,6 +112,24 @@ pub fn lockSharedUncancelable(rl: *RwLock, io: Io) void { rl.mutex.unlock(io); } +pub fn lockShared(rl: *RwLock, io: Io) Io.Cancelable!void { + var state = @atomicLoad(usize, &rl.state, .seq_cst); + while (state & (is_writing | writer_mask) == 0) { + state = @cmpxchgWeak( + usize, + &rl.state, + state, + state + reader, + .seq_cst, + .seq_cst, + ) orelse return; + } + + try rl.mutex.lock(io); + _ = @atomicRmw(usize, &rl.state, .Add, reader, .seq_cst); + rl.mutex.unlock(io); +} + pub fn unlockShared(rl: *RwLock, io: Io) void { const state = @atomicRmw(usize, &rl.state, .Sub, reader, .seq_cst); @@ -111,6 +148,10 @@ test "internal state" { rl.lockUncancelable(io); rl.unlock(io); try testing.expectEqual(rl, Io.RwLock.init); + + try rl.lock(io); + rl.unlock(io); + try testing.expectEqual(rl, Io.RwLock.init); } test "smoke test" { @@ -123,6 +164,11 @@ test "smoke test" { try testing.expect(!rl.tryLockShared(io)); rl.unlock(io); + try rl.lock(io); + try testing.expect(!rl.tryLock(io)); + try testing.expect(!rl.tryLockShared(io)); + rl.unlock(io); + try testing.expect(rl.tryLock(io)); try testing.expect(!rl.tryLock(io)); try testing.expect(!rl.tryLockShared(io)); @@ -236,3 +282,37 @@ test "concurrent access" { try testing.expect(run.writes == num_writes); try testing.expect(run.reads.raw >= num_reads); } + +test "lock canceling" { + const io = testing.io; + + var rl: Io.RwLock = .init; + + rl.lockSharedUncancelable(io); + var sfuture = io.concurrent(semaphoreLockCancel, .{ &rl, io }) catch |err| switch (err) { + error.ConcurrencyUnavailable => return error.SkipZigTest, + }; + try std.testing.expectEqual(error.Canceled, sfuture.cancel(io)); + rl.unlockShared(io); + try testing.expectEqual(rl, Io.RwLock.init); + + rl.lockUncancelable(io); + var mfuture = io.concurrent(mutexLockCancel, .{ &rl, io }) catch |err| switch (err) { + error.ConcurrencyUnavailable => return error.SkipZigTest, + }; + try std.testing.expectEqual(error.Canceled, mfuture.cancel(io)); + rl.unlock(io); + try testing.expectEqual(rl, Io.RwLock.init); +} + +fn semaphoreLockCancel(rl: *Io.RwLock, io: Io) !void { + try rl.lock(io); //tests semaphore cancelling +} + +fn mutexLockCancel(rl: *Io.RwLock, io: Io) !void { + //tests mutex canceling + try std.testing.expectEqual(error.Canceled, rl.lockShared(io)); + io.recancel(); + try std.testing.expectEqual(error.Canceled, rl.lock(io)); + return error.Canceled; +}