1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use crate::exceptions::PyBaseExceptionRef;
use crate::function::OptionalArg;
use crate::obj::objbytes::PyBytesRef;
use crate::pyobject::{ItemProtocol, PyObjectRef, PyResult};
use crate::types::create_type;
use crate::vm::VirtualMachine;
use adler32::RollingAdler32 as Adler32;
use crc32fast::Hasher as Crc32;
use flate2::{write::ZlibEncoder, Compression, Decompress, FlushDecompress, Status};
use libz_sys as libz;
use std::io::Write;
const MAX_WBITS: u8 = 15;
const DEF_BUF_SIZE: usize = 16 * 1024;
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let ctx = &vm.ctx;
let zlib_error = create_type(
"error",
&ctx.types.type_type,
&ctx.exceptions.exception_type,
);
py_module!(vm, "zlib", {
"crc32" => ctx.new_function(zlib_crc32),
"adler32" => ctx.new_function(zlib_adler32),
"compress" => ctx.new_function(zlib_compress),
"decompress" => ctx.new_function(zlib_decompress),
"error" => zlib_error,
"Z_DEFAULT_COMPRESSION" => ctx.new_int(libz::Z_DEFAULT_COMPRESSION),
"Z_NO_COMPRESSION" => ctx.new_int(libz::Z_NO_COMPRESSION),
"Z_BEST_SPEED" => ctx.new_int(libz::Z_BEST_SPEED),
"Z_BEST_COMPRESSION" => ctx.new_int(libz::Z_BEST_COMPRESSION),
"DEF_BUF_SIZE" => ctx.new_int(DEF_BUF_SIZE),
"MAX_WBITS" => ctx.new_int(MAX_WBITS),
})
}
fn zlib_adler32(data: PyBytesRef, begin_state: OptionalArg<i32>, vm: &VirtualMachine) -> PyResult {
let data = data.get_value();
let begin_state = begin_state.unwrap_or(1);
let mut hasher = Adler32::from_value(begin_state as u32);
hasher.update_buffer(data);
let checksum: u32 = hasher.hash();
Ok(vm.new_int(checksum))
}
fn zlib_crc32(data: PyBytesRef, begin_state: OptionalArg<i32>, vm: &VirtualMachine) -> PyResult {
let data = data.get_value();
let begin_state = begin_state.unwrap_or(0);
let mut hasher = Crc32::new_with_initial(begin_state as u32);
hasher.update(data);
let checksum: u32 = hasher.finalize();
Ok(vm.new_int(checksum))
}
fn zlib_compress(data: PyBytesRef, level: OptionalArg<i32>, vm: &VirtualMachine) -> PyResult {
let input_bytes = data.get_value();
let level = level.unwrap_or(libz::Z_DEFAULT_COMPRESSION);
let compression = match level {
valid_level @ libz::Z_NO_COMPRESSION..=libz::Z_BEST_COMPRESSION => {
Compression::new(valid_level as u32)
}
libz::Z_DEFAULT_COMPRESSION => Compression::default(),
_ => return Err(zlib_error("Bad compression level", vm)),
};
let mut encoder = ZlibEncoder::new(Vec::new(), compression);
encoder.write_all(input_bytes).unwrap();
let encoded_bytes = encoder.finish().unwrap();
Ok(vm.ctx.new_bytes(encoded_bytes))
}
fn zlib_decompress(
data: PyBytesRef,
wbits: OptionalArg<u8>,
bufsize: OptionalArg<usize>,
vm: &VirtualMachine,
) -> PyResult {
let encoded_bytes = data.get_value();
let wbits = wbits.unwrap_or(MAX_WBITS);
let bufsize = bufsize.unwrap_or(DEF_BUF_SIZE);
let mut decompressor = Decompress::new_with_window_bits(true, wbits);
let mut decoded_bytes = Vec::with_capacity(bufsize);
match decompressor.decompress_vec(&encoded_bytes, &mut decoded_bytes, FlushDecompress::Finish) {
Ok(Status::BufError) => Err(zlib_error("inconsistent or truncated state", vm)),
Err(_) => Err(zlib_error("invalid input data", vm)),
_ => Ok(vm.ctx.new_bytes(decoded_bytes)),
}
}
fn zlib_error(message: &str, vm: &VirtualMachine) -> PyBaseExceptionRef {
let module = vm
.get_attribute(vm.sys_module.clone(), "modules")
.unwrap()
.get_item("zlib", vm)
.unwrap();
let zlib_error = vm.get_attribute(module, "error").unwrap();
let zlib_error = zlib_error.downcast().unwrap();
vm.new_exception_msg(zlib_error, message.to_owned())
}