1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * (C) 2012 by Pablo Neira Ayuso <pablo@netfilter.org> 4 * (C) 2012 by Vyatta Inc. <http://www.vyatta.com> 5 */ 6 7 #include <linux/types.h> 8 #include <linux/netfilter.h> 9 #include <linux/skbuff.h> 10 #include <linux/vmalloc.h> 11 #include <linux/stddef.h> 12 #include <linux/err.h> 13 #include <linux/percpu.h> 14 #include <linux/kernel.h> 15 #include <linux/netdevice.h> 16 #include <linux/slab.h> 17 #include <linux/export.h> 18 19 #include <net/netfilter/nf_conntrack.h> 20 #include <net/netfilter/nf_conntrack_core.h> 21 #include <net/netfilter/nf_conntrack_extend.h> 22 #include <net/netfilter/nf_conntrack_timeout.h> 23 24 struct nf_ct_timeout * 25 (*nf_ct_timeout_find_get_hook)(struct net *net, const char *name) __read_mostly; 26 EXPORT_SYMBOL_GPL(nf_ct_timeout_find_get_hook); 27 28 void (*nf_ct_timeout_put_hook)(struct nf_ct_timeout *timeout) __read_mostly; 29 EXPORT_SYMBOL_GPL(nf_ct_timeout_put_hook); 30 31 static int untimeout(struct nf_conn *ct, void *timeout) 32 { 33 struct nf_conn_timeout *timeout_ext = nf_ct_timeout_find(ct); 34 35 if (timeout_ext && (!timeout || timeout_ext->timeout == timeout)) 36 RCU_INIT_POINTER(timeout_ext->timeout, NULL); 37 38 /* We are not intended to delete this conntrack. */ 39 return 0; 40 } 41 42 void nf_ct_untimeout(struct net *net, struct nf_ct_timeout *timeout) 43 { 44 nf_ct_iterate_cleanup_net(net, untimeout, timeout, 0, 0); 45 } 46 EXPORT_SYMBOL_GPL(nf_ct_untimeout); 47 48 static void __nf_ct_timeout_put(struct nf_ct_timeout *timeout) 49 { 50 typeof(nf_ct_timeout_put_hook) timeout_put; 51 52 timeout_put = rcu_dereference(nf_ct_timeout_put_hook); 53 if (timeout_put) 54 timeout_put(timeout); 55 } 56 57 int nf_ct_set_timeout(struct net *net, struct nf_conn *ct, 58 u8 l3num, u8 l4num, const char *timeout_name) 59 { 60 typeof(nf_ct_timeout_find_get_hook) timeout_find_get; 61 struct nf_ct_timeout *timeout; 62 struct nf_conn_timeout *timeout_ext; 63 const char *errmsg = NULL; 64 int ret = 0; 65 66 rcu_read_lock(); 67 timeout_find_get = rcu_dereference(nf_ct_timeout_find_get_hook); 68 if (!timeout_find_get) { 69 ret = -ENOENT; 70 errmsg = "Timeout policy base is empty"; 71 goto out; 72 } 73 74 timeout = timeout_find_get(net, timeout_name); 75 if (!timeout) { 76 ret = -ENOENT; 77 pr_info_ratelimited("No such timeout policy \"%s\"\n", 78 timeout_name); 79 goto out; 80 } 81 82 if (timeout->l3num != l3num) { 83 ret = -EINVAL; 84 pr_info_ratelimited("Timeout policy `%s' can only be used by " 85 "L%d protocol number %d\n", 86 timeout_name, 3, timeout->l3num); 87 goto err_put_timeout; 88 } 89 /* Make sure the timeout policy matches any existing protocol tracker, 90 * otherwise default to generic. 91 */ 92 if (timeout->l4proto->l4proto != l4num) { 93 ret = -EINVAL; 94 pr_info_ratelimited("Timeout policy `%s' can only be used by " 95 "L%d protocol number %d\n", 96 timeout_name, 4, timeout->l4proto->l4proto); 97 goto err_put_timeout; 98 } 99 timeout_ext = nf_ct_timeout_ext_add(ct, timeout, GFP_ATOMIC); 100 if (!timeout_ext) { 101 ret = -ENOMEM; 102 goto err_put_timeout; 103 } 104 105 rcu_read_unlock(); 106 return ret; 107 108 err_put_timeout: 109 __nf_ct_timeout_put(timeout); 110 out: 111 rcu_read_unlock(); 112 if (errmsg) 113 pr_info_ratelimited("%s\n", errmsg); 114 return ret; 115 } 116 EXPORT_SYMBOL_GPL(nf_ct_set_timeout); 117 118 void nf_ct_destroy_timeout(struct nf_conn *ct) 119 { 120 struct nf_conn_timeout *timeout_ext; 121 typeof(nf_ct_timeout_put_hook) timeout_put; 122 123 rcu_read_lock(); 124 timeout_put = rcu_dereference(nf_ct_timeout_put_hook); 125 126 if (timeout_put) { 127 timeout_ext = nf_ct_timeout_find(ct); 128 if (timeout_ext) { 129 timeout_put(timeout_ext->timeout); 130 RCU_INIT_POINTER(timeout_ext->timeout, NULL); 131 } 132 } 133 rcu_read_unlock(); 134 } 135 EXPORT_SYMBOL_GPL(nf_ct_destroy_timeout); 136 137 static const struct nf_ct_ext_type timeout_extend = { 138 .len = sizeof(struct nf_conn_timeout), 139 .align = __alignof__(struct nf_conn_timeout), 140 .id = NF_CT_EXT_TIMEOUT, 141 }; 142 143 int nf_conntrack_timeout_init(void) 144 { 145 int ret = nf_ct_extend_register(&timeout_extend); 146 if (ret < 0) 147 pr_err("nf_ct_timeout: Unable to register timeout extension.\n"); 148 return ret; 149 } 150 151 void nf_conntrack_timeout_fini(void) 152 { 153 nf_ct_extend_unregister(&timeout_extend); 154 } 155