1 /* SPDX-License-Identifier: (GPL-2.0+ OR BSD-3-Clause) */
2 /*
3  * Copyright (c) 2017-2018 Mellanox Technologies. All rights reserved.
4  */
5 
6 #include <rdma/ib_verbs.h>
7 #include <rdma/restrack.h>
8 #include <linux/mutex.h>
9 #include <linux/sched/task.h>
10 #include <linux/uaccess.h>
11 #include <linux/pid_namespace.h>
12 
13 void rdma_restrack_init(struct rdma_restrack_root *res)
14 {
15 	init_rwsem(&res->rwsem);
16 }
17 
18 void rdma_restrack_clean(struct rdma_restrack_root *res)
19 {
20 	WARN_ON_ONCE(!hash_empty(res->hash));
21 }
22 
23 int rdma_restrack_count(struct rdma_restrack_root *res,
24 			enum rdma_restrack_type type,
25 			struct pid_namespace *ns)
26 {
27 	struct rdma_restrack_entry *e;
28 	u32 cnt = 0;
29 
30 	down_read(&res->rwsem);
31 	hash_for_each_possible(res->hash, e, node, type) {
32 		if (ns == &init_pid_ns ||
33 		    (!rdma_is_kernel_res(e) &&
34 		     ns == task_active_pid_ns(e->task)))
35 			cnt++;
36 	}
37 	up_read(&res->rwsem);
38 	return cnt;
39 }
40 EXPORT_SYMBOL(rdma_restrack_count);
41 
42 static void set_kern_name(struct rdma_restrack_entry *res)
43 {
44 	enum rdma_restrack_type type = res->type;
45 	struct ib_qp *qp;
46 
47 	if (type != RDMA_RESTRACK_QP)
48 		/* PD and CQ types already have this name embedded in */
49 		return;
50 
51 	qp = container_of(res, struct ib_qp, res);
52 	if (!qp->pd) {
53 		WARN_ONCE(true, "XRC QPs are not supported\n");
54 		/* Survive, despite the programmer's error */
55 		res->kern_name = " ";
56 		return;
57 	}
58 
59 	res->kern_name = qp->pd->res.kern_name;
60 }
61 
62 static struct ib_device *res_to_dev(struct rdma_restrack_entry *res)
63 {
64 	enum rdma_restrack_type type = res->type;
65 	struct ib_device *dev;
66 	struct ib_xrcd *xrcd;
67 	struct ib_pd *pd;
68 	struct ib_cq *cq;
69 	struct ib_qp *qp;
70 
71 	switch (type) {
72 	case RDMA_RESTRACK_PD:
73 		pd = container_of(res, struct ib_pd, res);
74 		dev = pd->device;
75 		break;
76 	case RDMA_RESTRACK_CQ:
77 		cq = container_of(res, struct ib_cq, res);
78 		dev = cq->device;
79 		break;
80 	case RDMA_RESTRACK_QP:
81 		qp = container_of(res, struct ib_qp, res);
82 		dev = qp->device;
83 		break;
84 	case RDMA_RESTRACK_XRCD:
85 		xrcd = container_of(res, struct ib_xrcd, res);
86 		dev = xrcd->device;
87 		break;
88 	default:
89 		WARN_ONCE(true, "Wrong resource tracking type %u\n", type);
90 		return NULL;
91 	}
92 
93 	return dev;
94 }
95 
96 void rdma_restrack_add(struct rdma_restrack_entry *res)
97 {
98 	struct ib_device *dev = res_to_dev(res);
99 
100 	if (!dev)
101 		return;
102 
103 	if (!uaccess_kernel()) {
104 		get_task_struct(current);
105 		res->task = current;
106 		res->kern_name = NULL;
107 	} else {
108 		set_kern_name(res);
109 		res->task = NULL;
110 	}
111 
112 	kref_init(&res->kref);
113 	init_completion(&res->comp);
114 	res->valid = true;
115 
116 	down_write(&dev->res.rwsem);
117 	hash_add(dev->res.hash, &res->node, res->type);
118 	up_write(&dev->res.rwsem);
119 }
120 EXPORT_SYMBOL(rdma_restrack_add);
121 
122 int __must_check rdma_restrack_get(struct rdma_restrack_entry *res)
123 {
124 	return kref_get_unless_zero(&res->kref);
125 }
126 EXPORT_SYMBOL(rdma_restrack_get);
127 
128 static void restrack_release(struct kref *kref)
129 {
130 	struct rdma_restrack_entry *res;
131 
132 	res = container_of(kref, struct rdma_restrack_entry, kref);
133 	complete(&res->comp);
134 }
135 
136 int rdma_restrack_put(struct rdma_restrack_entry *res)
137 {
138 	return kref_put(&res->kref, restrack_release);
139 }
140 EXPORT_SYMBOL(rdma_restrack_put);
141 
142 void rdma_restrack_del(struct rdma_restrack_entry *res)
143 {
144 	struct ib_device *dev;
145 
146 	if (!res->valid)
147 		return;
148 
149 	dev = res_to_dev(res);
150 	if (!dev)
151 		return;
152 
153 	rdma_restrack_put(res);
154 
155 	wait_for_completion(&res->comp);
156 
157 	down_write(&dev->res.rwsem);
158 	hash_del(&res->node);
159 	res->valid = false;
160 	if (res->task)
161 		put_task_struct(res->task);
162 	up_write(&dev->res.rwsem);
163 }
164 EXPORT_SYMBOL(rdma_restrack_del);
165