xref: /openbmc/linux/tools/perf/util/zstd.c (revision e7bae9bb)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <string.h>
4 
5 #include "util/compress.h"
6 #include "util/debug.h"
7 
8 int zstd_init(struct zstd_data *data, int level)
9 {
10 	size_t ret;
11 
12 	data->dstream = ZSTD_createDStream();
13 	if (data->dstream == NULL) {
14 		pr_err("Couldn't create decompression stream.\n");
15 		return -1;
16 	}
17 
18 	ret = ZSTD_initDStream(data->dstream);
19 	if (ZSTD_isError(ret)) {
20 		pr_err("Failed to initialize decompression stream: %s\n", ZSTD_getErrorName(ret));
21 		return -1;
22 	}
23 
24 	if (!level)
25 		return 0;
26 
27 	data->cstream = ZSTD_createCStream();
28 	if (data->cstream == NULL) {
29 		pr_err("Couldn't create compression stream.\n");
30 		return -1;
31 	}
32 
33 	ret = ZSTD_initCStream(data->cstream, level);
34 	if (ZSTD_isError(ret)) {
35 		pr_err("Failed to initialize compression stream: %s\n", ZSTD_getErrorName(ret));
36 		return -1;
37 	}
38 
39 	return 0;
40 }
41 
42 int zstd_fini(struct zstd_data *data)
43 {
44 	if (data->dstream) {
45 		ZSTD_freeDStream(data->dstream);
46 		data->dstream = NULL;
47 	}
48 
49 	if (data->cstream) {
50 		ZSTD_freeCStream(data->cstream);
51 		data->cstream = NULL;
52 	}
53 
54 	return 0;
55 }
56 
57 size_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size,
58 				       void *src, size_t src_size, size_t max_record_size,
59 				       size_t process_header(void *record, size_t increment))
60 {
61 	size_t ret, size, compressed = 0;
62 	ZSTD_inBuffer input = { src, src_size, 0 };
63 	ZSTD_outBuffer output;
64 	void *record;
65 
66 	while (input.pos < input.size) {
67 		record = dst;
68 		size = process_header(record, 0);
69 		compressed += size;
70 		dst += size;
71 		dst_size -= size;
72 		output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ?
73 						max_record_size : dst_size, 0 };
74 		ret = ZSTD_compressStream(data->cstream, &output, &input);
75 		ZSTD_flushStream(data->cstream, &output);
76 		if (ZSTD_isError(ret)) {
77 			pr_err("failed to compress %ld bytes: %s\n",
78 				(long)src_size, ZSTD_getErrorName(ret));
79 			memcpy(dst, src, src_size);
80 			return src_size;
81 		}
82 		size = output.pos;
83 		size = process_header(record, size);
84 		compressed += size;
85 		dst += size;
86 		dst_size -= size;
87 	}
88 
89 	return compressed;
90 }
91 
92 size_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size,
93 			      void *dst, size_t dst_size)
94 {
95 	size_t ret;
96 	ZSTD_inBuffer input = { src, src_size, 0 };
97 	ZSTD_outBuffer output = { dst, dst_size, 0 };
98 
99 	while (input.pos < input.size) {
100 		ret = ZSTD_decompressStream(data->dstream, &output, &input);
101 		if (ZSTD_isError(ret)) {
102 			pr_err("failed to decompress (B): %zd -> %zd, dst_size %zd : %s\n",
103 			       src_size, output.size, dst_size, ZSTD_getErrorName(ret));
104 			break;
105 		}
106 		output.dst  = dst + output.pos;
107 		output.size = dst_size - output.pos;
108 	}
109 
110 	return output.pos;
111 }
112