xref: /openbmc/linux/fs/btrfs/zstd.c (revision e0f6d1a5)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (c) 2016-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  */
7 
8 #include <linux/bio.h>
9 #include <linux/err.h>
10 #include <linux/init.h>
11 #include <linux/kernel.h>
12 #include <linux/mm.h>
13 #include <linux/pagemap.h>
14 #include <linux/refcount.h>
15 #include <linux/sched.h>
16 #include <linux/slab.h>
17 #include <linux/zstd.h>
18 #include "compression.h"
19 
20 #define ZSTD_BTRFS_MAX_WINDOWLOG 17
21 #define ZSTD_BTRFS_MAX_INPUT (1 << ZSTD_BTRFS_MAX_WINDOWLOG)
22 #define ZSTD_BTRFS_DEFAULT_LEVEL 3
23 
24 static ZSTD_parameters zstd_get_btrfs_parameters(size_t src_len)
25 {
26 	ZSTD_parameters params = ZSTD_getParams(ZSTD_BTRFS_DEFAULT_LEVEL,
27 						src_len, 0);
28 
29 	if (params.cParams.windowLog > ZSTD_BTRFS_MAX_WINDOWLOG)
30 		params.cParams.windowLog = ZSTD_BTRFS_MAX_WINDOWLOG;
31 	WARN_ON(src_len > ZSTD_BTRFS_MAX_INPUT);
32 	return params;
33 }
34 
35 struct workspace {
36 	void *mem;
37 	size_t size;
38 	char *buf;
39 	struct list_head list;
40 	ZSTD_inBuffer in_buf;
41 	ZSTD_outBuffer out_buf;
42 };
43 
44 static void zstd_free_workspace(struct list_head *ws)
45 {
46 	struct workspace *workspace = list_entry(ws, struct workspace, list);
47 
48 	kvfree(workspace->mem);
49 	kfree(workspace->buf);
50 	kfree(workspace);
51 }
52 
53 static struct list_head *zstd_alloc_workspace(void)
54 {
55 	ZSTD_parameters params =
56 			zstd_get_btrfs_parameters(ZSTD_BTRFS_MAX_INPUT);
57 	struct workspace *workspace;
58 
59 	workspace = kzalloc(sizeof(*workspace), GFP_KERNEL);
60 	if (!workspace)
61 		return ERR_PTR(-ENOMEM);
62 
63 	workspace->size = max_t(size_t,
64 			ZSTD_CStreamWorkspaceBound(params.cParams),
65 			ZSTD_DStreamWorkspaceBound(ZSTD_BTRFS_MAX_INPUT));
66 	workspace->mem = kvmalloc(workspace->size, GFP_KERNEL);
67 	workspace->buf = kmalloc(PAGE_SIZE, GFP_KERNEL);
68 	if (!workspace->mem || !workspace->buf)
69 		goto fail;
70 
71 	INIT_LIST_HEAD(&workspace->list);
72 
73 	return &workspace->list;
74 fail:
75 	zstd_free_workspace(&workspace->list);
76 	return ERR_PTR(-ENOMEM);
77 }
78 
79 static int zstd_compress_pages(struct list_head *ws,
80 		struct address_space *mapping,
81 		u64 start,
82 		struct page **pages,
83 		unsigned long *out_pages,
84 		unsigned long *total_in,
85 		unsigned long *total_out)
86 {
87 	struct workspace *workspace = list_entry(ws, struct workspace, list);
88 	ZSTD_CStream *stream;
89 	int ret = 0;
90 	int nr_pages = 0;
91 	struct page *in_page = NULL;  /* The current page to read */
92 	struct page *out_page = NULL; /* The current page to write to */
93 	unsigned long tot_in = 0;
94 	unsigned long tot_out = 0;
95 	unsigned long len = *total_out;
96 	const unsigned long nr_dest_pages = *out_pages;
97 	unsigned long max_out = nr_dest_pages * PAGE_SIZE;
98 	ZSTD_parameters params = zstd_get_btrfs_parameters(len);
99 
100 	*out_pages = 0;
101 	*total_out = 0;
102 	*total_in = 0;
103 
104 	/* Initialize the stream */
105 	stream = ZSTD_initCStream(params, len, workspace->mem,
106 			workspace->size);
107 	if (!stream) {
108 		pr_warn("BTRFS: ZSTD_initCStream failed\n");
109 		ret = -EIO;
110 		goto out;
111 	}
112 
113 	/* map in the first page of input data */
114 	in_page = find_get_page(mapping, start >> PAGE_SHIFT);
115 	workspace->in_buf.src = kmap(in_page);
116 	workspace->in_buf.pos = 0;
117 	workspace->in_buf.size = min_t(size_t, len, PAGE_SIZE);
118 
119 
120 	/* Allocate and map in the output buffer */
121 	out_page = alloc_page(GFP_NOFS | __GFP_HIGHMEM);
122 	if (out_page == NULL) {
123 		ret = -ENOMEM;
124 		goto out;
125 	}
126 	pages[nr_pages++] = out_page;
127 	workspace->out_buf.dst = kmap(out_page);
128 	workspace->out_buf.pos = 0;
129 	workspace->out_buf.size = min_t(size_t, max_out, PAGE_SIZE);
130 
131 	while (1) {
132 		size_t ret2;
133 
134 		ret2 = ZSTD_compressStream(stream, &workspace->out_buf,
135 				&workspace->in_buf);
136 		if (ZSTD_isError(ret2)) {
137 			pr_debug("BTRFS: ZSTD_compressStream returned %d\n",
138 					ZSTD_getErrorCode(ret2));
139 			ret = -EIO;
140 			goto out;
141 		}
142 
143 		/* Check to see if we are making it bigger */
144 		if (tot_in + workspace->in_buf.pos > 8192 &&
145 				tot_in + workspace->in_buf.pos <
146 				tot_out + workspace->out_buf.pos) {
147 			ret = -E2BIG;
148 			goto out;
149 		}
150 
151 		/* We've reached the end of our output range */
152 		if (workspace->out_buf.pos >= max_out) {
153 			tot_out += workspace->out_buf.pos;
154 			ret = -E2BIG;
155 			goto out;
156 		}
157 
158 		/* Check if we need more output space */
159 		if (workspace->out_buf.pos == workspace->out_buf.size) {
160 			tot_out += PAGE_SIZE;
161 			max_out -= PAGE_SIZE;
162 			kunmap(out_page);
163 			if (nr_pages == nr_dest_pages) {
164 				out_page = NULL;
165 				ret = -E2BIG;
166 				goto out;
167 			}
168 			out_page = alloc_page(GFP_NOFS | __GFP_HIGHMEM);
169 			if (out_page == NULL) {
170 				ret = -ENOMEM;
171 				goto out;
172 			}
173 			pages[nr_pages++] = out_page;
174 			workspace->out_buf.dst = kmap(out_page);
175 			workspace->out_buf.pos = 0;
176 			workspace->out_buf.size = min_t(size_t, max_out,
177 							PAGE_SIZE);
178 		}
179 
180 		/* We've reached the end of the input */
181 		if (workspace->in_buf.pos >= len) {
182 			tot_in += workspace->in_buf.pos;
183 			break;
184 		}
185 
186 		/* Check if we need more input */
187 		if (workspace->in_buf.pos == workspace->in_buf.size) {
188 			tot_in += PAGE_SIZE;
189 			kunmap(in_page);
190 			put_page(in_page);
191 
192 			start += PAGE_SIZE;
193 			len -= PAGE_SIZE;
194 			in_page = find_get_page(mapping, start >> PAGE_SHIFT);
195 			workspace->in_buf.src = kmap(in_page);
196 			workspace->in_buf.pos = 0;
197 			workspace->in_buf.size = min_t(size_t, len, PAGE_SIZE);
198 		}
199 	}
200 	while (1) {
201 		size_t ret2;
202 
203 		ret2 = ZSTD_endStream(stream, &workspace->out_buf);
204 		if (ZSTD_isError(ret2)) {
205 			pr_debug("BTRFS: ZSTD_endStream returned %d\n",
206 					ZSTD_getErrorCode(ret2));
207 			ret = -EIO;
208 			goto out;
209 		}
210 		if (ret2 == 0) {
211 			tot_out += workspace->out_buf.pos;
212 			break;
213 		}
214 		if (workspace->out_buf.pos >= max_out) {
215 			tot_out += workspace->out_buf.pos;
216 			ret = -E2BIG;
217 			goto out;
218 		}
219 
220 		tot_out += PAGE_SIZE;
221 		max_out -= PAGE_SIZE;
222 		kunmap(out_page);
223 		if (nr_pages == nr_dest_pages) {
224 			out_page = NULL;
225 			ret = -E2BIG;
226 			goto out;
227 		}
228 		out_page = alloc_page(GFP_NOFS | __GFP_HIGHMEM);
229 		if (out_page == NULL) {
230 			ret = -ENOMEM;
231 			goto out;
232 		}
233 		pages[nr_pages++] = out_page;
234 		workspace->out_buf.dst = kmap(out_page);
235 		workspace->out_buf.pos = 0;
236 		workspace->out_buf.size = min_t(size_t, max_out, PAGE_SIZE);
237 	}
238 
239 	if (tot_out >= tot_in) {
240 		ret = -E2BIG;
241 		goto out;
242 	}
243 
244 	ret = 0;
245 	*total_in = tot_in;
246 	*total_out = tot_out;
247 out:
248 	*out_pages = nr_pages;
249 	/* Cleanup */
250 	if (in_page) {
251 		kunmap(in_page);
252 		put_page(in_page);
253 	}
254 	if (out_page)
255 		kunmap(out_page);
256 	return ret;
257 }
258 
259 static int zstd_decompress_bio(struct list_head *ws, struct compressed_bio *cb)
260 {
261 	struct workspace *workspace = list_entry(ws, struct workspace, list);
262 	struct page **pages_in = cb->compressed_pages;
263 	u64 disk_start = cb->start;
264 	struct bio *orig_bio = cb->orig_bio;
265 	size_t srclen = cb->compressed_len;
266 	ZSTD_DStream *stream;
267 	int ret = 0;
268 	unsigned long page_in_index = 0;
269 	unsigned long total_pages_in = DIV_ROUND_UP(srclen, PAGE_SIZE);
270 	unsigned long buf_start;
271 	unsigned long total_out = 0;
272 
273 	stream = ZSTD_initDStream(
274 			ZSTD_BTRFS_MAX_INPUT, workspace->mem, workspace->size);
275 	if (!stream) {
276 		pr_debug("BTRFS: ZSTD_initDStream failed\n");
277 		ret = -EIO;
278 		goto done;
279 	}
280 
281 	workspace->in_buf.src = kmap(pages_in[page_in_index]);
282 	workspace->in_buf.pos = 0;
283 	workspace->in_buf.size = min_t(size_t, srclen, PAGE_SIZE);
284 
285 	workspace->out_buf.dst = workspace->buf;
286 	workspace->out_buf.pos = 0;
287 	workspace->out_buf.size = PAGE_SIZE;
288 
289 	while (1) {
290 		size_t ret2;
291 
292 		ret2 = ZSTD_decompressStream(stream, &workspace->out_buf,
293 				&workspace->in_buf);
294 		if (ZSTD_isError(ret2)) {
295 			pr_debug("BTRFS: ZSTD_decompressStream returned %d\n",
296 					ZSTD_getErrorCode(ret2));
297 			ret = -EIO;
298 			goto done;
299 		}
300 		buf_start = total_out;
301 		total_out += workspace->out_buf.pos;
302 		workspace->out_buf.pos = 0;
303 
304 		ret = btrfs_decompress_buf2page(workspace->out_buf.dst,
305 				buf_start, total_out, disk_start, orig_bio);
306 		if (ret == 0)
307 			break;
308 
309 		if (workspace->in_buf.pos >= srclen)
310 			break;
311 
312 		/* Check if we've hit the end of a frame */
313 		if (ret2 == 0)
314 			break;
315 
316 		if (workspace->in_buf.pos == workspace->in_buf.size) {
317 			kunmap(pages_in[page_in_index++]);
318 			if (page_in_index >= total_pages_in) {
319 				workspace->in_buf.src = NULL;
320 				ret = -EIO;
321 				goto done;
322 			}
323 			srclen -= PAGE_SIZE;
324 			workspace->in_buf.src = kmap(pages_in[page_in_index]);
325 			workspace->in_buf.pos = 0;
326 			workspace->in_buf.size = min_t(size_t, srclen, PAGE_SIZE);
327 		}
328 	}
329 	ret = 0;
330 	zero_fill_bio(orig_bio);
331 done:
332 	if (workspace->in_buf.src)
333 		kunmap(pages_in[page_in_index]);
334 	return ret;
335 }
336 
337 static int zstd_decompress(struct list_head *ws, unsigned char *data_in,
338 		struct page *dest_page,
339 		unsigned long start_byte,
340 		size_t srclen, size_t destlen)
341 {
342 	struct workspace *workspace = list_entry(ws, struct workspace, list);
343 	ZSTD_DStream *stream;
344 	int ret = 0;
345 	size_t ret2;
346 	unsigned long total_out = 0;
347 	unsigned long pg_offset = 0;
348 	char *kaddr;
349 
350 	stream = ZSTD_initDStream(
351 			ZSTD_BTRFS_MAX_INPUT, workspace->mem, workspace->size);
352 	if (!stream) {
353 		pr_warn("BTRFS: ZSTD_initDStream failed\n");
354 		ret = -EIO;
355 		goto finish;
356 	}
357 
358 	destlen = min_t(size_t, destlen, PAGE_SIZE);
359 
360 	workspace->in_buf.src = data_in;
361 	workspace->in_buf.pos = 0;
362 	workspace->in_buf.size = srclen;
363 
364 	workspace->out_buf.dst = workspace->buf;
365 	workspace->out_buf.pos = 0;
366 	workspace->out_buf.size = PAGE_SIZE;
367 
368 	ret2 = 1;
369 	while (pg_offset < destlen
370 	       && workspace->in_buf.pos < workspace->in_buf.size) {
371 		unsigned long buf_start;
372 		unsigned long buf_offset;
373 		unsigned long bytes;
374 
375 		/* Check if the frame is over and we still need more input */
376 		if (ret2 == 0) {
377 			pr_debug("BTRFS: ZSTD_decompressStream ended early\n");
378 			ret = -EIO;
379 			goto finish;
380 		}
381 		ret2 = ZSTD_decompressStream(stream, &workspace->out_buf,
382 				&workspace->in_buf);
383 		if (ZSTD_isError(ret2)) {
384 			pr_debug("BTRFS: ZSTD_decompressStream returned %d\n",
385 					ZSTD_getErrorCode(ret2));
386 			ret = -EIO;
387 			goto finish;
388 		}
389 
390 		buf_start = total_out;
391 		total_out += workspace->out_buf.pos;
392 		workspace->out_buf.pos = 0;
393 
394 		if (total_out <= start_byte)
395 			continue;
396 
397 		if (total_out > start_byte && buf_start < start_byte)
398 			buf_offset = start_byte - buf_start;
399 		else
400 			buf_offset = 0;
401 
402 		bytes = min_t(unsigned long, destlen - pg_offset,
403 				workspace->out_buf.size - buf_offset);
404 
405 		kaddr = kmap_atomic(dest_page);
406 		memcpy(kaddr + pg_offset, workspace->out_buf.dst + buf_offset,
407 				bytes);
408 		kunmap_atomic(kaddr);
409 
410 		pg_offset += bytes;
411 	}
412 	ret = 0;
413 finish:
414 	if (pg_offset < destlen) {
415 		kaddr = kmap_atomic(dest_page);
416 		memset(kaddr + pg_offset, 0, destlen - pg_offset);
417 		kunmap_atomic(kaddr);
418 	}
419 	return ret;
420 }
421 
422 static void zstd_set_level(struct list_head *ws, unsigned int type)
423 {
424 }
425 
426 const struct btrfs_compress_op btrfs_zstd_compress = {
427 	.alloc_workspace = zstd_alloc_workspace,
428 	.free_workspace = zstd_free_workspace,
429 	.compress_pages = zstd_compress_pages,
430 	.decompress_bio = zstd_decompress_bio,
431 	.decompress = zstd_decompress,
432 	.set_level = zstd_set_level,
433 };
434