From e05af2da131b5d9353711ffb5979c67f4bd8b5af Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 1 May 2025 14:52:41 -0700 Subject: [PATCH] std.compress.zstd: it's compiling --- lib/std/compress/zstd.zig | 2 +- lib/std/compress/zstd/Decompress.zig | 509 ++++++++++++++++++++------- lib/std/io.zig | 12 +- lib/std/io/bit_reader.zig | 236 ------------- lib/std/io/buffered_atomic_file.zig | 55 --- 5 files changed, 379 insertions(+), 435 deletions(-) delete mode 100644 lib/std/io/bit_reader.zig delete mode 100644 lib/std/io/buffered_atomic_file.zig diff --git a/lib/std/compress/zstd.zig b/lib/std/compress/zstd.zig index f46af8f8c0..2654b6fed1 100644 --- a/lib/std/compress/zstd.zig +++ b/lib/std/compress/zstd.zig @@ -109,7 +109,7 @@ fn testExpectDecompressError(err: anyerror, compressed: []const u8) !void { in.initFixed(@constCast(compressed)); var zstd_stream: Decompress = .init(&in, .{}); try std.testing.expectError(error.ReadFailed, zstd_stream.reader().readRemainingArrayList(gpa, null, &out, .unlimited)); - try std.testing.expectError(err, zstd_stream.err.?); + try std.testing.expectError(err, zstd_stream.err orelse {}); return error.TestFailed; } diff --git a/lib/std/compress/zstd/Decompress.zig b/lib/std/compress/zstd/Decompress.zig index d2f33e2b7b..e7980a195a 100644 --- a/lib/std/compress/zstd/Decompress.zig +++ b/lib/std/compress/zstd/Decompress.zig @@ -11,12 +11,10 @@ state: State, verify_checksum: bool, err: ?Error = null, -const table_size_max = zstd.compressed_block.table_size_max; - const State = union(enum) { new_frame, in_frame: InFrame, - skipping_frame: u32, + skipping_frame: usize, end, const InFrame = struct { @@ -31,11 +29,38 @@ pub const Options = struct { }; pub const Error = error{ + BadMagic, + BlockOversize, ChecksumFailure, + ContentOversize, DictionaryIdFlagUnsupported, - MalformedBlock, - MalformedFrame, EndOfStream, + HuffmanTreeIncomplete, + InvalidBitStream, + LiteralsBufferUndersize, + MalformedAccuracyLog, + MalformedBlock, + MalformedCompressedBlock, + MalformedFrame, + MalformedFseBits, + MalformedFseTable, + MalformedHuffmanTree, + MalformedLiteralsHeader, + MalformedLiteralsLength, + MalformedLiteralsSection, + MalformedSequence, + MissingStartBit, + OutputBufferUndersize, + InputBufferUndersize, + ReadFailed, + RepeatModeFirst, + ReservedBitSet, + ReservedBlock, + SequenceBufferUndersize, + TreelessLiteralsFirst, + UnexpectedEndOfLiteralStream, + WindowOversize, + WindowSizeUnknown, }; pub fn init(input: *BufferedReader, options: Options) Decompress { @@ -69,25 +94,32 @@ fn read(context: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit) Reader.R d.err = err; return error.ReadFailed; }; - return readInFrame(d, bw, limit, &d.state.in_frame) catch |err| { - d.err = err; - return error.ReadFailed; + return readInFrame(d, bw, limit, &d.state.in_frame) catch |err| switch (err) { + error.ReadFailed => return error.ReadFailed, + error.WriteFailed => return error.WriteFailed, + else => |e| { + d.err = e; + return error.ReadFailed; + }, }; }, .in_frame => |*in_frame| { - return readInFrame(d, bw, limit, in_frame) catch |err| { - d.err = err; - return error.ReadFailed; + return readInFrame(d, bw, limit, in_frame) catch |err| switch (err) { + error.ReadFailed => return error.ReadFailed, + error.WriteFailed => return error.WriteFailed, + else => |e| { + d.err = e; + return error.ReadFailed; + }, }; }, .skipping_frame => |*remaining| { - const requested = remaining.*; - const n = in.discard(.limited(requested)) catch |err| { + const n = in.discard(.limited(remaining.*)) catch |err| { d.err = err; return error.ReadFailed; }; - if (requested == n) d.state = .new_frame; - remaining.* = requested - n; + remaining.* -= n; + if (remaining.* == 0) d.state = .new_frame; return 0; }, .end => return error.EndOfStream, @@ -115,9 +147,9 @@ fn initFrame(d: *Decompress, window_size_max: usize, magic: Frame.Magic) !void { fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state: *State.InFrame) !usize { const in = d.input; - var literal_fse_buffer: [table_size_max.literal]Table.Fse = undefined; - var match_fse_buffer: [table_size_max.match]Table.Fse = undefined; - var offset_fse_buffer: [table_size_max.offset]Table.Fse = undefined; + var literal_fse_buffer: [zstd.table_size_max.literal]Table.Fse = undefined; + var match_fse_buffer: [zstd.table_size_max.match]Table.Fse = undefined; + var offset_fse_buffer: [zstd.table_size_max.offset]Table.Fse = undefined; var literals_buffer: [zstd.block_size_max]u8 = undefined; var sequence_buffer: [zstd.block_size_max]u8 = undefined; @@ -125,10 +157,10 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state: const header_bytes = try in.takeArray(3); const block_header: Frame.Zstandard.Block.Header = @bitCast(header_bytes.*); - const block_size = block_header.block_size; + const block_size = block_header.size; if (state.frame.block_size_max < block_size) return error.BlockOversize; if (@intFromEnum(limit) < block_size) return error.OutputBufferUndersize; - switch (block_header.block_type) { + switch (block_header.type) { .raw => { try in.readAll(bw, .limited(block_size)); return block_size; @@ -151,9 +183,10 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state: var bytes_written: usize = 0; { if (sequence_buffer.len < @intFromEnum(remaining)) - return error.SequenceBufferTooSmall; - const seq_len = try in.readSlice(remaining.slice(&sequence_buffer)); - var bit_stream = try ReverseBitReader.init(sequence_buffer[0..seq_len]); + return error.SequenceBufferUndersize; + const seq_slice = remaining.slice(&sequence_buffer); + try in.readSlice(seq_slice); + var bit_stream = try ReverseBitReader.init(seq_slice); if (sequences_header.sequence_count > 0) { try decode.readInitialFseState(&bit_stream); @@ -205,16 +238,16 @@ fn readInFrame(d: *Decompress, bw: *BufferedWriter, limit: Reader.Limit, state: } } - if (block_header.last_block) { + if (block_header.last) { if (state.frame.has_checksum) { - const expected_checksum = try in.readInt(u32, .little); + const expected_checksum = try in.takeInt(u32, .little); if (state.frame.hasher_opt) |*hasher| { const actual_checksum: u32 = @truncate(hasher.final()); if (expected_checksum != actual_checksum) return error.ChecksumFailure; } } - if (d.frame.content_size) |content_size| { - if (content_size != d.current_frame_decompressed_size) { + if (state.frame.content_size) |content_size| { + if (content_size != state.decompressed_size) { return error.MalformedFrame; } } @@ -249,16 +282,16 @@ pub const Frame = struct { _, pub fn kind(m: Magic) ?Kind { - return switch (m) { - .zstandard => .zstandard, - Skippable.magic_min...Skippable.magic_max => .skippable, + return switch (@intFromEnum(m)) { + @intFromEnum(Magic.zstandard) => .zstandard, + @intFromEnum(Skippable.magic_min)...@intFromEnum(Skippable.magic_max) => .skippable, else => null, }; } pub fn isSkippable(m: Magic) bool { - return switch (m) { - Skippable.magic_min...Skippable.magic_max => true, + return switch (@intFromEnum(m)) { + @intFromEnum(Skippable.magic_min)...@intFromEnum(Skippable.magic_max) => true, else => false, }; } @@ -384,9 +417,9 @@ pub const Frame = struct { ) Decode { return .{ .repeat_offsets = .{ - zstd.compressed_block.start_repeated_offset_1, - zstd.compressed_block.start_repeated_offset_2, - zstd.compressed_block.start_repeated_offset_3, + zstd.start_repeated_offset_1, + zstd.start_repeated_offset_2, + zstd.start_repeated_offset_3, }, .offset = undefined, @@ -410,7 +443,7 @@ pub const Frame = struct { pub const PrepareError = error{ /// the (reversed) literal bitstream's first byte does not have any bits set - BitStreamHasNoStartBit, + MissingStartBit, /// `literals` is a treeless literals section and the decode state does not /// have a Huffman tree from a previous block TreelessLiteralsFirst, @@ -422,6 +455,8 @@ pub const Frame = struct { MalformedFseTable, /// input stream ends before all FSE tables are read EndOfStream, + ReadFailed, + InputBufferUndersize, }; /// Prepare the decoder to decode a compressed block. Loads the literals @@ -430,6 +465,7 @@ pub const Frame = struct { pub fn prepare( self: *Decode, in: *BufferedReader, + remaining: *Reader.Limit, literals: LiteralsSection, sequences_header: SequencesSection.Header, ) PrepareError!void { @@ -455,17 +491,14 @@ pub const Frame = struct { } if (sequences_header.sequence_count > 0) { - try self.updateFseTable(in, .literal, sequences_header.literal_lengths); - try self.updateFseTable(in, .offset, sequences_header.offsets); - try self.updateFseTable(in, .match, sequences_header.match_lengths); + try self.updateFseTable(in, remaining, .literal, sequences_header.literal_lengths); + try self.updateFseTable(in, remaining, .offset, sequences_header.offsets); + try self.updateFseTable(in, remaining, .match, sequences_header.match_lengths); self.fse_tables_undefined = false; } } /// Read initial FSE states for sequence decoding. - /// - /// Errors returned: - /// - `error.EndOfStream` if `bit_reader` does not contain enough bits. pub fn readInitialFseState(self: *Decode, bit_reader: *ReverseBitReader) error{EndOfStream}!void { self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log); self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log); @@ -490,6 +523,7 @@ pub const Frame = struct { const DataType = enum { offset, match, literal }; + /// TODO: don't use `@field` fn updateState( self: *Decode, comptime choice: DataType, @@ -517,9 +551,11 @@ pub const Frame = struct { EndOfStream, }; + /// TODO: don't use `@field` fn updateFseTable( self: *Decode, - source: *BufferedReader, + in: *BufferedReader, + remaining: *Reader.Limit, comptime choice: DataType, mode: SequencesSection.Header.Mode, ) !void { @@ -527,28 +563,32 @@ pub const Frame = struct { switch (mode) { .predefined => { @field(self, field_name).accuracy_log = - @field(zstd.compressed_block.default_accuracy_log, field_name); + @field(zstd.default_accuracy_log, field_name); @field(self, field_name).table = @field(Table, "predefined_" ++ field_name); }, .rle => { @field(self, field_name).accuracy_log = 0; - @field(self, field_name).table = .{ .rle = try source.readByte() }; + remaining.* = remaining.subtract(1) orelse return error.EndOfStream; + @field(self, field_name).table = .{ .rle = try in.takeByte() }; }, .fse => { - var bit_reader: std.io.BitReader(.little) = .init(source); - + if (in.buffer.len < @intFromEnum(remaining.*)) return error.InputBufferUndersize; + const limited_buffer = try in.peek(@intFromEnum(remaining.*)); + var bit_reader: BitReader = .{ .bytes = limited_buffer }; const table_size = try Table.decode( &bit_reader, - @field(zstd.compressed_block.table_symbol_count_max, field_name), - @field(zstd.compressed_block.table_accuracy_log_max, field_name), + @field(zstd.table_symbol_count_max, field_name), + @field(zstd.table_accuracy_log_max, field_name), @field(self, field_name ++ "_fse_buffer"), ); @field(self, field_name).table = .{ .fse = @field(self, field_name ++ "_fse_buffer")[0..table_size], }; @field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size); + in.toss(bit_reader.index); + remaining.* = remaining.subtract(bit_reader.index).?; }, .repeat => if (self.fse_tables_undefined) return error.RepeatModeFirst, } @@ -571,15 +611,15 @@ pub const Frame = struct { const offset_value = (@as(u32, 1) << offset_code) + try bit_reader.readBitsNoEof(u32, offset_code); const match_code = self.getCode(.match); - if (match_code >= zstd.compressed_block.match_length_code_table.len) + if (match_code >= zstd.match_length_code_table.len) return error.InvalidBitStream; - const match = zstd.compressed_block.match_length_code_table[match_code]; + const match = zstd.match_length_code_table[match_code]; const match_length = match[0] + try bit_reader.readBitsNoEof(u32, match[1]); const literal_code = self.getCode(.literal); - if (literal_code >= zstd.compressed_block.literals_length_code_table.len) + if (literal_code >= zstd.literals_length_code_table.len) return error.InvalidBitStream; - const literal = zstd.compressed_block.literals_length_code_table[literal_code]; + const literal = zstd.literals_length_code_table[literal_code]; const literal_length = literal[0] + try bit_reader.readBitsNoEof(u32, literal[1]); const offset = if (offset_value > 3) offset: { @@ -622,12 +662,17 @@ pub const Frame = struct { /// The `BufferedWriter` storage capacity is not large enough to /// accept this stream. OutputBufferUndersize, + WriteFailed, + MalformedLiteralsLength, + MalformedFseBits, + MissingStartBit, + HuffmanTreeIncomplete, }; /// Decode one sequence from `bit_reader` into `dest`. Updates FSE states /// if `last_sequence` is `false`. Assumes `prepare` called for the block /// before attempting to decode sequences. - pub fn decodeSequence( + fn decodeSequence( self: *Decode, dest: *BufferedWriter, bit_reader: *ReverseBitReader, @@ -662,13 +707,13 @@ pub const Frame = struct { return sequence_length; } - fn nextLiteralMultiStream(self: *Decode) error{BitStreamHasNoStartBit}!void { + fn nextLiteralMultiStream(self: *Decode) error{MissingStartBit}!void { self.literal_stream_index += 1; try self.initLiteralStream(self.literal_streams.four[self.literal_stream_index]); } - fn initLiteralStream(self: *Decode, bytes: []const u8) error{BitStreamHasNoStartBit}!void { - try self.literal_stream_reader.init(bytes); + fn initLiteralStream(self: *Decode, bytes: []const u8) error{MissingStartBit}!void { + self.literal_stream_reader = try ReverseBitReader.init(bytes); } fn isLiteralStreamEmpty(self: *Decode) bool { @@ -679,7 +724,7 @@ pub const Frame = struct { } const LiteralBitsError = error{ - BitStreamHasNoStartBit, + MissingStartBit, UnexpectedEndOfLiteralStream, }; fn readLiteralsBits( @@ -704,6 +749,9 @@ pub const Frame = struct { /// Problems decoding Huffman compressed literals UnexpectedEndOfLiteralStream, OutputBufferUndersize, + WriteFailed, + MissingStartBit, + HuffmanTreeIncomplete, }; /// Decode `len` bytes of literals into `dest`. @@ -765,6 +813,7 @@ pub const Frame = struct { } } + /// TODO: don't use `@field` fn getCode(self: *Decode, comptime choice: DataType) u32 { return switch (@field(self, @tagName(choice)).table) { .rle => |value| value, @@ -785,21 +834,17 @@ pub const Frame = struct { }; const InitError = error{ + /// Frame uses a dictionary. DictionaryIdFlagUnsupported, + /// Frame does not have a valid window size. WindowSizeUnknown, - WindowTooLarge, - ContentSizeTooLarge, + /// Window size exceeds `window_size_max` or max `usize` value. + WindowOversize, + /// Frame header indicates a content size exceeding max `usize` value. + ContentOversize, }; + /// Validates `frame_header` and returns the associated `Frame`. - /// - /// Errors returned: - /// - `error.DictionaryIdFlagUnsupported` if the frame uses a dictionary - /// - `error.WindowSizeUnknown` if the frame does not have a valid window - /// size - /// - `error.WindowTooLarge` if the window size is larger than - /// `window_size_max` or `std.math.intMax(usize)` - /// - `error.ContentSizeTooLarge` if the frame header indicates a content - /// size larger than `std.math.maxInt(usize)` pub fn init( frame_header: Frame.Zstandard.Header, window_size_max: usize, @@ -810,15 +855,15 @@ pub const Frame = struct { const window_size_raw = frame_header.windowSize() orelse return error.WindowSizeUnknown; const window_size = if (window_size_raw > window_size_max) - return error.WindowTooLarge + return error.WindowOversize else - std.math.cast(usize, window_size_raw) orelse return error.WindowTooLarge; + std.math.cast(usize, window_size_raw) orelse return error.WindowOversize; const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum; const content_size = if (frame_header.content_size) |size| - std.math.cast(usize, size) orelse return error.ContentSizeTooLarge + std.math.cast(usize, size) orelse return error.ContentOversize else null; @@ -875,13 +920,11 @@ pub const LiteralsSection = struct { compressed_size: ?u18, /// Decode a literals section header. - /// - /// Errors returned: - /// - `error.EndOfStream` if there are not enough bytes in `source` - pub fn decode(source: *BufferedReader) !Header { - const byte0 = try source.readByte(); - const block_type = @as(BlockType, @enumFromInt(byte0 & 0b11)); - const size_format = @as(u2, @intCast((byte0 & 0b1100) >> 2)); + pub fn decode(in: *BufferedReader, remaining: *Reader.Limit) !Header { + remaining.* = remaining.subtract(1) orelse return error.EndOfStream; + const byte0 = try in.takeByte(); + const block_type: BlockType = @enumFromInt(byte0 & 0b11); + const size_format: u2 = @intCast((byte0 & 0b1100) >> 2); var regenerated_size: u20 = undefined; var compressed_size: ?u18 = null; switch (block_type) { @@ -890,28 +933,37 @@ pub const LiteralsSection = struct { 0, 2 => { regenerated_size = byte0 >> 3; }, - 1 => regenerated_size = (byte0 >> 4) + (@as(u20, try source.readByte()) << 4), - 3 => regenerated_size = (byte0 >> 4) + - (@as(u20, try source.readByte()) << 4) + - (@as(u20, try source.readByte()) << 12), + 1 => { + remaining.* = remaining.subtract(1) orelse return error.EndOfStream; + regenerated_size = (byte0 >> 4) + (@as(u20, try in.takeByte()) << 4); + }, + 3 => { + remaining.* = remaining.subtract(2) orelse return error.EndOfStream; + regenerated_size = (byte0 >> 4) + + (@as(u20, try in.takeByte()) << 4) + + (@as(u20, try in.takeByte()) << 12); + }, } }, .compressed, .treeless => { - const byte1 = try source.readByte(); - const byte2 = try source.readByte(); + remaining.* = remaining.subtract(2) orelse return error.EndOfStream; + const byte1 = try in.takeByte(); + const byte2 = try in.takeByte(); switch (size_format) { 0, 1 => { regenerated_size = (byte0 >> 4) + ((@as(u20, byte1) & 0b00111111) << 4); compressed_size = ((byte1 & 0b11000000) >> 6) + (@as(u18, byte2) << 2); }, 2 => { - const byte3 = try source.readByte(); + remaining.* = remaining.subtract(1) orelse return error.EndOfStream; + const byte3 = try in.takeByte(); regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00000011) << 12); compressed_size = ((byte2 & 0b11111100) >> 2) + (@as(u18, byte3) << 6); }, 3 => { - const byte3 = try source.readByte(); - const byte4 = try source.readByte(); + remaining.* = remaining.subtract(2) orelse return error.EndOfStream; + const byte3 = try in.takeByte(); + const byte4 = try in.takeByte(); regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00111111) << 12); compressed_size = ((byte2 & 0b11000000) >> 6) + (@as(u18, byte3) << 2) + (@as(u18, byte4) << 10); }, @@ -950,17 +1002,17 @@ pub const LiteralsSection = struct { index: usize, }; - pub fn query(self: HuffmanTree, index: usize, prefix: u16) error{NotFound}!Result { + pub fn query(self: HuffmanTree, index: usize, prefix: u16) error{HuffmanTreeIncomplete}!Result { var node = self.nodes[index]; const weight = node.weight; var i: usize = index; while (node.weight == weight) { - if (node.prefix == prefix) return Result{ .symbol = node.symbol }; - if (i == 0) return error.NotFound; + if (node.prefix == prefix) return .{ .symbol = node.symbol }; + if (i == 0) return error.HuffmanTreeIncomplete; i -= 1; node = self.nodes[i]; } - return Result{ .index = i }; + return .{ .index = i }; } pub fn weightToBitCount(weight: u4, max_bit_count: u4) u4 { @@ -975,20 +1027,26 @@ pub const LiteralsSection = struct { MissingStartBit, }; - pub fn decode(in: *BufferedReader) HuffmanTree.DecodeError!HuffmanTree { + pub fn decode(in: *BufferedReader, remaining: *Reader.Limit) HuffmanTree.DecodeError!HuffmanTree { + remaining.* = remaining.subtract(1) orelse return error.EndOfStream; const header = try in.takeByte(); if (header < 128) { - return decodeFse(in, header); + return decodeFse(in, remaining, header); } else { - return decodeDirect(in, header - 127); + return decodeDirect(in, remaining, header - 127); } } - fn decodeDirect(source: *BufferedReader, encoded_symbol_count: usize) HuffmanTree.DecodeError!HuffmanTree { + fn decodeDirect( + in: *BufferedReader, + remaining: *Reader.Limit, + encoded_symbol_count: usize, + ) HuffmanTree.DecodeError!HuffmanTree { var weights: [256]u4 = undefined; const weights_byte_count = (encoded_symbol_count + 1) / 2; + remaining.* = remaining.subtract(weights_byte_count) orelse return error.EndOfStream; for (0..weights_byte_count) |i| { - const byte = try source.takeByte(); + const byte = try in.takeByte(); weights[2 * i] = @as(u4, @intCast(byte >> 4)); weights[2 * i + 1] = @as(u4, @intCast(byte & 0xF)); } @@ -996,22 +1054,25 @@ pub const LiteralsSection = struct { return build(&weights, symbol_count); } - fn decodeFse(in: *BufferedReader, compressed_size: usize) HuffmanTree.DecodeError!HuffmanTree { + fn decodeFse( + in: *BufferedReader, + remaining: *Reader.Limit, + compressed_size: usize, + ) HuffmanTree.DecodeError!HuffmanTree { var weights: [256]u4 = undefined; + remaining.* = remaining.subtract(compressed_size) orelse return error.EndOfStream; const compressed_buffer = try in.take(compressed_size); - var limited_stream: BufferedReader = undefined; - limited_stream.initFixed(compressed_buffer); - var bit_reader: std.io.BitReader(.little) = .init(&limited_stream); + var bit_reader: BitReader = .{ .bytes = compressed_buffer }; var entries: [1 << 6]Table.Fse = undefined; const table_size = try Table.decode(&bit_reader, 256, 6, &entries); const accuracy_log = std.math.log2_int_ceil(usize, table_size); - const remaining = limited_stream.bufferContents(); - const symbol_count = try assignWeights(remaining, accuracy_log, &entries, weights); + const remaining_buffer = bit_reader.bytes[bit_reader.index..]; + const symbol_count = try assignWeights(remaining_buffer, accuracy_log, &entries, &weights); return build(&weights, symbol_count); } fn assignWeights( - huff_bits_buffer: []u8, + huff_bits_buffer: []const u8, accuracy_log: u16, entries: *[1 << 6]Table.Fse, weights: *[256]u4, @@ -1159,14 +1220,18 @@ pub const LiteralsSection = struct { MalformedHuffmanTree, /// Not enough bytes to complete the section. EndOfStream, + ReadFailed, + LiteralsBufferUndersize, + MissingStartBit, }; - pub fn decode(source: *BufferedReader, buffer: []u8) DecodeError!LiteralsSection { - const header = try Header.decode(source); + pub fn decode(in: *BufferedReader, remaining: *Reader.Limit, buffer: []u8) DecodeError!LiteralsSection { + const header = try Header.decode(in, remaining); switch (header.block_type) { .raw => { - if (buffer.len < header.regenerated_size) return error.LiteralsBufferTooSmall; - try source.readNoEof(buffer[0..header.regenerated_size]); + if (buffer.len < header.regenerated_size) return error.LiteralsBufferUndersize; + remaining.* = remaining.subtract(header.regenerated_size) orelse return error.EndOfStream; + try in.readSlice(buffer[0..header.regenerated_size]); return .{ .header = header, .huffman_tree = null, @@ -1174,7 +1239,8 @@ pub const LiteralsSection = struct { }; }, .rle => { - buffer[0] = try source.readByte(); + remaining.* = remaining.subtract(1) orelse return error.EndOfStream; + buffer[0] = try in.takeByte(); return .{ .header = header, .huffman_tree = null, @@ -1182,19 +1248,18 @@ pub const LiteralsSection = struct { }; }, .compressed, .treeless => { - var counting_reader = std.io.countingReader(source); + const before_remaining = remaining.*; const huffman_tree = if (header.block_type == .compressed) - try HuffmanTree.decode(counting_reader.reader(), buffer) + try HuffmanTree.decode(in, remaining) else null; - const huffman_tree_size = @as(usize, @intCast(counting_reader.bytes_read)); + const huffman_tree_size = @intFromEnum(before_remaining) - @intFromEnum(remaining.*); const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch return error.MalformedLiteralsSection; - - if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall; - try source.readNoEof(buffer[0..total_streams_size]); + if (total_streams_size > buffer.len) return error.LiteralsBufferUndersize; + remaining.* = remaining.subtract(total_streams_size) orelse return error.EndOfStream; + try in.readSlice(buffer[0..total_streams_size]); const stream_data = buffer[0..total_streams_size]; - const streams = try Streams.decode(header.size_format, stream_data); return .{ .header = header, @@ -1207,7 +1272,7 @@ pub const LiteralsSection = struct { }; pub const SequencesSection = struct { - header: SequencesSection.Header, + header: Header, literals_length_table: Table, offset_table: Table, match_length_table: Table, @@ -1228,32 +1293,37 @@ pub const SequencesSection = struct { pub const DecodeError = error{ ReservedBitSet, EndOfStream, + ReadFailed, }; - pub fn decode(source: *BufferedReader) DecodeError!Header { + pub fn decode(in: *BufferedReader, remaining: *Reader.Limit) DecodeError!Header { var sequence_count: u24 = undefined; - const byte0 = try source.readByte(); + remaining.* = remaining.subtract(1) orelse return error.EndOfStream; + const byte0 = try in.takeByte(); if (byte0 == 0) { - return SequencesSection.Header{ + return .{ .sequence_count = 0, .offsets = undefined, .match_lengths = undefined, .literal_lengths = undefined, }; } else if (byte0 < 128) { + remaining.* = remaining.subtract(1) orelse return error.EndOfStream; sequence_count = byte0; } else if (byte0 < 255) { - sequence_count = (@as(u24, (byte0 - 128)) << 8) + try source.readByte(); + remaining.* = remaining.subtract(2) orelse return error.EndOfStream; + sequence_count = (@as(u24, (byte0 - 128)) << 8) + try in.takeByte(); } else { - sequence_count = (try source.readByte()) + (@as(u24, try source.readByte()) << 8) + 0x7F00; + remaining.* = remaining.subtract(3) orelse return error.EndOfStream; + sequence_count = (try in.takeByte()) + (@as(u24, try in.takeByte()) << 8) + 0x7F00; } - const compression_modes = try source.readByte(); + const compression_modes = try in.takeByte(); - const matches_mode = @as(SequencesSection.Header.Mode, @enumFromInt((compression_modes & 0b00001100) >> 2)); - const offsets_mode = @as(SequencesSection.Header.Mode, @enumFromInt((compression_modes & 0b00110000) >> 4)); - const literal_mode = @as(SequencesSection.Header.Mode, @enumFromInt((compression_modes & 0b11000000) >> 6)); + const matches_mode: Header.Mode = @enumFromInt((compression_modes & 0b00001100) >> 2); + const offsets_mode: Header.Mode = @enumFromInt((compression_modes & 0b00110000) >> 4); + const literal_mode: Header.Mode = @enumFromInt((compression_modes & 0b11000000) >> 6); if (compression_modes & 0b11 != 0) return error.ReservedBitSet; return .{ @@ -1277,7 +1347,7 @@ pub const Table = union(enum) { }; pub fn decode( - bit_reader: *std.io.BitReader(.little), + bit_reader: *BitReader, expected_symbol_count: usize, max_accuracy_log: u4, entries: []Table.Fse, @@ -1600,6 +1670,22 @@ pub const Table = union(enum) { }; }; +const low_bit_mask = [9]u8{ + 0b00000000, + 0b00000001, + 0b00000011, + 0b00000111, + 0b00001111, + 0b00011111, + 0b00111111, + 0b01111111, + 0b11111111, +}; + +fn Bits(comptime T: type) type { + return struct { T, u16 }; +} + /// For reading the reversed bit streams used to encode FSE compressed data. const ReverseBitReader = struct { bytes: []const u8, @@ -1619,20 +1705,175 @@ const ReverseBitReader = struct { return error.MissingStartBit; } - fn readBitsNoEof(self: *ReverseBitReader, comptime U: type, num_bits: u16) error{EndOfStream}!U { - return self.bit_reader.readBitsNoEof(U, num_bits); + fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) { + const UT = std.meta.Int(.unsigned, @bitSizeOf(T)); + return .{ + @bitCast(@as(UT, @intCast(out))), + num, + }; } - fn readBits(self: *ReverseBitReader, comptime U: type, num_bits: u16, out_bits: *u16) error{}!U { - return try self.bit_reader.readBits(U, num_bits, out_bits); + fn readBitsNoEof(self: *ReverseBitReader, comptime T: type, num: u16) error{EndOfStream}!T { + const b, const c = try self.readBitsTuple(T, num); + if (c < num) return error.EndOfStream; + return b; } - fn alignToByte(self: *ReverseBitReader) void { - self.bit_reader.alignToByte(); + fn readBits(self: *ReverseBitReader, comptime T: type, num: u16, out_bits: *u16) !T { + const b, const c = try self.readBitsTuple(T, num); + out_bits.* = c; + return b; + } + + fn readBitsTuple(self: *ReverseBitReader, comptime T: type, num: u16) !Bits(T) { + const UT = std.meta.Int(.unsigned, @bitSizeOf(T)); + const U = if (@bitSizeOf(T) < 8) u8 else UT; + + if (num <= self.count) return initBits(T, self.removeBits(@intCast(num)), num); + + var out_count: u16 = self.count; + var out: U = self.removeBits(self.count); + + const full_bytes_left = (num - out_count) / 8; + + for (0..full_bytes_left) |_| { + const byte = takeByte(self) catch |err| switch (err) { + error.EndOfStream => return initBits(T, out, out_count), + }; + if (U == u8) out = 0 else out <<= 8; + out |= byte; + out_count += 8; + } + + const bits_left = num - out_count; + const keep = 8 - bits_left; + + if (bits_left == 0) return initBits(T, out, out_count); + + const final_byte = takeByte(self) catch |err| switch (err) { + error.EndOfStream => return initBits(T, out, out_count), + }; + + out <<= @intCast(bits_left); + out |= final_byte >> @intCast(keep); + self.bits = final_byte & low_bit_mask[keep]; + + self.count = @intCast(keep); + return initBits(T, out, num); + } + + fn takeByte(rbr: *ReverseBitReader) error{EndOfStream}!u8 { + if (rbr.remaining == 0) return error.EndOfStream; + rbr.remaining -= 1; + return rbr.bytes[rbr.remaining]; } fn isEmpty(self: *const ReverseBitReader) bool { - return self.byte_reader.remaining_bytes == 0 and self.bit_reader.count == 0; + return self.remaining == 0 and self.count == 0; + } + + fn removeBits(self: *ReverseBitReader, num: u4) u8 { + if (num == 8) { + self.count = 0; + return self.bits; + } + + const keep = self.count - num; + const bits = self.bits >> @intCast(keep); + self.bits &= low_bit_mask[keep]; + + self.count = keep; + return bits; + } +}; + +const BitReader = struct { + bytes: []const u8, + index: usize = 0, + bits: u8 = 0, + count: u4 = 0, + + fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) { + const UT = std.meta.Int(.unsigned, @bitSizeOf(T)); + return .{ + @bitCast(@as(UT, @intCast(out))), + num, + }; + } + + fn readBitsNoEof(self: *@This(), comptime T: type, num: u16) !T { + const b, const c = try self.readBitsTuple(T, num); + if (c < num) return error.EndOfStream; + return b; + } + + fn readBits(self: *@This(), comptime T: type, num: u16, out_bits: *u16) !T { + const b, const c = try self.readBitsTuple(T, num); + out_bits.* = c; + return b; + } + + fn readBitsTuple(self: *@This(), comptime T: type, num: u16) !Bits(T) { + const UT = std.meta.Int(.unsigned, @bitSizeOf(T)); + const U = if (@bitSizeOf(T) < 8) u8 else UT; + + if (num <= self.count) return initBits(T, self.removeBits(@intCast(num)), num); + + var out_count: u16 = self.count; + var out: U = self.removeBits(self.count); + + const full_bytes_left = (num - out_count) / 8; + + for (0..full_bytes_left) |_| { + const byte = takeByte(self) catch |err| switch (err) { + error.EndOfStream => return initBits(T, out, out_count), + }; + + const pos = @as(U, byte) << @intCast(out_count); + out |= pos; + out_count += 8; + } + + const bits_left = num - out_count; + const keep = 8 - bits_left; + + if (bits_left == 0) return initBits(T, out, out_count); + + const final_byte = takeByte(self) catch |err| switch (err) { + error.EndOfStream => return initBits(T, out, out_count), + }; + + const pos = @as(U, final_byte & low_bit_mask[bits_left]) << @intCast(out_count); + out |= pos; + self.bits = final_byte >> @intCast(bits_left); + + self.count = @intCast(keep); + return initBits(T, out, num); + } + + fn takeByte(br: *BitReader) error{EndOfStream}!u8 { + if (br.bytes.len - br.index == 0) return error.EndOfStream; + const result = br.bytes[br.index]; + br.index += 1; + return result; + } + + fn removeBits(self: *@This(), num: u4) u8 { + if (num == 8) { + self.count = 0; + return self.bits; + } + + const keep = self.count - num; + const bits = self.bits & low_bit_mask[num]; + self.bits >>= @intCast(num); + self.count = keep; + return bits; + } + + fn alignToByte(self: *@This()) void { + self.bits = 0; + self.count = 0; } }; diff --git a/lib/std/io.zig b/lib/std/io.zig index 292812752d..c7b45fefa0 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -19,16 +19,12 @@ pub const AllocatingWriter = @import("io/AllocatingWriter.zig"); pub const MultiWriter = @import("io/multi_writer.zig").MultiWriter; pub const multiWriter = @import("io/multi_writer.zig").multiWriter; -pub const BitReader = @import("io/bit_reader.zig").Type; - pub const BitWriter = @import("io/bit_writer.zig").BitWriter; pub const bitWriter = @import("io/bit_writer.zig").bitWriter; pub const ChangeDetectionStream = @import("io/change_detection_stream.zig").ChangeDetectionStream; pub const changeDetectionStream = @import("io/change_detection_stream.zig").changeDetectionStream; -pub const BufferedAtomicFile = @import("io/buffered_atomic_file.zig").BufferedAtomicFile; - pub const tty = @import("io/tty.zig"); pub fn poll( @@ -437,13 +433,11 @@ pub fn PollFiles(comptime StreamEnum: type) type { } test { - _ = BufferedWriter; + _ = AllocatingWriter; + _ = BitWriter; _ = BufferedReader; + _ = BufferedWriter; _ = Reader; _ = Writer; - _ = AllocatingWriter; - _ = @import("io/bit_reader.zig"); - _ = @import("io/bit_writer.zig"); - _ = @import("io/buffered_atomic_file.zig"); _ = @import("io/test.zig"); } diff --git a/lib/std/io/bit_reader.zig b/lib/std/io/bit_reader.zig deleted file mode 100644 index 5740631380..0000000000 --- a/lib/std/io/bit_reader.zig +++ /dev/null @@ -1,236 +0,0 @@ -const std = @import("../std.zig"); -const bit_reader = @This(); - -//General note on endianess: -//Big endian is packed starting in the most significant part of the byte and subsequent -// bytes contain less significant bits. Thus we always take bits from the high -// end and place them below existing bits in our output. -//Little endian is packed starting in the least significant part of the byte and -// subsequent bytes contain more significant bits. Thus we always take bits from -// the low end and place them above existing bits in our output. -//Regardless of endianess, within any given byte the bits are always in most -// to least significant order. -//Also regardless of endianess, the buffer always aligns bits to the low end -// of the byte. - -/// Creates a bit reader which allows for reading bits from an underlying standard reader -pub fn Type(comptime endian: std.builtin.Endian) type { - return struct { - reader: *std.io.BufferedReader, - bits: u8, - count: u4, - - const low_bit_mask = [9]u8{ - 0b00000000, - 0b00000001, - 0b00000011, - 0b00000111, - 0b00001111, - 0b00011111, - 0b00111111, - 0b01111111, - 0b11111111, - }; - - pub fn init(reader: *std.io.BufferedReader) @This() { - return .{ .reader = reader, .bits = 0, .count = 0 }; - } - - fn Bits(comptime T: type) type { - return struct { T, u16 }; - } - - fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) { - const UT = std.meta.Int(.unsigned, @bitSizeOf(T)); - return .{ - @bitCast(@as(UT, @intCast(out))), - num, - }; - } - - /// Reads `bits` bits from the reader and returns a specified type - /// containing them in the least significant end, returning an error if the - /// specified number of bits could not be read. - pub fn readBitsNoEof(self: *@This(), comptime T: type, num: u16) !T { - const b, const c = try self.readBitsTuple(T, num); - if (c < num) return error.EndOfStream; - return b; - } - - /// Reads `bits` bits from the reader and returns a specified type - /// containing them in the least significant end. The number of bits successfully - /// read is placed in `out_bits`, as reaching the end of the stream is not an error. - pub fn readBits(self: *@This(), comptime T: type, num: u16, out_bits: *u16) !T { - const b, const c = try self.readBitsTuple(T, num); - out_bits.* = c; - return b; - } - - /// Reads `bits` bits from the reader and returns a tuple of the specified type - /// containing them in the least significant end, and the number of bits successfully - /// read. Reaching the end of the stream is not an error. - pub fn readBitsTuple(self: *@This(), comptime T: type, num: u16) !Bits(T) { - const UT = std.meta.Int(.unsigned, @bitSizeOf(T)); - const U = if (@bitSizeOf(T) < 8) u8 else UT; //it is a pain to work with return initBits(T, out, out_count), - else => |e| return e, - }; - - switch (endian) { - .big => { - if (U == u8) out = 0 else out <<= 8; //shifting u8 by 8 is illegal in Zig - out |= byte; - }, - .little => { - const pos = @as(U, byte) << @intCast(out_count); - out |= pos; - }, - } - out_count += 8; - } - - const bits_left = num - out_count; - const keep = 8 - bits_left; - - if (bits_left == 0) return initBits(T, out, out_count); - - const final_byte = self.reader.takeByte() catch |err| switch (err) { - error.EndOfStream => return initBits(T, out, out_count), - else => |e| return e, - }; - - switch (endian) { - .big => { - out <<= @intCast(bits_left); - out |= final_byte >> @intCast(keep); - self.bits = final_byte & low_bit_mask[keep]; - }, - .little => { - const pos = @as(U, final_byte & low_bit_mask[bits_left]) << @intCast(out_count); - out |= pos; - self.bits = final_byte >> @intCast(bits_left); - }, - } - - self.count = @intCast(keep); - return initBits(T, out, num); - } - - //convenience function for removing bits from - //the appropriate part of the buffer based on - //endianess. - fn removeBits(self: *@This(), num: u4) u8 { - if (num == 8) { - self.count = 0; - return self.bits; - } - - const keep = self.count - num; - const bits = switch (endian) { - .big => self.bits >> @intCast(keep), - .little => self.bits & low_bit_mask[num], - }; - switch (endian) { - .big => self.bits &= low_bit_mask[keep], - .little => self.bits >>= @intCast(num), - } - - self.count = keep; - return bits; - } - - pub fn alignToByte(self: *@This()) void { - self.bits = 0; - self.count = 0; - } - }; -} - -/////////////////////////////// - -test "api coverage" { - const mem_be = [_]u8{ 0b11001101, 0b00001011 }; - const mem_le = [_]u8{ 0b00011101, 0b10010101 }; - - var mem_in_be = std.io.fixedBufferStream(&mem_be); - var bit_stream_be: bit_reader.Type(.big) = .init(mem_in_be.reader()); - - var out_bits: u16 = undefined; - - const expect = std.testing.expect; - const expectError = std.testing.expectError; - - try expect(1 == try bit_stream_be.readBits(u2, 1, &out_bits)); - try expect(out_bits == 1); - try expect(2 == try bit_stream_be.readBits(u5, 2, &out_bits)); - try expect(out_bits == 2); - try expect(3 == try bit_stream_be.readBits(u128, 3, &out_bits)); - try expect(out_bits == 3); - try expect(4 == try bit_stream_be.readBits(u8, 4, &out_bits)); - try expect(out_bits == 4); - try expect(5 == try bit_stream_be.readBits(u9, 5, &out_bits)); - try expect(out_bits == 5); - try expect(1 == try bit_stream_be.readBits(u1, 1, &out_bits)); - try expect(out_bits == 1); - - mem_in_be.pos = 0; - bit_stream_be.count = 0; - try expect(0b110011010000101 == try bit_stream_be.readBits(u15, 15, &out_bits)); - try expect(out_bits == 15); - - mem_in_be.pos = 0; - bit_stream_be.count = 0; - try expect(0b1100110100001011 == try bit_stream_be.readBits(u16, 16, &out_bits)); - try expect(out_bits == 16); - - _ = try bit_stream_be.readBits(u0, 0, &out_bits); - - try expect(0 == try bit_stream_be.readBits(u1, 1, &out_bits)); - try expect(out_bits == 0); - try expectError(error.EndOfStream, bit_stream_be.readBitsNoEof(u1, 1)); - - var mem_in_le = std.io.fixedBufferStream(&mem_le); - var bit_stream_le: bit_reader.Type(.little) = .init(mem_in_le.reader()); - - try expect(1 == try bit_stream_le.readBits(u2, 1, &out_bits)); - try expect(out_bits == 1); - try expect(2 == try bit_stream_le.readBits(u5, 2, &out_bits)); - try expect(out_bits == 2); - try expect(3 == try bit_stream_le.readBits(u128, 3, &out_bits)); - try expect(out_bits == 3); - try expect(4 == try bit_stream_le.readBits(u8, 4, &out_bits)); - try expect(out_bits == 4); - try expect(5 == try bit_stream_le.readBits(u9, 5, &out_bits)); - try expect(out_bits == 5); - try expect(1 == try bit_stream_le.readBits(u1, 1, &out_bits)); - try expect(out_bits == 1); - - mem_in_le.pos = 0; - bit_stream_le.count = 0; - try expect(0b001010100011101 == try bit_stream_le.readBits(u15, 15, &out_bits)); - try expect(out_bits == 15); - - mem_in_le.pos = 0; - bit_stream_le.count = 0; - try expect(0b1001010100011101 == try bit_stream_le.readBits(u16, 16, &out_bits)); - try expect(out_bits == 16); - - _ = try bit_stream_le.readBits(u0, 0, &out_bits); - - try expect(0 == try bit_stream_le.readBits(u1, 1, &out_bits)); - try expect(out_bits == 0); - try expectError(error.EndOfStream, bit_stream_le.readBitsNoEof(u1, 1)); -} diff --git a/lib/std/io/buffered_atomic_file.zig b/lib/std/io/buffered_atomic_file.zig deleted file mode 100644 index 71edabb20a..0000000000 --- a/lib/std/io/buffered_atomic_file.zig +++ /dev/null @@ -1,55 +0,0 @@ -const std = @import("../std.zig"); -const mem = std.mem; -const fs = std.fs; -const File = std.fs.File; - -pub const BufferedAtomicFile = struct { - atomic_file: fs.AtomicFile, - file_writer: File.Writer, - buffered_writer: BufferedWriter, - allocator: mem.Allocator, - - pub const buffer_size = 4096; - pub const BufferedWriter = std.io.BufferedWriter(buffer_size, File.Writer); - pub const Writer = std.io.Writer(*BufferedWriter, BufferedWriter.Error, BufferedWriter.write); - - /// TODO when https://github.com/ziglang/zig/issues/2761 is solved - /// this API will not need an allocator - pub fn create( - allocator: mem.Allocator, - dir: fs.Dir, - dest_path: []const u8, - atomic_file_options: fs.Dir.AtomicFileOptions, - ) !*BufferedAtomicFile { - var self = try allocator.create(BufferedAtomicFile); - self.* = BufferedAtomicFile{ - .atomic_file = undefined, - .file_writer = undefined, - .buffered_writer = undefined, - .allocator = allocator, - }; - errdefer allocator.destroy(self); - - self.atomic_file = try dir.atomicFile(dest_path, atomic_file_options); - errdefer self.atomic_file.deinit(); - - self.file_writer = self.atomic_file.file.writer(); - self.buffered_writer = .{ .unbuffered_writer = self.file_writer }; - return self; - } - - /// always call destroy, even after successful finish() - pub fn destroy(self: *BufferedAtomicFile) void { - self.atomic_file.deinit(); - self.allocator.destroy(self); - } - - pub fn finish(self: *BufferedAtomicFile) !void { - try self.buffered_writer.flush(); - try self.atomic_file.finish(); - } - - pub fn writer(self: *BufferedAtomicFile) Writer { - return .{ .context = &self.buffered_writer }; - } -};