diff --git a/lib/std/Io.zig b/lib/std/Io.zig index 0c25249218..6dc0e24731 100644 --- a/lib/std/Io.zig +++ b/lib/std/Io.zig @@ -728,6 +728,14 @@ pub const Limit = enum(usize) { return @enumFromInt(@min(@intFromEnum(a), @intFromEnum(b))); } + pub fn max(a: Limit, b: Limit) Limit { + if (a == .unlimited or b == .unlimited) { + return .unlimited; + } + + return @enumFromInt(@max(@intFromEnum(a), @intFromEnum(b))); + } + pub fn minInt(l: Limit, n: usize) usize { return @min(n, @intFromEnum(l)); } diff --git a/lib/std/compress/zstd.zig b/lib/std/compress/zstd.zig index 39073b51c5..51168889c6 100644 --- a/lib/std/compress/zstd.zig +++ b/lib/std/compress/zstd.zig @@ -88,6 +88,17 @@ fn testDecompress(gpa: std.mem.Allocator, compressed: []const u8) ![]u8 { return out.toOwnedSlice(); } +/// Create a `Decompress` from `compressed` and immediately discard all output. Returns the number +/// of discarded bytes. +fn testDiscard(gpa: std.mem.Allocator, compressed: []const u8) !usize { + const buf: []u8 = try gpa.alloc(u8, default_window_len + block_size_max); + defer gpa.free(buf); + + var in: std.Io.Reader = .fixed(compressed); + var zstd_stream: Decompress = .init(&in, buf, .{}); + return try zstd_stream.reader.discardRemaining(); +} + fn testExpectDecompress(uncompressed: []const u8, compressed: []const u8) !void { const gpa = std.testing.allocator; const result = try testDecompress(gpa, compressed); @@ -117,6 +128,8 @@ test Decompress { try testExpectDecompress(uncompressed, compressed3); try testExpectDecompress(uncompressed, compressed19); + try std.testing.expectEqual(uncompressed.len, testDiscard(std.testing.allocator, compressed3)); + try std.testing.expectEqual(uncompressed.len, testDiscard(std.testing.allocator, compressed19)); } test "partial magic number" { diff --git a/lib/std/compress/zstd/Decompress.zig b/lib/std/compress/zstd/Decompress.zig index 0acef462e7..cab1ee99f4 100644 --- a/lib/std/compress/zstd/Decompress.zig +++ b/lib/std/compress/zstd/Decompress.zig @@ -123,9 +123,13 @@ fn rebaseFallible(r: *Reader, capacity: usize) Reader.RebaseError!void { rebase(r, capacity); } +// Rebase the buffer, keeping at least the sliding window (`d.window_len` bytes) buffered fn rebase(r: *Reader, capacity: usize) void { const d: *Decompress = @alignCast(@fieldParentPtr("reader", r)); + // `capacity` must fit in the buffer along with the required sliding window assert(capacity <= r.buffer.len - d.window_len); + // According to the vtable contract, this function will only be called if the free space in the + // buffer cannot already fit `capacity` bytes assert(r.end + capacity > r.buffer.len); const discard_n = @min(r.seek, r.end - d.window_len); const keep = r.buffer[discard_n..r.end]; @@ -134,11 +138,27 @@ fn rebase(r: *Reader, capacity: usize) void { r.seek -= discard_n; } +/// Rebase `d.reader.buffer` as much as needed for a discard limited by `limit` +fn rebaseForDiscard(d: *Decompress, limit: std.Io.Limit) void { + // Number of bytes desired to rebase, always rebase for at least block_size + const desire_n = limit.max(Limit.limited(zstd.block_size_max)); + // Maximum number of bytes possible to rebase + const max_n = d.reader.buffer.len -| d.window_len; + // Number of bytes to rebase + const n = desire_n.minInt(max_n); + + // Current buffer free space + const current_cap = d.reader.buffer.len - d.reader.end; + if (current_cap < n) { + rebase(&d.reader, n); + } +} + /// This could be improved so that when an amount is discarded that includes an /// entire frame, skip decoding that frame. fn discardDirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize { const d: *Decompress = @alignCast(@fieldParentPtr("reader", r)); - rebase(r, d.window_len); + rebaseForDiscard(d, limit); var writer: Writer = .{ .vtable = &.{ .drain = std.Io.Writer.Discarding.drain, @@ -162,7 +182,7 @@ fn discardDirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize { fn discardIndirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize { const d: *Decompress = @alignCast(@fieldParentPtr("reader", r)); - rebase(r, d.window_len); + rebaseForDiscard(d, limit); var writer: Writer = .{ .buffer = r.buffer, .end = r.end,