Skip to content

Commit

Permalink
add a custom error type
Browse files Browse the repository at this point in the history
  • Loading branch information
nhz2 committed Jul 26, 2024
1 parent b1f2a47 commit 16dfc95
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 21 deletions.
2 changes: 2 additions & 0 deletions src/CodecInflate64.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ export DeflateDecompressor
export DeflateDecompressorStream
export Deflate64Decompressor
export Deflate64DecompressorStream
export DecompressionError

include("errors.jl")
include("huffmantree.jl")
include("stream.jl")
include("codecs.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/codecs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function TranscodingStreams.process(
status, Δin, Δout = main_run!(input, output, codec.s)
catch e
# rethrow()
e isa InterruptException && rethrow()
e isa DecompressionError || rethrow()
error_ref[] = e
return 0, 0, :error
end
Expand All @@ -72,7 +72,7 @@ function TranscodingStreams.process(
elseif status === :input
# need more input
if iszero(input.size)
error_ref[] = ErrorException("not enough input")
error_ref[] = DecompressionError("not enough input")
return Δin, Δout, :error
else
return Δin, Δout, :ok
Expand Down
13 changes: 13 additions & 0 deletions src/errors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
struct DecompressionError <: Exception
The data is not valid for decompression
"""
struct DecompressionError <: Exception
msg::String
end

function Base.showerror(io::IO, err::DecompressionError)
print(io, "DecompressionError: ")
print(io, err.msg)
end
27 changes: 26 additions & 1 deletion src/huffmantree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ function parse_huffman!(
sorted_ops .= 0xFFFF
op_offset_per_num_bit .= 0x0000
max_num_bits = length(num_ops_per_num_bit)
code_space::UInt64 = 0 # this keeps track of the amount of code_space used out of 2^32
@assert max_num_bits maximum(num_bits_per_op)
@assert length(op_offset_per_num_bit) == max_num_bits + 1
@assert length(sorted_ops) length(num_bits_per_op)
for n in num_bits_per_op
if !iszero(n)
num_ops_per_num_bit[n] += 1
code_space += UInt64(1)<<(32-n) # a 1 bit code uses half code_space, 2 bit code uses 1/4 code_space ...
end
end

op_offset_per_num_bit[1] = 1
op_offset_per_num_bit[2] = 1
for n in 2:max_num_bits
Expand All @@ -49,6 +52,28 @@ function parse_huffman!(
op_offset_per_num_bit[n+1] = off + 1
end
end
# the logic for this test is from
# https://github.com/ebiggers/libdeflate/blob/dc76454a39e7e83b68c3704b6e3784654f8d5ac5/lib/deflate_decompress.c#L791
if code_space > UInt64(1)<<32
# This can never be valid
throw(DecompressionError("overfull code"))
elseif code_space < UInt64(1)<<32
# This can be valid in some special cases described in the RFC
# https://github.com/ebiggers/libdeflate/blob/dc76454a39e7e83b68c3704b6e3784654f8d5ac5/lib/deflate_decompress.c#L809-L839
if !iszero(code_space) # no codes is valid if no distance codes are used.
if code_space != UInt64(1)<<31 || num_ops_per_num_bit[1] != 1 # one code encoded with one bit is valid.
throw(DecompressionError("incomplete code"))
else
# pad out huffman tree like in libdeflate
# This ensures that all codes can be decoded without error
# later on.
num_ops_per_num_bit[1] = 2
op_offset_per_num_bit .= 3
op_offset_per_num_bit[1] = 1
sorted_ops[2] = sorted_ops[1]
end
end
end
end

# Using algorithm from https://github.com/GunnarFarneback/Inflate.jl/blob/cc77be73388f4160d187ab0c3fdaa3df13aa7f3b/src/Inflate.jl#L134-L145
Expand All @@ -63,5 +88,5 @@ function get_op(bits::UInt16, tree::HuffmanTree)::Tuple{UInt16, UInt8}
end
v -= tree.num_ops_per_num_bit[nbits]
end
error("incomplete code table")
error("incomplete code table") # This should never happen because of other checks
end
44 changes: 37 additions & 7 deletions src/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ function read_input_bits!(s::StreamState)::Bool
elseif BTYPE == 0b10
s.in_mode = NUM_CODES
else
error("invalid block compression mode 3")
throw(DecompressionError("invalid block compression mode 3"))
end
end
elseif s.in_mode == NON_COMPRESSED_LENS
Expand All @@ -214,7 +214,7 @@ function read_input_bits!(s::StreamState)::Bool
s.mode = COPY_OUT
s.in_mode = HEADER_BITS
if len nlen != 0xffff
error("corrupted copy lengths")
throw(DecompressionError("corrupted copy lengths"))
end
s.len = len
end
Expand All @@ -241,6 +241,9 @@ function read_input_bits!(s::StreamState)::Bool
s.clen_num_bits_per_op[1 + order[i]] = x & 0b111
x >>= 0x03
end
if all(iszero, s.clen_num_bits_per_op)
throw(DecompressionError("no code for clen"))
end
parse_huffman!(s.clen_tree, s.clen_num_bits_per_op)
s.lit_len_dist_num_bits_per_op .= 0x00
s.num_bits_per_op_idx = 1
Expand All @@ -251,6 +254,8 @@ function read_input_bits!(s::StreamState)::Bool
local op, nbits = get_op(s.in_buf%UInt16, s.clen_tree)
local i = s.num_bits_per_op_idx
local n::Int
local max_i = s.nlit + s.ndist
@assert i max_i
if !consume!(s, nbits)
return false
end
Expand All @@ -263,6 +268,11 @@ function read_input_bits!(s::StreamState)::Bool
if !consume!(s, 0x02) # The next 2 bits indicate repeat length
return false
end
if isone(i)
throw(DecompressionError("no previous code length to repeat"))
elseif i + n - 1 > max_i
throw(DecompressionError("too many code lengths"))
end
s.lit_len_dist_num_bits_per_op[i:i + n - 1] .= s.lit_len_dist_num_bits_per_op[i-1]
i += n
elseif op == 0x0011
Expand All @@ -271,6 +281,9 @@ function read_input_bits!(s::StreamState)::Bool
if !consume!(s, 0x03) # (3 bits of length)
return false
end
if i + n - 1 > max_i
throw(DecompressionError("too many code lengths"))
end
s.lit_len_dist_num_bits_per_op[i:i + n - 1] .= 0x00
i += n
elseif op == 0x0012
Expand All @@ -279,14 +292,29 @@ function read_input_bits!(s::StreamState)::Bool
if !consume!(s, 0x07) # (7 bits of length)
return false
end
if i + n - 1 > max_i
throw(DecompressionError("too many code lengths"))
end
s.lit_len_dist_num_bits_per_op[i:i + n - 1] .= 0x00
i += n
else
error("unreachable")
end
if i > s.nlit + s.ndist
parse_huffman!(s.lit_len_tree, view(s.lit_len_dist_num_bits_per_op, 1:Int(s.nlit)))
parse_huffman!(s.dist_tree, view(s.lit_len_dist_num_bits_per_op, (Int(s.nlit)+1):(Int(s.nlit+s.ndist))))
if i > max_i
local lit_len_num_bits_per_op = view(s.lit_len_dist_num_bits_per_op, 1:Int(s.nlit))
local dist_num_bits_per_op = view(s.lit_len_dist_num_bits_per_op, Int(s.nlit+1):Int(s.nlit+s.ndist))
if iszero(lit_len_num_bits_per_op[1 + 256])
throw(DecompressionError("no code for end-of-block"))
end
# if there are no dist codes, there also cannot be any len codes
if all(iszero, dist_num_bits_per_op)
local last_lit_len_op = something(findlast(!iszero, lit_len_num_bits_per_op))
if last_lit_len_op > 1 + 256
throw(DecompressionError("no codes for distances, but there is a code for length"))
end
end
parse_huffman!(s.lit_len_tree, lit_len_num_bits_per_op)
parse_huffman!(s.dist_tree, dist_num_bits_per_op)
s.in_mode = LIT_LEN_DIST_OP
else
s.num_bits_per_op_idx = i
Expand Down Expand Up @@ -331,7 +359,9 @@ function read_input_bits!(s::StreamState)::Bool
end
else
# unknown op
error("unknown op")
# if the fixed Huffman codes are used
# op 286 and op 287 are invalid but can be encoded.
throw(DecompressionError("unknown len op"))
end
# read dist
op, nbits = get_op(s.in_buf%UInt16, s.dist_tree)
Expand Down Expand Up @@ -442,7 +472,7 @@ this can error if `dist` goes before the start of the out buffer.
"""
function copy_from_output!(out_ptr::Ptr{UInt8}, s::StreamState, n_copy::Int64, dist::UInt32)::Nothing
if dist > BUFFER_SIZE || iszero(dist) || (!s.out_full && s.out_offset < dist)
error("cannot read past beginning of out buffer dist: $(dist)")
throw(DecompressionError("cannot read before beginning of out buffer"))
end
for i in 1:n_copy
x = s.out_buf[begin + (s.out_offset - dist%UInt16)]
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ include("tests_from_inflate.jl")
@test decompress_bytes(p7zip_compress(d)) == d
@test de64compress(p7zip_64compress(d)) == d
@test de64compress_bytes(p7zip_64compress(d)) == d
@test_throws ErrorException de64compress(p7zip_64compress(d)[begin:end-1])
@test_throws DecompressionError de64compress(p7zip_64compress(d)[begin:end-1])

for n in 65536-1000:65536+1000
d = [thing; zeros(UInt8, n); thing]
@test decompress(p7zip_64compress(d)) == d
@test_throws ErrorException de64compress(p7zip_64compress(d)[begin:end-1])
@test_throws DecompressionError de64compress(p7zip_64compress(d)[begin:end-1])
end

for n in [0:1000; 1000000;]
Expand All @@ -41,7 +41,7 @@ include("tests_from_inflate.jl")
@test decompress_bytes(p7zip_compress(d)) == d
@test de64compress(p7zip_64compress(d)) == d
@test de64compress_bytes(p7zip_64compress(d)) == d
@test_throws ErrorException de64compress(p7zip_64compress(d)[begin:end-1])
@test_throws DecompressionError de64compress(p7zip_64compress(d)[begin:end-1])
end
end

Expand Down
7 changes: 3 additions & 4 deletions test/tests_from_deflate64-rs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using Pkg.Artifacts: @artifact_str, ensure_artifact_installed

include("utils.jl")

#TODO Use custom errors for invalid deflate data.

@testset "tests from deflate64-rs" begin
ensure_artifact_installed("deflate64-rs", joinpath(@__DIR__,"Artifacts.toml"))
Expand All @@ -17,13 +16,13 @@ include("utils.jl")
@test de64compress(c) == u

c = read(joinpath(test_assets,"issue-23/raw_deflate64_index_out_of_bounds"))
@test_throws ErrorException("incomplete code table") de64compress(c)
@test_throws DecompressionError("incomplete code") de64compress(c)

c = read(joinpath(test_assets,"issue-25/deflate64_not_enough_space.zip"))[31:end]
@test_throws ErrorException("cannot read past beginning of out buffer dist: 65536") de64compress(c)
@test_throws DecompressionError("cannot read before beginning of out buffer") de64compress(c)

c = read(joinpath(test_assets,"issue-29/raw.zip"))[122:end]
@test_throws ErrorException("incomplete code table") de64compress(c)
@test_throws DecompressionError("incomplete code") de64compress(c)

c = read(joinpath(test_assets,"deflate64.zip"))[41:40+2669743]
stream = Deflate64DecompressorStream(IOBuffer(c))
Expand Down
8 changes: 4 additions & 4 deletions test/tests_from_inflate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ end
d3 = UInt8[0xed, 0x1c, 0xed, 0x72,
0xdb, 0x48, 0xf2, 0x3f] # incomplete code table
for d in [d1, d2, d3]
@test_throws ErrorException decompress(d)
@test_throws ErrorException decompress_bytes(d)
@test_throws ErrorException de64compress(d)
@test_throws ErrorException de64compress_bytes(d)
@test_throws DecompressionError decompress(d)
@test_throws DecompressionError decompress_bytes(d)
@test_throws DecompressionError de64compress(d)
@test_throws DecompressionError de64compress_bytes(d)
end
end

Expand Down

0 comments on commit 16dfc95

Please sign in to comment.