| // SPDX-License-Identifier: GPL-2.0 |
| |
| #include <string.h> |
| |
| #include "util/compress.h" |
| #include "util/debug.h" |
| |
| int zstd_init(struct zstd_data *data, int level) |
| { |
| size_t ret; |
| |
| data->dstream = ZSTD_createDStream(); |
| if (data->dstream == NULL) { |
| pr_err("Couldn't create decompression stream.\n"); |
| return -1; |
| } |
| |
| ret = ZSTD_initDStream(data->dstream); |
| if (ZSTD_isError(ret)) { |
| pr_err("Failed to initialize decompression stream: %s\n", ZSTD_getErrorName(ret)); |
| return -1; |
| } |
| |
| if (!level) |
| return 0; |
| |
| data->cstream = ZSTD_createCStream(); |
| if (data->cstream == NULL) { |
| pr_err("Couldn't create compression stream.\n"); |
| return -1; |
| } |
| |
| ret = ZSTD_initCStream(data->cstream, level); |
| if (ZSTD_isError(ret)) { |
| pr_err("Failed to initialize compression stream: %s\n", ZSTD_getErrorName(ret)); |
| return -1; |
| } |
| |
| return 0; |
| } |
| |
| int zstd_fini(struct zstd_data *data) |
| { |
| if (data->dstream) { |
| ZSTD_freeDStream(data->dstream); |
| data->dstream = NULL; |
| } |
| |
| if (data->cstream) { |
| ZSTD_freeCStream(data->cstream); |
| data->cstream = NULL; |
| } |
| |
| return 0; |
| } |
| |
| size_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size, |
| void *src, size_t src_size, size_t max_record_size, |
| size_t process_header(void *record, size_t increment)) |
| { |
| size_t ret, size, compressed = 0; |
| ZSTD_inBuffer input = { src, src_size, 0 }; |
| ZSTD_outBuffer output; |
| void *record; |
| |
| while (input.pos < input.size) { |
| record = dst; |
| size = process_header(record, 0); |
| compressed += size; |
| dst += size; |
| dst_size -= size; |
| output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ? |
| max_record_size : dst_size, 0 }; |
| ret = ZSTD_compressStream(data->cstream, &output, &input); |
| ZSTD_flushStream(data->cstream, &output); |
| if (ZSTD_isError(ret)) { |
| pr_err("failed to compress %ld bytes: %s\n", |
| (long)src_size, ZSTD_getErrorName(ret)); |
| memcpy(dst, src, src_size); |
| return src_size; |
| } |
| size = output.pos; |
| size = process_header(record, size); |
| compressed += size; |
| dst += size; |
| dst_size -= size; |
| } |
| |
| return compressed; |
| } |
| |
| size_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size, |
| void *dst, size_t dst_size) |
| { |
| size_t ret; |
| ZSTD_inBuffer input = { src, src_size, 0 }; |
| ZSTD_outBuffer output = { dst, dst_size, 0 }; |
| |
| while (input.pos < input.size) { |
| ret = ZSTD_decompressStream(data->dstream, &output, &input); |
| if (ZSTD_isError(ret)) { |
| pr_err("failed to decompress (B): %ld -> %ld, dst_size %ld : %s\n", |
| src_size, output.size, dst_size, ZSTD_getErrorName(ret)); |
| break; |
| } |
| output.dst = dst + output.pos; |
| output.size = dst_size - output.pos; |
| } |
| |
| return output.pos; |
| } |