1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use crate::exceptions::PyBaseExceptionRef;
use crate::function::OptionalArg;
use crate::obj::objbytes::PyBytesRef;
use crate::pyobject::{ItemProtocol, PyObjectRef, PyResult};
use crate::types::create_type;
use crate::vm::VirtualMachine;

use adler32::RollingAdler32 as Adler32;
use crc32fast::Hasher as Crc32;
use flate2::{write::ZlibEncoder, Compression, Decompress, FlushDecompress, Status};
use libz_sys as libz;

use std::io::Write;

// copied from zlibmodule.c (commit 530f506ac91338)
const MAX_WBITS: u8 = 15;
const DEF_BUF_SIZE: usize = 16 * 1024;

pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
    let ctx = &vm.ctx;

    let zlib_error = create_type(
        "error",
        &ctx.types.type_type,
        &ctx.exceptions.exception_type,
    );

    py_module!(vm, "zlib", {
        "crc32" => ctx.new_function(zlib_crc32),
        "adler32" => ctx.new_function(zlib_adler32),
        "compress" => ctx.new_function(zlib_compress),
        "decompress" => ctx.new_function(zlib_decompress),
        "error" => zlib_error,
        "Z_DEFAULT_COMPRESSION" => ctx.new_int(libz::Z_DEFAULT_COMPRESSION),
        "Z_NO_COMPRESSION" => ctx.new_int(libz::Z_NO_COMPRESSION),
        "Z_BEST_SPEED" => ctx.new_int(libz::Z_BEST_SPEED),
        "Z_BEST_COMPRESSION" => ctx.new_int(libz::Z_BEST_COMPRESSION),
        "DEF_BUF_SIZE" => ctx.new_int(DEF_BUF_SIZE),
        "MAX_WBITS" => ctx.new_int(MAX_WBITS),
    })
}

/// Compute an Adler-32 checksum of data.
fn zlib_adler32(data: PyBytesRef, begin_state: OptionalArg<i32>, vm: &VirtualMachine) -> PyResult {
    let data = data.get_value();

    let begin_state = begin_state.unwrap_or(1);

    let mut hasher = Adler32::from_value(begin_state as u32);
    hasher.update_buffer(data);

    let checksum: u32 = hasher.hash();

    Ok(vm.new_int(checksum))
}

/// Compute a CRC-32 checksum of data.
fn zlib_crc32(data: PyBytesRef, begin_state: OptionalArg<i32>, vm: &VirtualMachine) -> PyResult {
    let data = data.get_value();

    let begin_state = begin_state.unwrap_or(0);

    let mut hasher = Crc32::new_with_initial(begin_state as u32);
    hasher.update(data);

    let checksum: u32 = hasher.finalize();

    Ok(vm.new_int(checksum))
}

/// Returns a bytes object containing compressed data.
fn zlib_compress(data: PyBytesRef, level: OptionalArg<i32>, vm: &VirtualMachine) -> PyResult {
    let input_bytes = data.get_value();

    let level = level.unwrap_or(libz::Z_DEFAULT_COMPRESSION);

    let compression = match level {
        valid_level @ libz::Z_NO_COMPRESSION..=libz::Z_BEST_COMPRESSION => {
            Compression::new(valid_level as u32)
        }
        libz::Z_DEFAULT_COMPRESSION => Compression::default(),
        _ => return Err(zlib_error("Bad compression level", vm)),
    };

    let mut encoder = ZlibEncoder::new(Vec::new(), compression);
    encoder.write_all(input_bytes).unwrap();
    let encoded_bytes = encoder.finish().unwrap();

    Ok(vm.ctx.new_bytes(encoded_bytes))
}

/// Returns a bytes object containing the uncompressed data.
fn zlib_decompress(
    data: PyBytesRef,
    wbits: OptionalArg<u8>,
    bufsize: OptionalArg<usize>,
    vm: &VirtualMachine,
) -> PyResult {
    let encoded_bytes = data.get_value();

    let wbits = wbits.unwrap_or(MAX_WBITS);
    let bufsize = bufsize.unwrap_or(DEF_BUF_SIZE);

    let mut decompressor = Decompress::new_with_window_bits(true, wbits);
    let mut decoded_bytes = Vec::with_capacity(bufsize);

    match decompressor.decompress_vec(&encoded_bytes, &mut decoded_bytes, FlushDecompress::Finish) {
        Ok(Status::BufError) => Err(zlib_error("inconsistent or truncated state", vm)),
        Err(_) => Err(zlib_error("invalid input data", vm)),
        _ => Ok(vm.ctx.new_bytes(decoded_bytes)),
    }
}

fn zlib_error(message: &str, vm: &VirtualMachine) -> PyBaseExceptionRef {
    let module = vm
        .get_attribute(vm.sys_module.clone(), "modules")
        .unwrap()
        .get_item("zlib", vm)
        .unwrap();

    let zlib_error = vm.get_attribute(module, "error").unwrap();
    let zlib_error = zlib_error.downcast().unwrap();

    vm.new_exception_msg(zlib_error, message.to_owned())
}