1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * IP Payload Compression Protocol (IPComp) - RFC3173. 4 * 5 * Copyright (c) 2003 James Morris <jmorris@intercode.com.au> 6 * Copyright (c) 2003-2008 Herbert Xu <herbert@gondor.apana.org.au> 7 * 8 * Todo: 9 * - Tunable compression parameters. 10 * - Compression stats. 11 * - Adaptive compression. 12 */ 13 14 #include <linux/crypto.h> 15 #include <linux/err.h> 16 #include <linux/list.h> 17 #include <linux/module.h> 18 #include <linux/mutex.h> 19 #include <linux/percpu.h> 20 #include <linux/slab.h> 21 #include <linux/smp.h> 22 #include <linux/vmalloc.h> 23 #include <net/ip.h> 24 #include <net/ipcomp.h> 25 #include <net/xfrm.h> 26 27 struct ipcomp_tfms { 28 struct list_head list; 29 struct crypto_comp * __percpu *tfms; 30 int users; 31 }; 32 33 static DEFINE_MUTEX(ipcomp_resource_mutex); 34 static void * __percpu *ipcomp_scratches; 35 static int ipcomp_scratch_users; 36 static LIST_HEAD(ipcomp_tfms_list); 37 38 static int ipcomp_decompress(struct xfrm_state *x, struct sk_buff *skb) 39 { 40 struct ipcomp_data *ipcd = x->data; 41 const int plen = skb->len; 42 int dlen = IPCOMP_SCRATCH_SIZE; 43 const u8 *start = skb->data; 44 const int cpu = get_cpu(); 45 u8 *scratch = *per_cpu_ptr(ipcomp_scratches, cpu); 46 struct crypto_comp *tfm = *per_cpu_ptr(ipcd->tfms, cpu); 47 int err = crypto_comp_decompress(tfm, start, plen, scratch, &dlen); 48 int len; 49 50 if (err) 51 goto out; 52 53 if (dlen < (plen + sizeof(struct ip_comp_hdr))) { 54 err = -EINVAL; 55 goto out; 56 } 57 58 len = dlen - plen; 59 if (len > skb_tailroom(skb)) 60 len = skb_tailroom(skb); 61 62 __skb_put(skb, len); 63 64 len += plen; 65 skb_copy_to_linear_data(skb, scratch, len); 66 67 while ((scratch += len, dlen -= len) > 0) { 68 skb_frag_t *frag; 69 struct page *page; 70 71 err = -EMSGSIZE; 72 if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS)) 73 goto out; 74 75 frag = skb_shinfo(skb)->frags + skb_shinfo(skb)->nr_frags; 76 page = alloc_page(GFP_ATOMIC); 77 78 err = -ENOMEM; 79 if (!page) 80 goto out; 81 82 __skb_frag_set_page(frag, page); 83 84 len = PAGE_SIZE; 85 if (dlen < len) 86 len = dlen; 87 88 skb_frag_off_set(frag, 0); 89 skb_frag_size_set(frag, len); 90 memcpy(skb_frag_address(frag), scratch, len); 91 92 skb->truesize += len; 93 skb->data_len += len; 94 skb->len += len; 95 96 skb_shinfo(skb)->nr_frags++; 97 } 98 99 err = 0; 100 101 out: 102 put_cpu(); 103 return err; 104 } 105 106 int ipcomp_input(struct xfrm_state *x, struct sk_buff *skb) 107 { 108 int nexthdr; 109 int err = -ENOMEM; 110 struct ip_comp_hdr *ipch; 111 112 if (skb_linearize_cow(skb)) 113 goto out; 114 115 skb->ip_summed = CHECKSUM_NONE; 116 117 /* Remove ipcomp header and decompress original payload */ 118 ipch = (void *)skb->data; 119 nexthdr = ipch->nexthdr; 120 121 skb->transport_header = skb->network_header + sizeof(*ipch); 122 __skb_pull(skb, sizeof(*ipch)); 123 err = ipcomp_decompress(x, skb); 124 if (err) 125 goto out; 126 127 err = nexthdr; 128 129 out: 130 return err; 131 } 132 EXPORT_SYMBOL_GPL(ipcomp_input); 133 134 static int ipcomp_compress(struct xfrm_state *x, struct sk_buff *skb) 135 { 136 struct ipcomp_data *ipcd = x->data; 137 const int plen = skb->len; 138 int dlen = IPCOMP_SCRATCH_SIZE; 139 u8 *start = skb->data; 140 struct crypto_comp *tfm; 141 u8 *scratch; 142 int err; 143 144 local_bh_disable(); 145 scratch = *this_cpu_ptr(ipcomp_scratches); 146 tfm = *this_cpu_ptr(ipcd->tfms); 147 err = crypto_comp_compress(tfm, start, plen, scratch, &dlen); 148 if (err) 149 goto out; 150 151 if ((dlen + sizeof(struct ip_comp_hdr)) >= plen) { 152 err = -EMSGSIZE; 153 goto out; 154 } 155 156 memcpy(start + sizeof(struct ip_comp_hdr), scratch, dlen); 157 local_bh_enable(); 158 159 pskb_trim(skb, dlen + sizeof(struct ip_comp_hdr)); 160 return 0; 161 162 out: 163 local_bh_enable(); 164 return err; 165 } 166 167 int ipcomp_output(struct xfrm_state *x, struct sk_buff *skb) 168 { 169 int err; 170 struct ip_comp_hdr *ipch; 171 struct ipcomp_data *ipcd = x->data; 172 173 if (skb->len < ipcd->threshold) { 174 /* Don't bother compressing */ 175 goto out_ok; 176 } 177 178 if (skb_linearize_cow(skb)) 179 goto out_ok; 180 181 err = ipcomp_compress(x, skb); 182 183 if (err) { 184 goto out_ok; 185 } 186 187 /* Install ipcomp header, convert into ipcomp datagram. */ 188 ipch = ip_comp_hdr(skb); 189 ipch->nexthdr = *skb_mac_header(skb); 190 ipch->flags = 0; 191 ipch->cpi = htons((u16 )ntohl(x->id.spi)); 192 *skb_mac_header(skb) = IPPROTO_COMP; 193 out_ok: 194 skb_push(skb, -skb_network_offset(skb)); 195 return 0; 196 } 197 EXPORT_SYMBOL_GPL(ipcomp_output); 198 199 static void ipcomp_free_scratches(void) 200 { 201 int i; 202 void * __percpu *scratches; 203 204 if (--ipcomp_scratch_users) 205 return; 206 207 scratches = ipcomp_scratches; 208 if (!scratches) 209 return; 210 211 for_each_possible_cpu(i) 212 vfree(*per_cpu_ptr(scratches, i)); 213 214 free_percpu(scratches); 215 } 216 217 static void * __percpu *ipcomp_alloc_scratches(void) 218 { 219 void * __percpu *scratches; 220 int i; 221 222 if (ipcomp_scratch_users++) 223 return ipcomp_scratches; 224 225 scratches = alloc_percpu(void *); 226 if (!scratches) 227 return NULL; 228 229 ipcomp_scratches = scratches; 230 231 for_each_possible_cpu(i) { 232 void *scratch; 233 234 scratch = vmalloc_node(IPCOMP_SCRATCH_SIZE, cpu_to_node(i)); 235 if (!scratch) 236 return NULL; 237 *per_cpu_ptr(scratches, i) = scratch; 238 } 239 240 return scratches; 241 } 242 243 static void ipcomp_free_tfms(struct crypto_comp * __percpu *tfms) 244 { 245 struct ipcomp_tfms *pos; 246 int cpu; 247 248 list_for_each_entry(pos, &ipcomp_tfms_list, list) { 249 if (pos->tfms == tfms) 250 break; 251 } 252 253 WARN_ON(!pos); 254 255 if (--pos->users) 256 return; 257 258 list_del(&pos->list); 259 kfree(pos); 260 261 if (!tfms) 262 return; 263 264 for_each_possible_cpu(cpu) { 265 struct crypto_comp *tfm = *per_cpu_ptr(tfms, cpu); 266 crypto_free_comp(tfm); 267 } 268 free_percpu(tfms); 269 } 270 271 static struct crypto_comp * __percpu *ipcomp_alloc_tfms(const char *alg_name) 272 { 273 struct ipcomp_tfms *pos; 274 struct crypto_comp * __percpu *tfms; 275 int cpu; 276 277 278 list_for_each_entry(pos, &ipcomp_tfms_list, list) { 279 struct crypto_comp *tfm; 280 281 /* This can be any valid CPU ID so we don't need locking. */ 282 tfm = this_cpu_read(*pos->tfms); 283 284 if (!strcmp(crypto_comp_name(tfm), alg_name)) { 285 pos->users++; 286 return pos->tfms; 287 } 288 } 289 290 pos = kmalloc(sizeof(*pos), GFP_KERNEL); 291 if (!pos) 292 return NULL; 293 294 pos->users = 1; 295 INIT_LIST_HEAD(&pos->list); 296 list_add(&pos->list, &ipcomp_tfms_list); 297 298 pos->tfms = tfms = alloc_percpu(struct crypto_comp *); 299 if (!tfms) 300 goto error; 301 302 for_each_possible_cpu(cpu) { 303 struct crypto_comp *tfm = crypto_alloc_comp(alg_name, 0, 304 CRYPTO_ALG_ASYNC); 305 if (IS_ERR(tfm)) 306 goto error; 307 *per_cpu_ptr(tfms, cpu) = tfm; 308 } 309 310 return tfms; 311 312 error: 313 ipcomp_free_tfms(tfms); 314 return NULL; 315 } 316 317 static void ipcomp_free_data(struct ipcomp_data *ipcd) 318 { 319 if (ipcd->tfms) 320 ipcomp_free_tfms(ipcd->tfms); 321 ipcomp_free_scratches(); 322 } 323 324 void ipcomp_destroy(struct xfrm_state *x) 325 { 326 struct ipcomp_data *ipcd = x->data; 327 if (!ipcd) 328 return; 329 xfrm_state_delete_tunnel(x); 330 mutex_lock(&ipcomp_resource_mutex); 331 ipcomp_free_data(ipcd); 332 mutex_unlock(&ipcomp_resource_mutex); 333 kfree(ipcd); 334 } 335 EXPORT_SYMBOL_GPL(ipcomp_destroy); 336 337 int ipcomp_init_state(struct xfrm_state *x) 338 { 339 int err; 340 struct ipcomp_data *ipcd; 341 struct xfrm_algo_desc *calg_desc; 342 343 err = -EINVAL; 344 if (!x->calg) 345 goto out; 346 347 if (x->encap) 348 goto out; 349 350 err = -ENOMEM; 351 ipcd = kzalloc(sizeof(*ipcd), GFP_KERNEL); 352 if (!ipcd) 353 goto out; 354 355 mutex_lock(&ipcomp_resource_mutex); 356 if (!ipcomp_alloc_scratches()) 357 goto error; 358 359 ipcd->tfms = ipcomp_alloc_tfms(x->calg->alg_name); 360 if (!ipcd->tfms) 361 goto error; 362 mutex_unlock(&ipcomp_resource_mutex); 363 364 calg_desc = xfrm_calg_get_byname(x->calg->alg_name, 0); 365 BUG_ON(!calg_desc); 366 ipcd->threshold = calg_desc->uinfo.comp.threshold; 367 x->data = ipcd; 368 err = 0; 369 out: 370 return err; 371 372 error: 373 ipcomp_free_data(ipcd); 374 mutex_unlock(&ipcomp_resource_mutex); 375 kfree(ipcd); 376 goto out; 377 } 378 EXPORT_SYMBOL_GPL(ipcomp_init_state); 379 380 MODULE_LICENSE("GPL"); 381 MODULE_DESCRIPTION("IP Payload Compression Protocol (IPComp) - RFC3173"); 382 MODULE_AUTHOR("James Morris <jmorris@intercode.com.au>"); 383