xref: /openbmc/linux/drivers/vdpa/mlx5/net/mlx5_vnet.c (revision 35f752be)
1 // SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
2 /* Copyright (c) 2020 Mellanox Technologies Ltd. */
3 
4 #include <linux/module.h>
5 #include <linux/vdpa.h>
6 #include <linux/vringh.h>
7 #include <uapi/linux/virtio_net.h>
8 #include <uapi/linux/virtio_ids.h>
9 #include <linux/virtio_config.h>
10 #include <linux/auxiliary_bus.h>
11 #include <linux/mlx5/cq.h>
12 #include <linux/mlx5/qp.h>
13 #include <linux/mlx5/device.h>
14 #include <linux/mlx5/driver.h>
15 #include <linux/mlx5/vport.h>
16 #include <linux/mlx5/fs.h>
17 #include <linux/mlx5/mlx5_ifc_vdpa.h>
18 #include "mlx5_vdpa.h"
19 
20 MODULE_AUTHOR("Eli Cohen <eli@mellanox.com>");
21 MODULE_DESCRIPTION("Mellanox VDPA driver");
22 MODULE_LICENSE("Dual BSD/GPL");
23 
24 #define to_mlx5_vdpa_ndev(__mvdev)                                             \
25 	container_of(__mvdev, struct mlx5_vdpa_net, mvdev)
26 #define to_mvdev(__vdev) container_of((__vdev), struct mlx5_vdpa_dev, vdev)
27 
28 #define VALID_FEATURES_MASK                                                                        \
29 	(BIT_ULL(VIRTIO_NET_F_CSUM) | BIT_ULL(VIRTIO_NET_F_GUEST_CSUM) |                                   \
30 	 BIT_ULL(VIRTIO_NET_F_CTRL_GUEST_OFFLOADS) | BIT_ULL(VIRTIO_NET_F_MTU) | BIT_ULL(VIRTIO_NET_F_MAC) |   \
31 	 BIT_ULL(VIRTIO_NET_F_GUEST_TSO4) | BIT_ULL(VIRTIO_NET_F_GUEST_TSO6) |                             \
32 	 BIT_ULL(VIRTIO_NET_F_GUEST_ECN) | BIT_ULL(VIRTIO_NET_F_GUEST_UFO) | BIT_ULL(VIRTIO_NET_F_HOST_TSO4) | \
33 	 BIT_ULL(VIRTIO_NET_F_HOST_TSO6) | BIT_ULL(VIRTIO_NET_F_HOST_ECN) | BIT_ULL(VIRTIO_NET_F_HOST_UFO) |   \
34 	 BIT_ULL(VIRTIO_NET_F_MRG_RXBUF) | BIT_ULL(VIRTIO_NET_F_STATUS) | BIT_ULL(VIRTIO_NET_F_CTRL_VQ) |      \
35 	 BIT_ULL(VIRTIO_NET_F_CTRL_RX) | BIT_ULL(VIRTIO_NET_F_CTRL_VLAN) |                                 \
36 	 BIT_ULL(VIRTIO_NET_F_CTRL_RX_EXTRA) | BIT_ULL(VIRTIO_NET_F_GUEST_ANNOUNCE) |                      \
37 	 BIT_ULL(VIRTIO_NET_F_MQ) | BIT_ULL(VIRTIO_NET_F_CTRL_MAC_ADDR) | BIT_ULL(VIRTIO_NET_F_HASH_REPORT) |  \
38 	 BIT_ULL(VIRTIO_NET_F_RSS) | BIT_ULL(VIRTIO_NET_F_RSC_EXT) | BIT_ULL(VIRTIO_NET_F_STANDBY) |           \
39 	 BIT_ULL(VIRTIO_NET_F_SPEED_DUPLEX) | BIT_ULL(VIRTIO_F_NOTIFY_ON_EMPTY) |                          \
40 	 BIT_ULL(VIRTIO_F_ANY_LAYOUT) | BIT_ULL(VIRTIO_F_VERSION_1) | BIT_ULL(VIRTIO_F_ACCESS_PLATFORM) |      \
41 	 BIT_ULL(VIRTIO_F_RING_PACKED) | BIT_ULL(VIRTIO_F_ORDER_PLATFORM) | BIT_ULL(VIRTIO_F_SR_IOV))
42 
43 #define VALID_STATUS_MASK                                                                          \
44 	(VIRTIO_CONFIG_S_ACKNOWLEDGE | VIRTIO_CONFIG_S_DRIVER | VIRTIO_CONFIG_S_DRIVER_OK |        \
45 	 VIRTIO_CONFIG_S_FEATURES_OK | VIRTIO_CONFIG_S_NEEDS_RESET | VIRTIO_CONFIG_S_FAILED)
46 
47 struct mlx5_vdpa_net_resources {
48 	u32 tisn;
49 	u32 tdn;
50 	u32 tirn;
51 	u32 rqtn;
52 	bool valid;
53 };
54 
55 struct mlx5_vdpa_cq_buf {
56 	struct mlx5_frag_buf_ctrl fbc;
57 	struct mlx5_frag_buf frag_buf;
58 	int cqe_size;
59 	int nent;
60 };
61 
62 struct mlx5_vdpa_cq {
63 	struct mlx5_core_cq mcq;
64 	struct mlx5_vdpa_cq_buf buf;
65 	struct mlx5_db db;
66 	int cqe;
67 };
68 
69 struct mlx5_vdpa_umem {
70 	struct mlx5_frag_buf_ctrl fbc;
71 	struct mlx5_frag_buf frag_buf;
72 	int size;
73 	u32 id;
74 };
75 
76 struct mlx5_vdpa_qp {
77 	struct mlx5_core_qp mqp;
78 	struct mlx5_frag_buf frag_buf;
79 	struct mlx5_db db;
80 	u16 head;
81 	bool fw;
82 };
83 
84 struct mlx5_vq_restore_info {
85 	u32 num_ent;
86 	u64 desc_addr;
87 	u64 device_addr;
88 	u64 driver_addr;
89 	u16 avail_index;
90 	u16 used_index;
91 	bool ready;
92 	struct vdpa_callback cb;
93 	bool restore;
94 };
95 
96 struct mlx5_vdpa_virtqueue {
97 	bool ready;
98 	u64 desc_addr;
99 	u64 device_addr;
100 	u64 driver_addr;
101 	u32 num_ent;
102 	struct vdpa_callback event_cb;
103 
104 	/* Resources for implementing the notification channel from the device
105 	 * to the driver. fwqp is the firmware end of an RC connection; the
106 	 * other end is vqqp used by the driver. cq is is where completions are
107 	 * reported.
108 	 */
109 	struct mlx5_vdpa_cq cq;
110 	struct mlx5_vdpa_qp fwqp;
111 	struct mlx5_vdpa_qp vqqp;
112 
113 	/* umem resources are required for the virtqueue operation. They're use
114 	 * is internal and they must be provided by the driver.
115 	 */
116 	struct mlx5_vdpa_umem umem1;
117 	struct mlx5_vdpa_umem umem2;
118 	struct mlx5_vdpa_umem umem3;
119 
120 	bool initialized;
121 	int index;
122 	u32 virtq_id;
123 	struct mlx5_vdpa_net *ndev;
124 	u16 avail_idx;
125 	u16 used_idx;
126 	int fw_state;
127 
128 	/* keep last in the struct */
129 	struct mlx5_vq_restore_info ri;
130 };
131 
132 /* We will remove this limitation once mlx5_vdpa_alloc_resources()
133  * provides for driver space allocation
134  */
135 #define MLX5_MAX_SUPPORTED_VQS 16
136 
137 struct mlx5_vdpa_net {
138 	struct mlx5_vdpa_dev mvdev;
139 	struct mlx5_vdpa_net_resources res;
140 	struct virtio_net_config config;
141 	struct mlx5_vdpa_virtqueue vqs[MLX5_MAX_SUPPORTED_VQS];
142 
143 	/* Serialize vq resources creation and destruction. This is required
144 	 * since memory map might change and we need to destroy and create
145 	 * resources while driver in operational.
146 	 */
147 	struct mutex reslock;
148 	struct mlx5_flow_table *rxft;
149 	struct mlx5_fc *rx_counter;
150 	struct mlx5_flow_handle *rx_rule;
151 	bool setup;
152 	u16 mtu;
153 };
154 
155 static void free_resources(struct mlx5_vdpa_net *ndev);
156 static void init_mvqs(struct mlx5_vdpa_net *ndev);
157 static int setup_driver(struct mlx5_vdpa_net *ndev);
158 static void teardown_driver(struct mlx5_vdpa_net *ndev);
159 
160 static bool mlx5_vdpa_debug;
161 
162 #define MLX5_LOG_VIO_FLAG(_feature)                                                                \
163 	do {                                                                                       \
164 		if (features & BIT_ULL(_feature))                                                  \
165 			mlx5_vdpa_info(mvdev, "%s\n", #_feature);                                  \
166 	} while (0)
167 
168 #define MLX5_LOG_VIO_STAT(_status)                                                                 \
169 	do {                                                                                       \
170 		if (status & (_status))                                                            \
171 			mlx5_vdpa_info(mvdev, "%s\n", #_status);                                   \
172 	} while (0)
173 
174 static inline u32 mlx5_vdpa_max_qps(int max_vqs)
175 {
176 	return max_vqs / 2;
177 }
178 
179 static void print_status(struct mlx5_vdpa_dev *mvdev, u8 status, bool set)
180 {
181 	if (status & ~VALID_STATUS_MASK)
182 		mlx5_vdpa_warn(mvdev, "Warning: there are invalid status bits 0x%x\n",
183 			       status & ~VALID_STATUS_MASK);
184 
185 	if (!mlx5_vdpa_debug)
186 		return;
187 
188 	mlx5_vdpa_info(mvdev, "driver status %s", set ? "set" : "get");
189 	if (set && !status) {
190 		mlx5_vdpa_info(mvdev, "driver resets the device\n");
191 		return;
192 	}
193 
194 	MLX5_LOG_VIO_STAT(VIRTIO_CONFIG_S_ACKNOWLEDGE);
195 	MLX5_LOG_VIO_STAT(VIRTIO_CONFIG_S_DRIVER);
196 	MLX5_LOG_VIO_STAT(VIRTIO_CONFIG_S_DRIVER_OK);
197 	MLX5_LOG_VIO_STAT(VIRTIO_CONFIG_S_FEATURES_OK);
198 	MLX5_LOG_VIO_STAT(VIRTIO_CONFIG_S_NEEDS_RESET);
199 	MLX5_LOG_VIO_STAT(VIRTIO_CONFIG_S_FAILED);
200 }
201 
202 static void print_features(struct mlx5_vdpa_dev *mvdev, u64 features, bool set)
203 {
204 	if (features & ~VALID_FEATURES_MASK)
205 		mlx5_vdpa_warn(mvdev, "There are invalid feature bits 0x%llx\n",
206 			       features & ~VALID_FEATURES_MASK);
207 
208 	if (!mlx5_vdpa_debug)
209 		return;
210 
211 	mlx5_vdpa_info(mvdev, "driver %s feature bits:\n", set ? "sets" : "reads");
212 	if (!features)
213 		mlx5_vdpa_info(mvdev, "all feature bits are cleared\n");
214 
215 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_CSUM);
216 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_GUEST_CSUM);
217 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_CTRL_GUEST_OFFLOADS);
218 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_MTU);
219 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_MAC);
220 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_GUEST_TSO4);
221 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_GUEST_TSO6);
222 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_GUEST_ECN);
223 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_GUEST_UFO);
224 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_HOST_TSO4);
225 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_HOST_TSO6);
226 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_HOST_ECN);
227 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_HOST_UFO);
228 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_MRG_RXBUF);
229 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_STATUS);
230 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_CTRL_VQ);
231 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_CTRL_RX);
232 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_CTRL_VLAN);
233 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_CTRL_RX_EXTRA);
234 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_GUEST_ANNOUNCE);
235 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_MQ);
236 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_CTRL_MAC_ADDR);
237 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_HASH_REPORT);
238 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_RSS);
239 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_RSC_EXT);
240 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_STANDBY);
241 	MLX5_LOG_VIO_FLAG(VIRTIO_NET_F_SPEED_DUPLEX);
242 	MLX5_LOG_VIO_FLAG(VIRTIO_F_NOTIFY_ON_EMPTY);
243 	MLX5_LOG_VIO_FLAG(VIRTIO_F_ANY_LAYOUT);
244 	MLX5_LOG_VIO_FLAG(VIRTIO_F_VERSION_1);
245 	MLX5_LOG_VIO_FLAG(VIRTIO_F_ACCESS_PLATFORM);
246 	MLX5_LOG_VIO_FLAG(VIRTIO_F_RING_PACKED);
247 	MLX5_LOG_VIO_FLAG(VIRTIO_F_ORDER_PLATFORM);
248 	MLX5_LOG_VIO_FLAG(VIRTIO_F_SR_IOV);
249 }
250 
251 static int create_tis(struct mlx5_vdpa_net *ndev)
252 {
253 	struct mlx5_vdpa_dev *mvdev = &ndev->mvdev;
254 	u32 in[MLX5_ST_SZ_DW(create_tis_in)] = {};
255 	void *tisc;
256 	int err;
257 
258 	tisc = MLX5_ADDR_OF(create_tis_in, in, ctx);
259 	MLX5_SET(tisc, tisc, transport_domain, ndev->res.tdn);
260 	err = mlx5_vdpa_create_tis(mvdev, in, &ndev->res.tisn);
261 	if (err)
262 		mlx5_vdpa_warn(mvdev, "create TIS (%d)\n", err);
263 
264 	return err;
265 }
266 
267 static void destroy_tis(struct mlx5_vdpa_net *ndev)
268 {
269 	mlx5_vdpa_destroy_tis(&ndev->mvdev, ndev->res.tisn);
270 }
271 
272 #define MLX5_VDPA_CQE_SIZE 64
273 #define MLX5_VDPA_LOG_CQE_SIZE ilog2(MLX5_VDPA_CQE_SIZE)
274 
275 static int cq_frag_buf_alloc(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_cq_buf *buf, int nent)
276 {
277 	struct mlx5_frag_buf *frag_buf = &buf->frag_buf;
278 	u8 log_wq_stride = MLX5_VDPA_LOG_CQE_SIZE;
279 	u8 log_wq_sz = MLX5_VDPA_LOG_CQE_SIZE;
280 	int err;
281 
282 	err = mlx5_frag_buf_alloc_node(ndev->mvdev.mdev, nent * MLX5_VDPA_CQE_SIZE, frag_buf,
283 				       ndev->mvdev.mdev->priv.numa_node);
284 	if (err)
285 		return err;
286 
287 	mlx5_init_fbc(frag_buf->frags, log_wq_stride, log_wq_sz, &buf->fbc);
288 
289 	buf->cqe_size = MLX5_VDPA_CQE_SIZE;
290 	buf->nent = nent;
291 
292 	return 0;
293 }
294 
295 static int umem_frag_buf_alloc(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_umem *umem, int size)
296 {
297 	struct mlx5_frag_buf *frag_buf = &umem->frag_buf;
298 
299 	return mlx5_frag_buf_alloc_node(ndev->mvdev.mdev, size, frag_buf,
300 					ndev->mvdev.mdev->priv.numa_node);
301 }
302 
303 static void cq_frag_buf_free(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_cq_buf *buf)
304 {
305 	mlx5_frag_buf_free(ndev->mvdev.mdev, &buf->frag_buf);
306 }
307 
308 static void *get_cqe(struct mlx5_vdpa_cq *vcq, int n)
309 {
310 	return mlx5_frag_buf_get_wqe(&vcq->buf.fbc, n);
311 }
312 
313 static void cq_frag_buf_init(struct mlx5_vdpa_cq *vcq, struct mlx5_vdpa_cq_buf *buf)
314 {
315 	struct mlx5_cqe64 *cqe64;
316 	void *cqe;
317 	int i;
318 
319 	for (i = 0; i < buf->nent; i++) {
320 		cqe = get_cqe(vcq, i);
321 		cqe64 = cqe;
322 		cqe64->op_own = MLX5_CQE_INVALID << 4;
323 	}
324 }
325 
326 static void *get_sw_cqe(struct mlx5_vdpa_cq *cq, int n)
327 {
328 	struct mlx5_cqe64 *cqe64 = get_cqe(cq, n & (cq->cqe - 1));
329 
330 	if (likely(get_cqe_opcode(cqe64) != MLX5_CQE_INVALID) &&
331 	    !((cqe64->op_own & MLX5_CQE_OWNER_MASK) ^ !!(n & cq->cqe)))
332 		return cqe64;
333 
334 	return NULL;
335 }
336 
337 static void rx_post(struct mlx5_vdpa_qp *vqp, int n)
338 {
339 	vqp->head += n;
340 	vqp->db.db[0] = cpu_to_be32(vqp->head);
341 }
342 
343 static void qp_prepare(struct mlx5_vdpa_net *ndev, bool fw, void *in,
344 		       struct mlx5_vdpa_virtqueue *mvq, u32 num_ent)
345 {
346 	struct mlx5_vdpa_qp *vqp;
347 	__be64 *pas;
348 	void *qpc;
349 
350 	vqp = fw ? &mvq->fwqp : &mvq->vqqp;
351 	MLX5_SET(create_qp_in, in, uid, ndev->mvdev.res.uid);
352 	qpc = MLX5_ADDR_OF(create_qp_in, in, qpc);
353 	if (vqp->fw) {
354 		/* Firmware QP is allocated by the driver for the firmware's
355 		 * use so we can skip part of the params as they will be chosen by firmware
356 		 */
357 		qpc = MLX5_ADDR_OF(create_qp_in, in, qpc);
358 		MLX5_SET(qpc, qpc, rq_type, MLX5_ZERO_LEN_RQ);
359 		MLX5_SET(qpc, qpc, no_sq, 1);
360 		return;
361 	}
362 
363 	MLX5_SET(qpc, qpc, st, MLX5_QP_ST_RC);
364 	MLX5_SET(qpc, qpc, pm_state, MLX5_QP_PM_MIGRATED);
365 	MLX5_SET(qpc, qpc, pd, ndev->mvdev.res.pdn);
366 	MLX5_SET(qpc, qpc, mtu, MLX5_QPC_MTU_256_BYTES);
367 	MLX5_SET(qpc, qpc, uar_page, ndev->mvdev.res.uar->index);
368 	MLX5_SET(qpc, qpc, log_page_size, vqp->frag_buf.page_shift - MLX5_ADAPTER_PAGE_SHIFT);
369 	MLX5_SET(qpc, qpc, no_sq, 1);
370 	MLX5_SET(qpc, qpc, cqn_rcv, mvq->cq.mcq.cqn);
371 	MLX5_SET(qpc, qpc, log_rq_size, ilog2(num_ent));
372 	MLX5_SET(qpc, qpc, rq_type, MLX5_NON_ZERO_RQ);
373 	pas = (__be64 *)MLX5_ADDR_OF(create_qp_in, in, pas);
374 	mlx5_fill_page_frag_array(&vqp->frag_buf, pas);
375 }
376 
377 static int rq_buf_alloc(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_qp *vqp, u32 num_ent)
378 {
379 	return mlx5_frag_buf_alloc_node(ndev->mvdev.mdev,
380 					num_ent * sizeof(struct mlx5_wqe_data_seg), &vqp->frag_buf,
381 					ndev->mvdev.mdev->priv.numa_node);
382 }
383 
384 static void rq_buf_free(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_qp *vqp)
385 {
386 	mlx5_frag_buf_free(ndev->mvdev.mdev, &vqp->frag_buf);
387 }
388 
389 static int qp_create(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq,
390 		     struct mlx5_vdpa_qp *vqp)
391 {
392 	struct mlx5_core_dev *mdev = ndev->mvdev.mdev;
393 	int inlen = MLX5_ST_SZ_BYTES(create_qp_in);
394 	u32 out[MLX5_ST_SZ_DW(create_qp_out)] = {};
395 	void *qpc;
396 	void *in;
397 	int err;
398 
399 	if (!vqp->fw) {
400 		vqp = &mvq->vqqp;
401 		err = rq_buf_alloc(ndev, vqp, mvq->num_ent);
402 		if (err)
403 			return err;
404 
405 		err = mlx5_db_alloc(ndev->mvdev.mdev, &vqp->db);
406 		if (err)
407 			goto err_db;
408 		inlen += vqp->frag_buf.npages * sizeof(__be64);
409 	}
410 
411 	in = kzalloc(inlen, GFP_KERNEL);
412 	if (!in) {
413 		err = -ENOMEM;
414 		goto err_kzalloc;
415 	}
416 
417 	qp_prepare(ndev, vqp->fw, in, mvq, mvq->num_ent);
418 	qpc = MLX5_ADDR_OF(create_qp_in, in, qpc);
419 	MLX5_SET(qpc, qpc, st, MLX5_QP_ST_RC);
420 	MLX5_SET(qpc, qpc, pm_state, MLX5_QP_PM_MIGRATED);
421 	MLX5_SET(qpc, qpc, pd, ndev->mvdev.res.pdn);
422 	MLX5_SET(qpc, qpc, mtu, MLX5_QPC_MTU_256_BYTES);
423 	if (!vqp->fw)
424 		MLX5_SET64(qpc, qpc, dbr_addr, vqp->db.dma);
425 	MLX5_SET(create_qp_in, in, opcode, MLX5_CMD_OP_CREATE_QP);
426 	err = mlx5_cmd_exec(mdev, in, inlen, out, sizeof(out));
427 	kfree(in);
428 	if (err)
429 		goto err_kzalloc;
430 
431 	vqp->mqp.uid = ndev->mvdev.res.uid;
432 	vqp->mqp.qpn = MLX5_GET(create_qp_out, out, qpn);
433 
434 	if (!vqp->fw)
435 		rx_post(vqp, mvq->num_ent);
436 
437 	return 0;
438 
439 err_kzalloc:
440 	if (!vqp->fw)
441 		mlx5_db_free(ndev->mvdev.mdev, &vqp->db);
442 err_db:
443 	if (!vqp->fw)
444 		rq_buf_free(ndev, vqp);
445 
446 	return err;
447 }
448 
449 static void qp_destroy(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_qp *vqp)
450 {
451 	u32 in[MLX5_ST_SZ_DW(destroy_qp_in)] = {};
452 
453 	MLX5_SET(destroy_qp_in, in, opcode, MLX5_CMD_OP_DESTROY_QP);
454 	MLX5_SET(destroy_qp_in, in, qpn, vqp->mqp.qpn);
455 	MLX5_SET(destroy_qp_in, in, uid, ndev->mvdev.res.uid);
456 	if (mlx5_cmd_exec_in(ndev->mvdev.mdev, destroy_qp, in))
457 		mlx5_vdpa_warn(&ndev->mvdev, "destroy qp 0x%x\n", vqp->mqp.qpn);
458 	if (!vqp->fw) {
459 		mlx5_db_free(ndev->mvdev.mdev, &vqp->db);
460 		rq_buf_free(ndev, vqp);
461 	}
462 }
463 
464 static void *next_cqe_sw(struct mlx5_vdpa_cq *cq)
465 {
466 	return get_sw_cqe(cq, cq->mcq.cons_index);
467 }
468 
469 static int mlx5_vdpa_poll_one(struct mlx5_vdpa_cq *vcq)
470 {
471 	struct mlx5_cqe64 *cqe64;
472 
473 	cqe64 = next_cqe_sw(vcq);
474 	if (!cqe64)
475 		return -EAGAIN;
476 
477 	vcq->mcq.cons_index++;
478 	return 0;
479 }
480 
481 static void mlx5_vdpa_handle_completions(struct mlx5_vdpa_virtqueue *mvq, int num)
482 {
483 	mlx5_cq_set_ci(&mvq->cq.mcq);
484 
485 	/* make sure CQ cosumer update is visible to the hardware before updating
486 	 * RX doorbell record.
487 	 */
488 	dma_wmb();
489 	rx_post(&mvq->vqqp, num);
490 	if (mvq->event_cb.callback)
491 		mvq->event_cb.callback(mvq->event_cb.private);
492 }
493 
494 static void mlx5_vdpa_cq_comp(struct mlx5_core_cq *mcq, struct mlx5_eqe *eqe)
495 {
496 	struct mlx5_vdpa_virtqueue *mvq = container_of(mcq, struct mlx5_vdpa_virtqueue, cq.mcq);
497 	struct mlx5_vdpa_net *ndev = mvq->ndev;
498 	void __iomem *uar_page = ndev->mvdev.res.uar->map;
499 	int num = 0;
500 
501 	while (!mlx5_vdpa_poll_one(&mvq->cq)) {
502 		num++;
503 		if (num > mvq->num_ent / 2) {
504 			/* If completions keep coming while we poll, we want to
505 			 * let the hardware know that we consumed them by
506 			 * updating the doorbell record.  We also let vdpa core
507 			 * know about this so it passes it on the virtio driver
508 			 * on the guest.
509 			 */
510 			mlx5_vdpa_handle_completions(mvq, num);
511 			num = 0;
512 		}
513 	}
514 
515 	if (num)
516 		mlx5_vdpa_handle_completions(mvq, num);
517 
518 	mlx5_cq_arm(&mvq->cq.mcq, MLX5_CQ_DB_REQ_NOT, uar_page, mvq->cq.mcq.cons_index);
519 }
520 
521 static int cq_create(struct mlx5_vdpa_net *ndev, u16 idx, u32 num_ent)
522 {
523 	struct mlx5_vdpa_virtqueue *mvq = &ndev->vqs[idx];
524 	struct mlx5_core_dev *mdev = ndev->mvdev.mdev;
525 	void __iomem *uar_page = ndev->mvdev.res.uar->map;
526 	u32 out[MLX5_ST_SZ_DW(create_cq_out)];
527 	struct mlx5_vdpa_cq *vcq = &mvq->cq;
528 	unsigned int irqn;
529 	__be64 *pas;
530 	int inlen;
531 	void *cqc;
532 	void *in;
533 	int err;
534 	int eqn;
535 
536 	err = mlx5_db_alloc(mdev, &vcq->db);
537 	if (err)
538 		return err;
539 
540 	vcq->mcq.set_ci_db = vcq->db.db;
541 	vcq->mcq.arm_db = vcq->db.db + 1;
542 	vcq->mcq.cqe_sz = 64;
543 
544 	err = cq_frag_buf_alloc(ndev, &vcq->buf, num_ent);
545 	if (err)
546 		goto err_db;
547 
548 	cq_frag_buf_init(vcq, &vcq->buf);
549 
550 	inlen = MLX5_ST_SZ_BYTES(create_cq_in) +
551 		MLX5_FLD_SZ_BYTES(create_cq_in, pas[0]) * vcq->buf.frag_buf.npages;
552 	in = kzalloc(inlen, GFP_KERNEL);
553 	if (!in) {
554 		err = -ENOMEM;
555 		goto err_vzalloc;
556 	}
557 
558 	MLX5_SET(create_cq_in, in, uid, ndev->mvdev.res.uid);
559 	pas = (__be64 *)MLX5_ADDR_OF(create_cq_in, in, pas);
560 	mlx5_fill_page_frag_array(&vcq->buf.frag_buf, pas);
561 
562 	cqc = MLX5_ADDR_OF(create_cq_in, in, cq_context);
563 	MLX5_SET(cqc, cqc, log_page_size, vcq->buf.frag_buf.page_shift - MLX5_ADAPTER_PAGE_SHIFT);
564 
565 	/* Use vector 0 by default. Consider adding code to choose least used
566 	 * vector.
567 	 */
568 	err = mlx5_vector2eqn(mdev, 0, &eqn, &irqn);
569 	if (err)
570 		goto err_vec;
571 
572 	cqc = MLX5_ADDR_OF(create_cq_in, in, cq_context);
573 	MLX5_SET(cqc, cqc, log_cq_size, ilog2(num_ent));
574 	MLX5_SET(cqc, cqc, uar_page, ndev->mvdev.res.uar->index);
575 	MLX5_SET(cqc, cqc, c_eqn, eqn);
576 	MLX5_SET64(cqc, cqc, dbr_addr, vcq->db.dma);
577 
578 	err = mlx5_core_create_cq(mdev, &vcq->mcq, in, inlen, out, sizeof(out));
579 	if (err)
580 		goto err_vec;
581 
582 	vcq->mcq.comp = mlx5_vdpa_cq_comp;
583 	vcq->cqe = num_ent;
584 	vcq->mcq.set_ci_db = vcq->db.db;
585 	vcq->mcq.arm_db = vcq->db.db + 1;
586 	mlx5_cq_arm(&mvq->cq.mcq, MLX5_CQ_DB_REQ_NOT, uar_page, mvq->cq.mcq.cons_index);
587 	kfree(in);
588 	return 0;
589 
590 err_vec:
591 	kfree(in);
592 err_vzalloc:
593 	cq_frag_buf_free(ndev, &vcq->buf);
594 err_db:
595 	mlx5_db_free(ndev->mvdev.mdev, &vcq->db);
596 	return err;
597 }
598 
599 static void cq_destroy(struct mlx5_vdpa_net *ndev, u16 idx)
600 {
601 	struct mlx5_vdpa_virtqueue *mvq = &ndev->vqs[idx];
602 	struct mlx5_core_dev *mdev = ndev->mvdev.mdev;
603 	struct mlx5_vdpa_cq *vcq = &mvq->cq;
604 
605 	if (mlx5_core_destroy_cq(mdev, &vcq->mcq)) {
606 		mlx5_vdpa_warn(&ndev->mvdev, "destroy CQ 0x%x\n", vcq->mcq.cqn);
607 		return;
608 	}
609 	cq_frag_buf_free(ndev, &vcq->buf);
610 	mlx5_db_free(ndev->mvdev.mdev, &vcq->db);
611 }
612 
613 static int umem_size(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq, int num,
614 		     struct mlx5_vdpa_umem **umemp)
615 {
616 	struct mlx5_core_dev *mdev = ndev->mvdev.mdev;
617 	int p_a;
618 	int p_b;
619 
620 	switch (num) {
621 	case 1:
622 		p_a = MLX5_CAP_DEV_VDPA_EMULATION(mdev, umem_1_buffer_param_a);
623 		p_b = MLX5_CAP_DEV_VDPA_EMULATION(mdev, umem_1_buffer_param_b);
624 		*umemp = &mvq->umem1;
625 		break;
626 	case 2:
627 		p_a = MLX5_CAP_DEV_VDPA_EMULATION(mdev, umem_2_buffer_param_a);
628 		p_b = MLX5_CAP_DEV_VDPA_EMULATION(mdev, umem_2_buffer_param_b);
629 		*umemp = &mvq->umem2;
630 		break;
631 	case 3:
632 		p_a = MLX5_CAP_DEV_VDPA_EMULATION(mdev, umem_3_buffer_param_a);
633 		p_b = MLX5_CAP_DEV_VDPA_EMULATION(mdev, umem_3_buffer_param_b);
634 		*umemp = &mvq->umem3;
635 		break;
636 	}
637 	return p_a * mvq->num_ent + p_b;
638 }
639 
640 static void umem_frag_buf_free(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_umem *umem)
641 {
642 	mlx5_frag_buf_free(ndev->mvdev.mdev, &umem->frag_buf);
643 }
644 
645 static int create_umem(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq, int num)
646 {
647 	int inlen;
648 	u32 out[MLX5_ST_SZ_DW(create_umem_out)] = {};
649 	void *um;
650 	void *in;
651 	int err;
652 	__be64 *pas;
653 	int size;
654 	struct mlx5_vdpa_umem *umem;
655 
656 	size = umem_size(ndev, mvq, num, &umem);
657 	if (size < 0)
658 		return size;
659 
660 	umem->size = size;
661 	err = umem_frag_buf_alloc(ndev, umem, size);
662 	if (err)
663 		return err;
664 
665 	inlen = MLX5_ST_SZ_BYTES(create_umem_in) + MLX5_ST_SZ_BYTES(mtt) * umem->frag_buf.npages;
666 
667 	in = kzalloc(inlen, GFP_KERNEL);
668 	if (!in) {
669 		err = -ENOMEM;
670 		goto err_in;
671 	}
672 
673 	MLX5_SET(create_umem_in, in, opcode, MLX5_CMD_OP_CREATE_UMEM);
674 	MLX5_SET(create_umem_in, in, uid, ndev->mvdev.res.uid);
675 	um = MLX5_ADDR_OF(create_umem_in, in, umem);
676 	MLX5_SET(umem, um, log_page_size, umem->frag_buf.page_shift - MLX5_ADAPTER_PAGE_SHIFT);
677 	MLX5_SET64(umem, um, num_of_mtt, umem->frag_buf.npages);
678 
679 	pas = (__be64 *)MLX5_ADDR_OF(umem, um, mtt[0]);
680 	mlx5_fill_page_frag_array_perm(&umem->frag_buf, pas, MLX5_MTT_PERM_RW);
681 
682 	err = mlx5_cmd_exec(ndev->mvdev.mdev, in, inlen, out, sizeof(out));
683 	if (err) {
684 		mlx5_vdpa_warn(&ndev->mvdev, "create umem(%d)\n", err);
685 		goto err_cmd;
686 	}
687 
688 	kfree(in);
689 	umem->id = MLX5_GET(create_umem_out, out, umem_id);
690 
691 	return 0;
692 
693 err_cmd:
694 	kfree(in);
695 err_in:
696 	umem_frag_buf_free(ndev, umem);
697 	return err;
698 }
699 
700 static void umem_destroy(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq, int num)
701 {
702 	u32 in[MLX5_ST_SZ_DW(destroy_umem_in)] = {};
703 	u32 out[MLX5_ST_SZ_DW(destroy_umem_out)] = {};
704 	struct mlx5_vdpa_umem *umem;
705 
706 	switch (num) {
707 	case 1:
708 		umem = &mvq->umem1;
709 		break;
710 	case 2:
711 		umem = &mvq->umem2;
712 		break;
713 	case 3:
714 		umem = &mvq->umem3;
715 		break;
716 	}
717 
718 	MLX5_SET(destroy_umem_in, in, opcode, MLX5_CMD_OP_DESTROY_UMEM);
719 	MLX5_SET(destroy_umem_in, in, umem_id, umem->id);
720 	if (mlx5_cmd_exec(ndev->mvdev.mdev, in, sizeof(in), out, sizeof(out)))
721 		return;
722 
723 	umem_frag_buf_free(ndev, umem);
724 }
725 
726 static int umems_create(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
727 {
728 	int num;
729 	int err;
730 
731 	for (num = 1; num <= 3; num++) {
732 		err = create_umem(ndev, mvq, num);
733 		if (err)
734 			goto err_umem;
735 	}
736 	return 0;
737 
738 err_umem:
739 	for (num--; num > 0; num--)
740 		umem_destroy(ndev, mvq, num);
741 
742 	return err;
743 }
744 
745 static void umems_destroy(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
746 {
747 	int num;
748 
749 	for (num = 3; num > 0; num--)
750 		umem_destroy(ndev, mvq, num);
751 }
752 
753 static int get_queue_type(struct mlx5_vdpa_net *ndev)
754 {
755 	u32 type_mask;
756 
757 	type_mask = MLX5_CAP_DEV_VDPA_EMULATION(ndev->mvdev.mdev, virtio_queue_type);
758 
759 	/* prefer split queue */
760 	if (type_mask & MLX5_VIRTIO_EMULATION_CAP_VIRTIO_QUEUE_TYPE_PACKED)
761 		return MLX5_VIRTIO_EMULATION_VIRTIO_QUEUE_TYPE_PACKED;
762 
763 	WARN_ON(!(type_mask & MLX5_VIRTIO_EMULATION_CAP_VIRTIO_QUEUE_TYPE_SPLIT));
764 
765 	return MLX5_VIRTIO_EMULATION_VIRTIO_QUEUE_TYPE_SPLIT;
766 }
767 
768 static bool vq_is_tx(u16 idx)
769 {
770 	return idx % 2;
771 }
772 
773 static u16 get_features_12_3(u64 features)
774 {
775 	return (!!(features & BIT_ULL(VIRTIO_NET_F_HOST_TSO4)) << 9) |
776 	       (!!(features & BIT_ULL(VIRTIO_NET_F_HOST_TSO6)) << 8) |
777 	       (!!(features & BIT_ULL(VIRTIO_NET_F_CSUM)) << 7) |
778 	       (!!(features & BIT_ULL(VIRTIO_NET_F_GUEST_CSUM)) << 6);
779 }
780 
781 static int create_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
782 {
783 	int inlen = MLX5_ST_SZ_BYTES(create_virtio_net_q_in);
784 	u32 out[MLX5_ST_SZ_DW(create_virtio_net_q_out)] = {};
785 	void *obj_context;
786 	void *cmd_hdr;
787 	void *vq_ctx;
788 	void *in;
789 	int err;
790 
791 	err = umems_create(ndev, mvq);
792 	if (err)
793 		return err;
794 
795 	in = kzalloc(inlen, GFP_KERNEL);
796 	if (!in) {
797 		err = -ENOMEM;
798 		goto err_alloc;
799 	}
800 
801 	cmd_hdr = MLX5_ADDR_OF(create_virtio_net_q_in, in, general_obj_in_cmd_hdr);
802 
803 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, opcode, MLX5_CMD_OP_CREATE_GENERAL_OBJECT);
804 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, obj_type, MLX5_OBJ_TYPE_VIRTIO_NET_Q);
805 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, uid, ndev->mvdev.res.uid);
806 
807 	obj_context = MLX5_ADDR_OF(create_virtio_net_q_in, in, obj_context);
808 	MLX5_SET(virtio_net_q_object, obj_context, hw_available_index, mvq->avail_idx);
809 	MLX5_SET(virtio_net_q_object, obj_context, hw_used_index, mvq->used_idx);
810 	MLX5_SET(virtio_net_q_object, obj_context, queue_feature_bit_mask_12_3,
811 		 get_features_12_3(ndev->mvdev.actual_features));
812 	vq_ctx = MLX5_ADDR_OF(virtio_net_q_object, obj_context, virtio_q_context);
813 	MLX5_SET(virtio_q, vq_ctx, virtio_q_type, get_queue_type(ndev));
814 
815 	if (vq_is_tx(mvq->index))
816 		MLX5_SET(virtio_net_q_object, obj_context, tisn_or_qpn, ndev->res.tisn);
817 
818 	MLX5_SET(virtio_q, vq_ctx, event_mode, MLX5_VIRTIO_Q_EVENT_MODE_QP_MODE);
819 	MLX5_SET(virtio_q, vq_ctx, queue_index, mvq->index);
820 	MLX5_SET(virtio_q, vq_ctx, event_qpn_or_msix, mvq->fwqp.mqp.qpn);
821 	MLX5_SET(virtio_q, vq_ctx, queue_size, mvq->num_ent);
822 	MLX5_SET(virtio_q, vq_ctx, virtio_version_1_0,
823 		 !!(ndev->mvdev.actual_features & BIT_ULL(VIRTIO_F_VERSION_1)));
824 	MLX5_SET64(virtio_q, vq_ctx, desc_addr, mvq->desc_addr);
825 	MLX5_SET64(virtio_q, vq_ctx, used_addr, mvq->device_addr);
826 	MLX5_SET64(virtio_q, vq_ctx, available_addr, mvq->driver_addr);
827 	MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, ndev->mvdev.mr.mkey.key);
828 	MLX5_SET(virtio_q, vq_ctx, umem_1_id, mvq->umem1.id);
829 	MLX5_SET(virtio_q, vq_ctx, umem_1_size, mvq->umem1.size);
830 	MLX5_SET(virtio_q, vq_ctx, umem_2_id, mvq->umem2.id);
831 	MLX5_SET(virtio_q, vq_ctx, umem_2_size, mvq->umem1.size);
832 	MLX5_SET(virtio_q, vq_ctx, umem_3_id, mvq->umem3.id);
833 	MLX5_SET(virtio_q, vq_ctx, umem_3_size, mvq->umem1.size);
834 	MLX5_SET(virtio_q, vq_ctx, pd, ndev->mvdev.res.pdn);
835 	if (MLX5_CAP_DEV_VDPA_EMULATION(ndev->mvdev.mdev, eth_frame_offload_type))
836 		MLX5_SET(virtio_q, vq_ctx, virtio_version_1_0, 1);
837 
838 	err = mlx5_cmd_exec(ndev->mvdev.mdev, in, inlen, out, sizeof(out));
839 	if (err)
840 		goto err_cmd;
841 
842 	kfree(in);
843 	mvq->virtq_id = MLX5_GET(general_obj_out_cmd_hdr, out, obj_id);
844 
845 	return 0;
846 
847 err_cmd:
848 	kfree(in);
849 err_alloc:
850 	umems_destroy(ndev, mvq);
851 	return err;
852 }
853 
854 static void destroy_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
855 {
856 	u32 in[MLX5_ST_SZ_DW(destroy_virtio_net_q_in)] = {};
857 	u32 out[MLX5_ST_SZ_DW(destroy_virtio_net_q_out)] = {};
858 
859 	MLX5_SET(destroy_virtio_net_q_in, in, general_obj_out_cmd_hdr.opcode,
860 		 MLX5_CMD_OP_DESTROY_GENERAL_OBJECT);
861 	MLX5_SET(destroy_virtio_net_q_in, in, general_obj_out_cmd_hdr.obj_id, mvq->virtq_id);
862 	MLX5_SET(destroy_virtio_net_q_in, in, general_obj_out_cmd_hdr.uid, ndev->mvdev.res.uid);
863 	MLX5_SET(destroy_virtio_net_q_in, in, general_obj_out_cmd_hdr.obj_type,
864 		 MLX5_OBJ_TYPE_VIRTIO_NET_Q);
865 	if (mlx5_cmd_exec(ndev->mvdev.mdev, in, sizeof(in), out, sizeof(out))) {
866 		mlx5_vdpa_warn(&ndev->mvdev, "destroy virtqueue 0x%x\n", mvq->virtq_id);
867 		return;
868 	}
869 	umems_destroy(ndev, mvq);
870 }
871 
872 static u32 get_rqpn(struct mlx5_vdpa_virtqueue *mvq, bool fw)
873 {
874 	return fw ? mvq->vqqp.mqp.qpn : mvq->fwqp.mqp.qpn;
875 }
876 
877 static u32 get_qpn(struct mlx5_vdpa_virtqueue *mvq, bool fw)
878 {
879 	return fw ? mvq->fwqp.mqp.qpn : mvq->vqqp.mqp.qpn;
880 }
881 
882 static void alloc_inout(struct mlx5_vdpa_net *ndev, int cmd, void **in, int *inlen, void **out,
883 			int *outlen, u32 qpn, u32 rqpn)
884 {
885 	void *qpc;
886 	void *pp;
887 
888 	switch (cmd) {
889 	case MLX5_CMD_OP_2RST_QP:
890 		*inlen = MLX5_ST_SZ_BYTES(qp_2rst_in);
891 		*outlen = MLX5_ST_SZ_BYTES(qp_2rst_out);
892 		*in = kzalloc(*inlen, GFP_KERNEL);
893 		*out = kzalloc(*outlen, GFP_KERNEL);
894 		if (!*in || !*out)
895 			goto outerr;
896 
897 		MLX5_SET(qp_2rst_in, *in, opcode, cmd);
898 		MLX5_SET(qp_2rst_in, *in, uid, ndev->mvdev.res.uid);
899 		MLX5_SET(qp_2rst_in, *in, qpn, qpn);
900 		break;
901 	case MLX5_CMD_OP_RST2INIT_QP:
902 		*inlen = MLX5_ST_SZ_BYTES(rst2init_qp_in);
903 		*outlen = MLX5_ST_SZ_BYTES(rst2init_qp_out);
904 		*in = kzalloc(*inlen, GFP_KERNEL);
905 		*out = kzalloc(MLX5_ST_SZ_BYTES(rst2init_qp_out), GFP_KERNEL);
906 		if (!*in || !*out)
907 			goto outerr;
908 
909 		MLX5_SET(rst2init_qp_in, *in, opcode, cmd);
910 		MLX5_SET(rst2init_qp_in, *in, uid, ndev->mvdev.res.uid);
911 		MLX5_SET(rst2init_qp_in, *in, qpn, qpn);
912 		qpc = MLX5_ADDR_OF(rst2init_qp_in, *in, qpc);
913 		MLX5_SET(qpc, qpc, remote_qpn, rqpn);
914 		MLX5_SET(qpc, qpc, rwe, 1);
915 		pp = MLX5_ADDR_OF(qpc, qpc, primary_address_path);
916 		MLX5_SET(ads, pp, vhca_port_num, 1);
917 		break;
918 	case MLX5_CMD_OP_INIT2RTR_QP:
919 		*inlen = MLX5_ST_SZ_BYTES(init2rtr_qp_in);
920 		*outlen = MLX5_ST_SZ_BYTES(init2rtr_qp_out);
921 		*in = kzalloc(*inlen, GFP_KERNEL);
922 		*out = kzalloc(MLX5_ST_SZ_BYTES(init2rtr_qp_out), GFP_KERNEL);
923 		if (!*in || !*out)
924 			goto outerr;
925 
926 		MLX5_SET(init2rtr_qp_in, *in, opcode, cmd);
927 		MLX5_SET(init2rtr_qp_in, *in, uid, ndev->mvdev.res.uid);
928 		MLX5_SET(init2rtr_qp_in, *in, qpn, qpn);
929 		qpc = MLX5_ADDR_OF(rst2init_qp_in, *in, qpc);
930 		MLX5_SET(qpc, qpc, mtu, MLX5_QPC_MTU_256_BYTES);
931 		MLX5_SET(qpc, qpc, log_msg_max, 30);
932 		MLX5_SET(qpc, qpc, remote_qpn, rqpn);
933 		pp = MLX5_ADDR_OF(qpc, qpc, primary_address_path);
934 		MLX5_SET(ads, pp, fl, 1);
935 		break;
936 	case MLX5_CMD_OP_RTR2RTS_QP:
937 		*inlen = MLX5_ST_SZ_BYTES(rtr2rts_qp_in);
938 		*outlen = MLX5_ST_SZ_BYTES(rtr2rts_qp_out);
939 		*in = kzalloc(*inlen, GFP_KERNEL);
940 		*out = kzalloc(MLX5_ST_SZ_BYTES(rtr2rts_qp_out), GFP_KERNEL);
941 		if (!*in || !*out)
942 			goto outerr;
943 
944 		MLX5_SET(rtr2rts_qp_in, *in, opcode, cmd);
945 		MLX5_SET(rtr2rts_qp_in, *in, uid, ndev->mvdev.res.uid);
946 		MLX5_SET(rtr2rts_qp_in, *in, qpn, qpn);
947 		qpc = MLX5_ADDR_OF(rst2init_qp_in, *in, qpc);
948 		pp = MLX5_ADDR_OF(qpc, qpc, primary_address_path);
949 		MLX5_SET(ads, pp, ack_timeout, 14);
950 		MLX5_SET(qpc, qpc, retry_count, 7);
951 		MLX5_SET(qpc, qpc, rnr_retry, 7);
952 		break;
953 	default:
954 		goto outerr_nullify;
955 	}
956 
957 	return;
958 
959 outerr:
960 	kfree(*in);
961 	kfree(*out);
962 outerr_nullify:
963 	*in = NULL;
964 	*out = NULL;
965 }
966 
967 static void free_inout(void *in, void *out)
968 {
969 	kfree(in);
970 	kfree(out);
971 }
972 
973 /* Two QPs are used by each virtqueue. One is used by the driver and one by
974  * firmware. The fw argument indicates whether the subjected QP is the one used
975  * by firmware.
976  */
977 static int modify_qp(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq, bool fw, int cmd)
978 {
979 	int outlen;
980 	int inlen;
981 	void *out;
982 	void *in;
983 	int err;
984 
985 	alloc_inout(ndev, cmd, &in, &inlen, &out, &outlen, get_qpn(mvq, fw), get_rqpn(mvq, fw));
986 	if (!in || !out)
987 		return -ENOMEM;
988 
989 	err = mlx5_cmd_exec(ndev->mvdev.mdev, in, inlen, out, outlen);
990 	free_inout(in, out);
991 	return err;
992 }
993 
994 static int connect_qps(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
995 {
996 	int err;
997 
998 	err = modify_qp(ndev, mvq, true, MLX5_CMD_OP_2RST_QP);
999 	if (err)
1000 		return err;
1001 
1002 	err = modify_qp(ndev, mvq, false, MLX5_CMD_OP_2RST_QP);
1003 	if (err)
1004 		return err;
1005 
1006 	err = modify_qp(ndev, mvq, true, MLX5_CMD_OP_RST2INIT_QP);
1007 	if (err)
1008 		return err;
1009 
1010 	err = modify_qp(ndev, mvq, false, MLX5_CMD_OP_RST2INIT_QP);
1011 	if (err)
1012 		return err;
1013 
1014 	err = modify_qp(ndev, mvq, true, MLX5_CMD_OP_INIT2RTR_QP);
1015 	if (err)
1016 		return err;
1017 
1018 	err = modify_qp(ndev, mvq, false, MLX5_CMD_OP_INIT2RTR_QP);
1019 	if (err)
1020 		return err;
1021 
1022 	return modify_qp(ndev, mvq, true, MLX5_CMD_OP_RTR2RTS_QP);
1023 }
1024 
1025 struct mlx5_virtq_attr {
1026 	u8 state;
1027 	u16 available_index;
1028 	u16 used_index;
1029 };
1030 
1031 static int query_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq,
1032 			   struct mlx5_virtq_attr *attr)
1033 {
1034 	int outlen = MLX5_ST_SZ_BYTES(query_virtio_net_q_out);
1035 	u32 in[MLX5_ST_SZ_DW(query_virtio_net_q_in)] = {};
1036 	void *out;
1037 	void *obj_context;
1038 	void *cmd_hdr;
1039 	int err;
1040 
1041 	out = kzalloc(outlen, GFP_KERNEL);
1042 	if (!out)
1043 		return -ENOMEM;
1044 
1045 	cmd_hdr = MLX5_ADDR_OF(query_virtio_net_q_in, in, general_obj_in_cmd_hdr);
1046 
1047 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, opcode, MLX5_CMD_OP_QUERY_GENERAL_OBJECT);
1048 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, obj_type, MLX5_OBJ_TYPE_VIRTIO_NET_Q);
1049 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, obj_id, mvq->virtq_id);
1050 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, uid, ndev->mvdev.res.uid);
1051 	err = mlx5_cmd_exec(ndev->mvdev.mdev, in, sizeof(in), out, outlen);
1052 	if (err)
1053 		goto err_cmd;
1054 
1055 	obj_context = MLX5_ADDR_OF(query_virtio_net_q_out, out, obj_context);
1056 	memset(attr, 0, sizeof(*attr));
1057 	attr->state = MLX5_GET(virtio_net_q_object, obj_context, state);
1058 	attr->available_index = MLX5_GET(virtio_net_q_object, obj_context, hw_available_index);
1059 	attr->used_index = MLX5_GET(virtio_net_q_object, obj_context, hw_used_index);
1060 	kfree(out);
1061 	return 0;
1062 
1063 err_cmd:
1064 	kfree(out);
1065 	return err;
1066 }
1067 
1068 static int modify_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq, int state)
1069 {
1070 	int inlen = MLX5_ST_SZ_BYTES(modify_virtio_net_q_in);
1071 	u32 out[MLX5_ST_SZ_DW(modify_virtio_net_q_out)] = {};
1072 	void *obj_context;
1073 	void *cmd_hdr;
1074 	void *in;
1075 	int err;
1076 
1077 	in = kzalloc(inlen, GFP_KERNEL);
1078 	if (!in)
1079 		return -ENOMEM;
1080 
1081 	cmd_hdr = MLX5_ADDR_OF(modify_virtio_net_q_in, in, general_obj_in_cmd_hdr);
1082 
1083 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, opcode, MLX5_CMD_OP_MODIFY_GENERAL_OBJECT);
1084 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, obj_type, MLX5_OBJ_TYPE_VIRTIO_NET_Q);
1085 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, obj_id, mvq->virtq_id);
1086 	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, uid, ndev->mvdev.res.uid);
1087 
1088 	obj_context = MLX5_ADDR_OF(modify_virtio_net_q_in, in, obj_context);
1089 	MLX5_SET64(virtio_net_q_object, obj_context, modify_field_select,
1090 		   MLX5_VIRTQ_MODIFY_MASK_STATE);
1091 	MLX5_SET(virtio_net_q_object, obj_context, state, state);
1092 	err = mlx5_cmd_exec(ndev->mvdev.mdev, in, inlen, out, sizeof(out));
1093 	kfree(in);
1094 	if (!err)
1095 		mvq->fw_state = state;
1096 
1097 	return err;
1098 }
1099 
1100 static int setup_vq(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
1101 {
1102 	u16 idx = mvq->index;
1103 	int err;
1104 
1105 	if (!mvq->num_ent)
1106 		return 0;
1107 
1108 	if (mvq->initialized) {
1109 		mlx5_vdpa_warn(&ndev->mvdev, "attempt re init\n");
1110 		return -EINVAL;
1111 	}
1112 
1113 	err = cq_create(ndev, idx, mvq->num_ent);
1114 	if (err)
1115 		return err;
1116 
1117 	err = qp_create(ndev, mvq, &mvq->fwqp);
1118 	if (err)
1119 		goto err_fwqp;
1120 
1121 	err = qp_create(ndev, mvq, &mvq->vqqp);
1122 	if (err)
1123 		goto err_vqqp;
1124 
1125 	err = connect_qps(ndev, mvq);
1126 	if (err)
1127 		goto err_connect;
1128 
1129 	err = create_virtqueue(ndev, mvq);
1130 	if (err)
1131 		goto err_connect;
1132 
1133 	if (mvq->ready) {
1134 		err = modify_virtqueue(ndev, mvq, MLX5_VIRTIO_NET_Q_OBJECT_STATE_RDY);
1135 		if (err) {
1136 			mlx5_vdpa_warn(&ndev->mvdev, "failed to modify to ready vq idx %d(%d)\n",
1137 				       idx, err);
1138 			goto err_connect;
1139 		}
1140 	}
1141 
1142 	mvq->initialized = true;
1143 	return 0;
1144 
1145 err_connect:
1146 	qp_destroy(ndev, &mvq->vqqp);
1147 err_vqqp:
1148 	qp_destroy(ndev, &mvq->fwqp);
1149 err_fwqp:
1150 	cq_destroy(ndev, idx);
1151 	return err;
1152 }
1153 
1154 static void suspend_vq(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
1155 {
1156 	struct mlx5_virtq_attr attr;
1157 
1158 	if (!mvq->initialized)
1159 		return;
1160 
1161 	if (mvq->fw_state != MLX5_VIRTIO_NET_Q_OBJECT_STATE_RDY)
1162 		return;
1163 
1164 	if (modify_virtqueue(ndev, mvq, MLX5_VIRTIO_NET_Q_OBJECT_STATE_SUSPEND))
1165 		mlx5_vdpa_warn(&ndev->mvdev, "modify to suspend failed\n");
1166 
1167 	if (query_virtqueue(ndev, mvq, &attr)) {
1168 		mlx5_vdpa_warn(&ndev->mvdev, "failed to query virtqueue\n");
1169 		return;
1170 	}
1171 	mvq->avail_idx = attr.available_index;
1172 	mvq->used_idx = attr.used_index;
1173 }
1174 
1175 static void suspend_vqs(struct mlx5_vdpa_net *ndev)
1176 {
1177 	int i;
1178 
1179 	for (i = 0; i < MLX5_MAX_SUPPORTED_VQS; i++)
1180 		suspend_vq(ndev, &ndev->vqs[i]);
1181 }
1182 
1183 static void teardown_vq(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
1184 {
1185 	if (!mvq->initialized)
1186 		return;
1187 
1188 	suspend_vq(ndev, mvq);
1189 	destroy_virtqueue(ndev, mvq);
1190 	qp_destroy(ndev, &mvq->vqqp);
1191 	qp_destroy(ndev, &mvq->fwqp);
1192 	cq_destroy(ndev, mvq->index);
1193 	mvq->initialized = false;
1194 }
1195 
1196 static int create_rqt(struct mlx5_vdpa_net *ndev)
1197 {
1198 	int log_max_rqt;
1199 	__be32 *list;
1200 	void *rqtc;
1201 	int inlen;
1202 	void *in;
1203 	int i, j;
1204 	int err;
1205 
1206 	log_max_rqt = min_t(int, 1, MLX5_CAP_GEN(ndev->mvdev.mdev, log_max_rqt_size));
1207 	if (log_max_rqt < 1)
1208 		return -EOPNOTSUPP;
1209 
1210 	inlen = MLX5_ST_SZ_BYTES(create_rqt_in) + (1 << log_max_rqt) * MLX5_ST_SZ_BYTES(rq_num);
1211 	in = kzalloc(inlen, GFP_KERNEL);
1212 	if (!in)
1213 		return -ENOMEM;
1214 
1215 	MLX5_SET(create_rqt_in, in, uid, ndev->mvdev.res.uid);
1216 	rqtc = MLX5_ADDR_OF(create_rqt_in, in, rqt_context);
1217 
1218 	MLX5_SET(rqtc, rqtc, list_q_type, MLX5_RQTC_LIST_Q_TYPE_VIRTIO_NET_Q);
1219 	MLX5_SET(rqtc, rqtc, rqt_max_size, 1 << log_max_rqt);
1220 	MLX5_SET(rqtc, rqtc, rqt_actual_size, 1);
1221 	list = MLX5_ADDR_OF(rqtc, rqtc, rq_num[0]);
1222 	for (i = 0, j = 0; j < ndev->mvdev.max_vqs; j++) {
1223 		if (!ndev->vqs[j].initialized)
1224 			continue;
1225 
1226 		if (!vq_is_tx(ndev->vqs[j].index)) {
1227 			list[i] = cpu_to_be32(ndev->vqs[j].virtq_id);
1228 			i++;
1229 		}
1230 	}
1231 
1232 	err = mlx5_vdpa_create_rqt(&ndev->mvdev, in, inlen, &ndev->res.rqtn);
1233 	kfree(in);
1234 	if (err)
1235 		return err;
1236 
1237 	return 0;
1238 }
1239 
1240 static void destroy_rqt(struct mlx5_vdpa_net *ndev)
1241 {
1242 	mlx5_vdpa_destroy_rqt(&ndev->mvdev, ndev->res.rqtn);
1243 }
1244 
1245 static int create_tir(struct mlx5_vdpa_net *ndev)
1246 {
1247 #define HASH_IP_L4PORTS                                                                            \
1248 	(MLX5_HASH_FIELD_SEL_SRC_IP | MLX5_HASH_FIELD_SEL_DST_IP | MLX5_HASH_FIELD_SEL_L4_SPORT |  \
1249 	 MLX5_HASH_FIELD_SEL_L4_DPORT)
1250 	static const u8 rx_hash_toeplitz_key[] = { 0x2c, 0xc6, 0x81, 0xd1, 0x5b, 0xdb, 0xf4, 0xf7,
1251 						   0xfc, 0xa2, 0x83, 0x19, 0xdb, 0x1a, 0x3e, 0x94,
1252 						   0x6b, 0x9e, 0x38, 0xd9, 0x2c, 0x9c, 0x03, 0xd1,
1253 						   0xad, 0x99, 0x44, 0xa7, 0xd9, 0x56, 0x3d, 0x59,
1254 						   0x06, 0x3c, 0x25, 0xf3, 0xfc, 0x1f, 0xdc, 0x2a };
1255 	void *rss_key;
1256 	void *outer;
1257 	void *tirc;
1258 	void *in;
1259 	int err;
1260 
1261 	in = kzalloc(MLX5_ST_SZ_BYTES(create_tir_in), GFP_KERNEL);
1262 	if (!in)
1263 		return -ENOMEM;
1264 
1265 	MLX5_SET(create_tir_in, in, uid, ndev->mvdev.res.uid);
1266 	tirc = MLX5_ADDR_OF(create_tir_in, in, ctx);
1267 	MLX5_SET(tirc, tirc, disp_type, MLX5_TIRC_DISP_TYPE_INDIRECT);
1268 
1269 	MLX5_SET(tirc, tirc, rx_hash_symmetric, 1);
1270 	MLX5_SET(tirc, tirc, rx_hash_fn, MLX5_RX_HASH_FN_TOEPLITZ);
1271 	rss_key = MLX5_ADDR_OF(tirc, tirc, rx_hash_toeplitz_key);
1272 	memcpy(rss_key, rx_hash_toeplitz_key, sizeof(rx_hash_toeplitz_key));
1273 
1274 	outer = MLX5_ADDR_OF(tirc, tirc, rx_hash_field_selector_outer);
1275 	MLX5_SET(rx_hash_field_select, outer, l3_prot_type, MLX5_L3_PROT_TYPE_IPV4);
1276 	MLX5_SET(rx_hash_field_select, outer, l4_prot_type, MLX5_L4_PROT_TYPE_TCP);
1277 	MLX5_SET(rx_hash_field_select, outer, selected_fields, HASH_IP_L4PORTS);
1278 
1279 	MLX5_SET(tirc, tirc, indirect_table, ndev->res.rqtn);
1280 	MLX5_SET(tirc, tirc, transport_domain, ndev->res.tdn);
1281 
1282 	err = mlx5_vdpa_create_tir(&ndev->mvdev, in, &ndev->res.tirn);
1283 	kfree(in);
1284 	return err;
1285 }
1286 
1287 static void destroy_tir(struct mlx5_vdpa_net *ndev)
1288 {
1289 	mlx5_vdpa_destroy_tir(&ndev->mvdev, ndev->res.tirn);
1290 }
1291 
1292 static int add_fwd_to_tir(struct mlx5_vdpa_net *ndev)
1293 {
1294 	struct mlx5_flow_destination dest[2] = {};
1295 	struct mlx5_flow_table_attr ft_attr = {};
1296 	struct mlx5_flow_act flow_act = {};
1297 	struct mlx5_flow_namespace *ns;
1298 	int err;
1299 
1300 	/* for now, one entry, match all, forward to tir */
1301 	ft_attr.max_fte = 1;
1302 	ft_attr.autogroup.max_num_groups = 1;
1303 
1304 	ns = mlx5_get_flow_namespace(ndev->mvdev.mdev, MLX5_FLOW_NAMESPACE_BYPASS);
1305 	if (!ns) {
1306 		mlx5_vdpa_warn(&ndev->mvdev, "get flow namespace\n");
1307 		return -EOPNOTSUPP;
1308 	}
1309 
1310 	ndev->rxft = mlx5_create_auto_grouped_flow_table(ns, &ft_attr);
1311 	if (IS_ERR(ndev->rxft))
1312 		return PTR_ERR(ndev->rxft);
1313 
1314 	ndev->rx_counter = mlx5_fc_create(ndev->mvdev.mdev, false);
1315 	if (IS_ERR(ndev->rx_counter)) {
1316 		err = PTR_ERR(ndev->rx_counter);
1317 		goto err_fc;
1318 	}
1319 
1320 	flow_act.action = MLX5_FLOW_CONTEXT_ACTION_FWD_DEST | MLX5_FLOW_CONTEXT_ACTION_COUNT;
1321 	dest[0].type = MLX5_FLOW_DESTINATION_TYPE_TIR;
1322 	dest[0].tir_num = ndev->res.tirn;
1323 	dest[1].type = MLX5_FLOW_DESTINATION_TYPE_COUNTER;
1324 	dest[1].counter_id = mlx5_fc_id(ndev->rx_counter);
1325 	ndev->rx_rule = mlx5_add_flow_rules(ndev->rxft, NULL, &flow_act, dest, 2);
1326 	if (IS_ERR(ndev->rx_rule)) {
1327 		err = PTR_ERR(ndev->rx_rule);
1328 		ndev->rx_rule = NULL;
1329 		goto err_rule;
1330 	}
1331 
1332 	return 0;
1333 
1334 err_rule:
1335 	mlx5_fc_destroy(ndev->mvdev.mdev, ndev->rx_counter);
1336 err_fc:
1337 	mlx5_destroy_flow_table(ndev->rxft);
1338 	return err;
1339 }
1340 
1341 static void remove_fwd_to_tir(struct mlx5_vdpa_net *ndev)
1342 {
1343 	if (!ndev->rx_rule)
1344 		return;
1345 
1346 	mlx5_del_flow_rules(ndev->rx_rule);
1347 	mlx5_fc_destroy(ndev->mvdev.mdev, ndev->rx_counter);
1348 	mlx5_destroy_flow_table(ndev->rxft);
1349 
1350 	ndev->rx_rule = NULL;
1351 }
1352 
1353 static void mlx5_vdpa_kick_vq(struct vdpa_device *vdev, u16 idx)
1354 {
1355 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1356 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1357 	struct mlx5_vdpa_virtqueue *mvq = &ndev->vqs[idx];
1358 
1359 	if (unlikely(!mvq->ready))
1360 		return;
1361 
1362 	iowrite16(idx, ndev->mvdev.res.kick_addr);
1363 }
1364 
1365 static int mlx5_vdpa_set_vq_address(struct vdpa_device *vdev, u16 idx, u64 desc_area,
1366 				    u64 driver_area, u64 device_area)
1367 {
1368 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1369 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1370 	struct mlx5_vdpa_virtqueue *mvq = &ndev->vqs[idx];
1371 
1372 	mvq->desc_addr = desc_area;
1373 	mvq->device_addr = device_area;
1374 	mvq->driver_addr = driver_area;
1375 	return 0;
1376 }
1377 
1378 static void mlx5_vdpa_set_vq_num(struct vdpa_device *vdev, u16 idx, u32 num)
1379 {
1380 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1381 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1382 	struct mlx5_vdpa_virtqueue *mvq;
1383 
1384 	mvq = &ndev->vqs[idx];
1385 	mvq->num_ent = num;
1386 }
1387 
1388 static void mlx5_vdpa_set_vq_cb(struct vdpa_device *vdev, u16 idx, struct vdpa_callback *cb)
1389 {
1390 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1391 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1392 	struct mlx5_vdpa_virtqueue *vq = &ndev->vqs[idx];
1393 
1394 	vq->event_cb = *cb;
1395 }
1396 
1397 static void mlx5_vdpa_set_vq_ready(struct vdpa_device *vdev, u16 idx, bool ready)
1398 {
1399 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1400 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1401 	struct mlx5_vdpa_virtqueue *mvq = &ndev->vqs[idx];
1402 
1403 	if (!ready)
1404 		suspend_vq(ndev, mvq);
1405 
1406 	mvq->ready = ready;
1407 }
1408 
1409 static bool mlx5_vdpa_get_vq_ready(struct vdpa_device *vdev, u16 idx)
1410 {
1411 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1412 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1413 	struct mlx5_vdpa_virtqueue *mvq = &ndev->vqs[idx];
1414 
1415 	return mvq->ready;
1416 }
1417 
1418 static int mlx5_vdpa_set_vq_state(struct vdpa_device *vdev, u16 idx,
1419 				  const struct vdpa_vq_state *state)
1420 {
1421 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1422 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1423 	struct mlx5_vdpa_virtqueue *mvq = &ndev->vqs[idx];
1424 
1425 	if (mvq->fw_state == MLX5_VIRTIO_NET_Q_OBJECT_STATE_RDY) {
1426 		mlx5_vdpa_warn(mvdev, "can't modify available index\n");
1427 		return -EINVAL;
1428 	}
1429 
1430 	mvq->used_idx = state->avail_index;
1431 	mvq->avail_idx = state->avail_index;
1432 	return 0;
1433 }
1434 
1435 static int mlx5_vdpa_get_vq_state(struct vdpa_device *vdev, u16 idx, struct vdpa_vq_state *state)
1436 {
1437 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1438 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1439 	struct mlx5_vdpa_virtqueue *mvq = &ndev->vqs[idx];
1440 	struct mlx5_virtq_attr attr;
1441 	int err;
1442 
1443 	/* If the virtq object was destroyed, use the value saved at
1444 	 * the last minute of suspend_vq. This caters for userspace
1445 	 * that cares about emulating the index after vq is stopped.
1446 	 */
1447 	if (!mvq->initialized) {
1448 		/* Firmware returns a wrong value for the available index.
1449 		 * Since both values should be identical, we take the value of
1450 		 * used_idx which is reported correctly.
1451 		 */
1452 		state->avail_index = mvq->used_idx;
1453 		return 0;
1454 	}
1455 
1456 	err = query_virtqueue(ndev, mvq, &attr);
1457 	if (err) {
1458 		mlx5_vdpa_warn(mvdev, "failed to query virtqueue\n");
1459 		return err;
1460 	}
1461 	state->avail_index = attr.used_index;
1462 	return 0;
1463 }
1464 
1465 static u32 mlx5_vdpa_get_vq_align(struct vdpa_device *vdev)
1466 {
1467 	return PAGE_SIZE;
1468 }
1469 
1470 enum { MLX5_VIRTIO_NET_F_GUEST_CSUM = 1 << 9,
1471 	MLX5_VIRTIO_NET_F_CSUM = 1 << 10,
1472 	MLX5_VIRTIO_NET_F_HOST_TSO6 = 1 << 11,
1473 	MLX5_VIRTIO_NET_F_HOST_TSO4 = 1 << 12,
1474 };
1475 
1476 static u64 mlx_to_vritio_features(u16 dev_features)
1477 {
1478 	u64 result = 0;
1479 
1480 	if (dev_features & MLX5_VIRTIO_NET_F_GUEST_CSUM)
1481 		result |= BIT_ULL(VIRTIO_NET_F_GUEST_CSUM);
1482 	if (dev_features & MLX5_VIRTIO_NET_F_CSUM)
1483 		result |= BIT_ULL(VIRTIO_NET_F_CSUM);
1484 	if (dev_features & MLX5_VIRTIO_NET_F_HOST_TSO6)
1485 		result |= BIT_ULL(VIRTIO_NET_F_HOST_TSO6);
1486 	if (dev_features & MLX5_VIRTIO_NET_F_HOST_TSO4)
1487 		result |= BIT_ULL(VIRTIO_NET_F_HOST_TSO4);
1488 
1489 	return result;
1490 }
1491 
1492 static u64 mlx5_vdpa_get_features(struct vdpa_device *vdev)
1493 {
1494 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1495 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1496 	u16 dev_features;
1497 
1498 	dev_features = MLX5_CAP_DEV_VDPA_EMULATION(mvdev->mdev, device_features_bits_mask);
1499 	ndev->mvdev.mlx_features = mlx_to_vritio_features(dev_features);
1500 	if (MLX5_CAP_DEV_VDPA_EMULATION(mvdev->mdev, virtio_version_1_0))
1501 		ndev->mvdev.mlx_features |= BIT_ULL(VIRTIO_F_VERSION_1);
1502 	ndev->mvdev.mlx_features |= BIT_ULL(VIRTIO_F_ACCESS_PLATFORM);
1503 	print_features(mvdev, ndev->mvdev.mlx_features, false);
1504 	return ndev->mvdev.mlx_features;
1505 }
1506 
1507 static int verify_min_features(struct mlx5_vdpa_dev *mvdev, u64 features)
1508 {
1509 	if (!(features & BIT_ULL(VIRTIO_F_ACCESS_PLATFORM)))
1510 		return -EOPNOTSUPP;
1511 
1512 	return 0;
1513 }
1514 
1515 static int setup_virtqueues(struct mlx5_vdpa_net *ndev)
1516 {
1517 	int err;
1518 	int i;
1519 
1520 	for (i = 0; i < 2 * mlx5_vdpa_max_qps(ndev->mvdev.max_vqs); i++) {
1521 		err = setup_vq(ndev, &ndev->vqs[i]);
1522 		if (err)
1523 			goto err_vq;
1524 	}
1525 
1526 	return 0;
1527 
1528 err_vq:
1529 	for (--i; i >= 0; i--)
1530 		teardown_vq(ndev, &ndev->vqs[i]);
1531 
1532 	return err;
1533 }
1534 
1535 static void teardown_virtqueues(struct mlx5_vdpa_net *ndev)
1536 {
1537 	struct mlx5_vdpa_virtqueue *mvq;
1538 	int i;
1539 
1540 	for (i = ndev->mvdev.max_vqs - 1; i >= 0; i--) {
1541 		mvq = &ndev->vqs[i];
1542 		if (!mvq->initialized)
1543 			continue;
1544 
1545 		teardown_vq(ndev, mvq);
1546 	}
1547 }
1548 
1549 /* TODO: cross-endian support */
1550 static inline bool mlx5_vdpa_is_little_endian(struct mlx5_vdpa_dev *mvdev)
1551 {
1552 	return virtio_legacy_is_little_endian() ||
1553 		(mvdev->actual_features & BIT_ULL(VIRTIO_F_VERSION_1));
1554 }
1555 
1556 static __virtio16 cpu_to_mlx5vdpa16(struct mlx5_vdpa_dev *mvdev, u16 val)
1557 {
1558 	return __cpu_to_virtio16(mlx5_vdpa_is_little_endian(mvdev), val);
1559 }
1560 
1561 static int mlx5_vdpa_set_features(struct vdpa_device *vdev, u64 features)
1562 {
1563 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1564 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1565 	int err;
1566 
1567 	print_features(mvdev, features, true);
1568 
1569 	err = verify_min_features(mvdev, features);
1570 	if (err)
1571 		return err;
1572 
1573 	ndev->mvdev.actual_features = features & ndev->mvdev.mlx_features;
1574 	ndev->config.mtu = cpu_to_mlx5vdpa16(mvdev, ndev->mtu);
1575 	ndev->config.status |= cpu_to_mlx5vdpa16(mvdev, VIRTIO_NET_S_LINK_UP);
1576 	return err;
1577 }
1578 
1579 static void mlx5_vdpa_set_config_cb(struct vdpa_device *vdev, struct vdpa_callback *cb)
1580 {
1581 	/* not implemented */
1582 	mlx5_vdpa_warn(to_mvdev(vdev), "set config callback not supported\n");
1583 }
1584 
1585 #define MLX5_VDPA_MAX_VQ_ENTRIES 256
1586 static u16 mlx5_vdpa_get_vq_num_max(struct vdpa_device *vdev)
1587 {
1588 	return MLX5_VDPA_MAX_VQ_ENTRIES;
1589 }
1590 
1591 static u32 mlx5_vdpa_get_device_id(struct vdpa_device *vdev)
1592 {
1593 	return VIRTIO_ID_NET;
1594 }
1595 
1596 static u32 mlx5_vdpa_get_vendor_id(struct vdpa_device *vdev)
1597 {
1598 	return PCI_VENDOR_ID_MELLANOX;
1599 }
1600 
1601 static u8 mlx5_vdpa_get_status(struct vdpa_device *vdev)
1602 {
1603 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1604 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1605 
1606 	print_status(mvdev, ndev->mvdev.status, false);
1607 	return ndev->mvdev.status;
1608 }
1609 
1610 static int save_channel_info(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
1611 {
1612 	struct mlx5_vq_restore_info *ri = &mvq->ri;
1613 	struct mlx5_virtq_attr attr;
1614 	int err;
1615 
1616 	if (!mvq->initialized)
1617 		return 0;
1618 
1619 	err = query_virtqueue(ndev, mvq, &attr);
1620 	if (err)
1621 		return err;
1622 
1623 	ri->avail_index = attr.available_index;
1624 	ri->used_index = attr.used_index;
1625 	ri->ready = mvq->ready;
1626 	ri->num_ent = mvq->num_ent;
1627 	ri->desc_addr = mvq->desc_addr;
1628 	ri->device_addr = mvq->device_addr;
1629 	ri->driver_addr = mvq->driver_addr;
1630 	ri->cb = mvq->event_cb;
1631 	ri->restore = true;
1632 	return 0;
1633 }
1634 
1635 static int save_channels_info(struct mlx5_vdpa_net *ndev)
1636 {
1637 	int i;
1638 
1639 	for (i = 0; i < ndev->mvdev.max_vqs; i++) {
1640 		memset(&ndev->vqs[i].ri, 0, sizeof(ndev->vqs[i].ri));
1641 		save_channel_info(ndev, &ndev->vqs[i]);
1642 	}
1643 	return 0;
1644 }
1645 
1646 static void mlx5_clear_vqs(struct mlx5_vdpa_net *ndev)
1647 {
1648 	int i;
1649 
1650 	for (i = 0; i < ndev->mvdev.max_vqs; i++)
1651 		memset(&ndev->vqs[i], 0, offsetof(struct mlx5_vdpa_virtqueue, ri));
1652 }
1653 
1654 static void restore_channels_info(struct mlx5_vdpa_net *ndev)
1655 {
1656 	struct mlx5_vdpa_virtqueue *mvq;
1657 	struct mlx5_vq_restore_info *ri;
1658 	int i;
1659 
1660 	mlx5_clear_vqs(ndev);
1661 	init_mvqs(ndev);
1662 	for (i = 0; i < ndev->mvdev.max_vqs; i++) {
1663 		mvq = &ndev->vqs[i];
1664 		ri = &mvq->ri;
1665 		if (!ri->restore)
1666 			continue;
1667 
1668 		mvq->avail_idx = ri->avail_index;
1669 		mvq->used_idx = ri->used_index;
1670 		mvq->ready = ri->ready;
1671 		mvq->num_ent = ri->num_ent;
1672 		mvq->desc_addr = ri->desc_addr;
1673 		mvq->device_addr = ri->device_addr;
1674 		mvq->driver_addr = ri->driver_addr;
1675 		mvq->event_cb = ri->cb;
1676 	}
1677 }
1678 
1679 static int mlx5_vdpa_change_map(struct mlx5_vdpa_net *ndev, struct vhost_iotlb *iotlb)
1680 {
1681 	int err;
1682 
1683 	suspend_vqs(ndev);
1684 	err = save_channels_info(ndev);
1685 	if (err)
1686 		goto err_mr;
1687 
1688 	teardown_driver(ndev);
1689 	mlx5_vdpa_destroy_mr(&ndev->mvdev);
1690 	err = mlx5_vdpa_create_mr(&ndev->mvdev, iotlb);
1691 	if (err)
1692 		goto err_mr;
1693 
1694 	if (!(ndev->mvdev.status & VIRTIO_CONFIG_S_DRIVER_OK))
1695 		return 0;
1696 
1697 	restore_channels_info(ndev);
1698 	err = setup_driver(ndev);
1699 	if (err)
1700 		goto err_setup;
1701 
1702 	return 0;
1703 
1704 err_setup:
1705 	mlx5_vdpa_destroy_mr(&ndev->mvdev);
1706 err_mr:
1707 	return err;
1708 }
1709 
1710 static int setup_driver(struct mlx5_vdpa_net *ndev)
1711 {
1712 	int err;
1713 
1714 	mutex_lock(&ndev->reslock);
1715 	if (ndev->setup) {
1716 		mlx5_vdpa_warn(&ndev->mvdev, "setup driver called for already setup driver\n");
1717 		err = 0;
1718 		goto out;
1719 	}
1720 	err = setup_virtqueues(ndev);
1721 	if (err) {
1722 		mlx5_vdpa_warn(&ndev->mvdev, "setup_virtqueues\n");
1723 		goto out;
1724 	}
1725 
1726 	err = create_rqt(ndev);
1727 	if (err) {
1728 		mlx5_vdpa_warn(&ndev->mvdev, "create_rqt\n");
1729 		goto err_rqt;
1730 	}
1731 
1732 	err = create_tir(ndev);
1733 	if (err) {
1734 		mlx5_vdpa_warn(&ndev->mvdev, "create_tir\n");
1735 		goto err_tir;
1736 	}
1737 
1738 	err = add_fwd_to_tir(ndev);
1739 	if (err) {
1740 		mlx5_vdpa_warn(&ndev->mvdev, "add_fwd_to_tir\n");
1741 		goto err_fwd;
1742 	}
1743 	ndev->setup = true;
1744 	mutex_unlock(&ndev->reslock);
1745 
1746 	return 0;
1747 
1748 err_fwd:
1749 	destroy_tir(ndev);
1750 err_tir:
1751 	destroy_rqt(ndev);
1752 err_rqt:
1753 	teardown_virtqueues(ndev);
1754 out:
1755 	mutex_unlock(&ndev->reslock);
1756 	return err;
1757 }
1758 
1759 static void teardown_driver(struct mlx5_vdpa_net *ndev)
1760 {
1761 	mutex_lock(&ndev->reslock);
1762 	if (!ndev->setup)
1763 		goto out;
1764 
1765 	remove_fwd_to_tir(ndev);
1766 	destroy_tir(ndev);
1767 	destroy_rqt(ndev);
1768 	teardown_virtqueues(ndev);
1769 	ndev->setup = false;
1770 out:
1771 	mutex_unlock(&ndev->reslock);
1772 }
1773 
1774 static void mlx5_vdpa_set_status(struct vdpa_device *vdev, u8 status)
1775 {
1776 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1777 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1778 	int err;
1779 
1780 	print_status(mvdev, status, true);
1781 	if (!status) {
1782 		mlx5_vdpa_info(mvdev, "performing device reset\n");
1783 		teardown_driver(ndev);
1784 		mlx5_vdpa_destroy_mr(&ndev->mvdev);
1785 		ndev->mvdev.status = 0;
1786 		ndev->mvdev.mlx_features = 0;
1787 		++mvdev->generation;
1788 		return;
1789 	}
1790 
1791 	if ((status ^ ndev->mvdev.status) & VIRTIO_CONFIG_S_DRIVER_OK) {
1792 		if (status & VIRTIO_CONFIG_S_DRIVER_OK) {
1793 			err = setup_driver(ndev);
1794 			if (err) {
1795 				mlx5_vdpa_warn(mvdev, "failed to setup driver\n");
1796 				goto err_setup;
1797 			}
1798 		} else {
1799 			mlx5_vdpa_warn(mvdev, "did not expect DRIVER_OK to be cleared\n");
1800 			return;
1801 		}
1802 	}
1803 
1804 	ndev->mvdev.status = status;
1805 	return;
1806 
1807 err_setup:
1808 	mlx5_vdpa_destroy_mr(&ndev->mvdev);
1809 	ndev->mvdev.status |= VIRTIO_CONFIG_S_FAILED;
1810 }
1811 
1812 static size_t mlx5_vdpa_get_config_size(struct vdpa_device *vdev)
1813 {
1814 	return sizeof(struct virtio_net_config);
1815 }
1816 
1817 static void mlx5_vdpa_get_config(struct vdpa_device *vdev, unsigned int offset, void *buf,
1818 				 unsigned int len)
1819 {
1820 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1821 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1822 
1823 	if (offset + len <= sizeof(struct virtio_net_config))
1824 		memcpy(buf, (u8 *)&ndev->config + offset, len);
1825 }
1826 
1827 static void mlx5_vdpa_set_config(struct vdpa_device *vdev, unsigned int offset, const void *buf,
1828 				 unsigned int len)
1829 {
1830 	/* not supported */
1831 }
1832 
1833 static u32 mlx5_vdpa_get_generation(struct vdpa_device *vdev)
1834 {
1835 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1836 
1837 	return mvdev->generation;
1838 }
1839 
1840 static int mlx5_vdpa_set_map(struct vdpa_device *vdev, struct vhost_iotlb *iotlb)
1841 {
1842 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1843 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
1844 	bool change_map;
1845 	int err;
1846 
1847 	err = mlx5_vdpa_handle_set_map(mvdev, iotlb, &change_map);
1848 	if (err) {
1849 		mlx5_vdpa_warn(mvdev, "set map failed(%d)\n", err);
1850 		return err;
1851 	}
1852 
1853 	if (change_map)
1854 		return mlx5_vdpa_change_map(ndev, iotlb);
1855 
1856 	return 0;
1857 }
1858 
1859 static void mlx5_vdpa_free(struct vdpa_device *vdev)
1860 {
1861 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
1862 	struct mlx5_vdpa_net *ndev;
1863 
1864 	ndev = to_mlx5_vdpa_ndev(mvdev);
1865 
1866 	free_resources(ndev);
1867 	mlx5_vdpa_free_resources(&ndev->mvdev);
1868 	mutex_destroy(&ndev->reslock);
1869 }
1870 
1871 static struct vdpa_notification_area mlx5_get_vq_notification(struct vdpa_device *vdev, u16 idx)
1872 {
1873 	struct vdpa_notification_area ret = {};
1874 
1875 	return ret;
1876 }
1877 
1878 static int mlx5_get_vq_irq(struct vdpa_device *vdv, u16 idx)
1879 {
1880 	return -EOPNOTSUPP;
1881 }
1882 
1883 static const struct vdpa_config_ops mlx5_vdpa_ops = {
1884 	.set_vq_address = mlx5_vdpa_set_vq_address,
1885 	.set_vq_num = mlx5_vdpa_set_vq_num,
1886 	.kick_vq = mlx5_vdpa_kick_vq,
1887 	.set_vq_cb = mlx5_vdpa_set_vq_cb,
1888 	.set_vq_ready = mlx5_vdpa_set_vq_ready,
1889 	.get_vq_ready = mlx5_vdpa_get_vq_ready,
1890 	.set_vq_state = mlx5_vdpa_set_vq_state,
1891 	.get_vq_state = mlx5_vdpa_get_vq_state,
1892 	.get_vq_notification = mlx5_get_vq_notification,
1893 	.get_vq_irq = mlx5_get_vq_irq,
1894 	.get_vq_align = mlx5_vdpa_get_vq_align,
1895 	.get_features = mlx5_vdpa_get_features,
1896 	.set_features = mlx5_vdpa_set_features,
1897 	.set_config_cb = mlx5_vdpa_set_config_cb,
1898 	.get_vq_num_max = mlx5_vdpa_get_vq_num_max,
1899 	.get_device_id = mlx5_vdpa_get_device_id,
1900 	.get_vendor_id = mlx5_vdpa_get_vendor_id,
1901 	.get_status = mlx5_vdpa_get_status,
1902 	.set_status = mlx5_vdpa_set_status,
1903 	.get_config_size = mlx5_vdpa_get_config_size,
1904 	.get_config = mlx5_vdpa_get_config,
1905 	.set_config = mlx5_vdpa_set_config,
1906 	.get_generation = mlx5_vdpa_get_generation,
1907 	.set_map = mlx5_vdpa_set_map,
1908 	.free = mlx5_vdpa_free,
1909 };
1910 
1911 static int query_mtu(struct mlx5_core_dev *mdev, u16 *mtu)
1912 {
1913 	u16 hw_mtu;
1914 	int err;
1915 
1916 	err = mlx5_query_nic_vport_mtu(mdev, &hw_mtu);
1917 	if (err)
1918 		return err;
1919 
1920 	*mtu = hw_mtu - MLX5V_ETH_HARD_MTU;
1921 	return 0;
1922 }
1923 
1924 static int alloc_resources(struct mlx5_vdpa_net *ndev)
1925 {
1926 	struct mlx5_vdpa_net_resources *res = &ndev->res;
1927 	int err;
1928 
1929 	if (res->valid) {
1930 		mlx5_vdpa_warn(&ndev->mvdev, "resources already allocated\n");
1931 		return -EEXIST;
1932 	}
1933 
1934 	err = mlx5_vdpa_alloc_transport_domain(&ndev->mvdev, &res->tdn);
1935 	if (err)
1936 		return err;
1937 
1938 	err = create_tis(ndev);
1939 	if (err)
1940 		goto err_tis;
1941 
1942 	res->valid = true;
1943 
1944 	return 0;
1945 
1946 err_tis:
1947 	mlx5_vdpa_dealloc_transport_domain(&ndev->mvdev, res->tdn);
1948 	return err;
1949 }
1950 
1951 static void free_resources(struct mlx5_vdpa_net *ndev)
1952 {
1953 	struct mlx5_vdpa_net_resources *res = &ndev->res;
1954 
1955 	if (!res->valid)
1956 		return;
1957 
1958 	destroy_tis(ndev);
1959 	mlx5_vdpa_dealloc_transport_domain(&ndev->mvdev, res->tdn);
1960 	res->valid = false;
1961 }
1962 
1963 static void init_mvqs(struct mlx5_vdpa_net *ndev)
1964 {
1965 	struct mlx5_vdpa_virtqueue *mvq;
1966 	int i;
1967 
1968 	for (i = 0; i < 2 * mlx5_vdpa_max_qps(ndev->mvdev.max_vqs); ++i) {
1969 		mvq = &ndev->vqs[i];
1970 		memset(mvq, 0, offsetof(struct mlx5_vdpa_virtqueue, ri));
1971 		mvq->index = i;
1972 		mvq->ndev = ndev;
1973 		mvq->fwqp.fw = true;
1974 	}
1975 	for (; i < ndev->mvdev.max_vqs; i++) {
1976 		mvq = &ndev->vqs[i];
1977 		memset(mvq, 0, offsetof(struct mlx5_vdpa_virtqueue, ri));
1978 		mvq->index = i;
1979 		mvq->ndev = ndev;
1980 	}
1981 }
1982 
1983 struct mlx5_vdpa_mgmtdev {
1984 	struct vdpa_mgmt_dev mgtdev;
1985 	struct mlx5_adev *madev;
1986 	struct mlx5_vdpa_net *ndev;
1987 };
1988 
1989 static int mlx5_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name)
1990 {
1991 	struct mlx5_vdpa_mgmtdev *mgtdev = container_of(v_mdev, struct mlx5_vdpa_mgmtdev, mgtdev);
1992 	struct virtio_net_config *config;
1993 	struct mlx5_vdpa_dev *mvdev;
1994 	struct mlx5_vdpa_net *ndev;
1995 	struct mlx5_core_dev *mdev;
1996 	u32 max_vqs;
1997 	int err;
1998 
1999 	if (mgtdev->ndev)
2000 		return -ENOSPC;
2001 
2002 	mdev = mgtdev->madev->mdev;
2003 	/* we save one virtqueue for control virtqueue should we require it */
2004 	max_vqs = MLX5_CAP_DEV_VDPA_EMULATION(mdev, max_num_virtio_queues);
2005 	max_vqs = min_t(u32, max_vqs, MLX5_MAX_SUPPORTED_VQS);
2006 
2007 	ndev = vdpa_alloc_device(struct mlx5_vdpa_net, mvdev.vdev, mdev->device, &mlx5_vdpa_ops,
2008 				 name);
2009 	if (IS_ERR(ndev))
2010 		return PTR_ERR(ndev);
2011 
2012 	ndev->mvdev.max_vqs = max_vqs;
2013 	mvdev = &ndev->mvdev;
2014 	mvdev->mdev = mdev;
2015 	init_mvqs(ndev);
2016 	mutex_init(&ndev->reslock);
2017 	config = &ndev->config;
2018 	err = query_mtu(mdev, &ndev->mtu);
2019 	if (err)
2020 		goto err_mtu;
2021 
2022 	err = mlx5_query_nic_vport_mac_address(mdev, 0, 0, config->mac);
2023 	if (err)
2024 		goto err_mtu;
2025 
2026 	mvdev->vdev.dma_dev = mdev->device;
2027 	err = mlx5_vdpa_alloc_resources(&ndev->mvdev);
2028 	if (err)
2029 		goto err_mtu;
2030 
2031 	err = alloc_resources(ndev);
2032 	if (err)
2033 		goto err_res;
2034 
2035 	mvdev->vdev.mdev = &mgtdev->mgtdev;
2036 	err = _vdpa_register_device(&mvdev->vdev, 2 * mlx5_vdpa_max_qps(max_vqs));
2037 	if (err)
2038 		goto err_reg;
2039 
2040 	mgtdev->ndev = ndev;
2041 	return 0;
2042 
2043 err_reg:
2044 	free_resources(ndev);
2045 err_res:
2046 	mlx5_vdpa_free_resources(&ndev->mvdev);
2047 err_mtu:
2048 	mutex_destroy(&ndev->reslock);
2049 	put_device(&mvdev->vdev.dev);
2050 	return err;
2051 }
2052 
2053 static void mlx5_vdpa_dev_del(struct vdpa_mgmt_dev *v_mdev, struct vdpa_device *dev)
2054 {
2055 	struct mlx5_vdpa_mgmtdev *mgtdev = container_of(v_mdev, struct mlx5_vdpa_mgmtdev, mgtdev);
2056 
2057 	_vdpa_unregister_device(dev);
2058 	mgtdev->ndev = NULL;
2059 }
2060 
2061 static const struct vdpa_mgmtdev_ops mdev_ops = {
2062 	.dev_add = mlx5_vdpa_dev_add,
2063 	.dev_del = mlx5_vdpa_dev_del,
2064 };
2065 
2066 static struct virtio_device_id id_table[] = {
2067 	{ VIRTIO_ID_NET, VIRTIO_DEV_ANY_ID },
2068 	{ 0 },
2069 };
2070 
2071 static int mlx5v_probe(struct auxiliary_device *adev,
2072 		       const struct auxiliary_device_id *id)
2073 
2074 {
2075 	struct mlx5_adev *madev = container_of(adev, struct mlx5_adev, adev);
2076 	struct mlx5_core_dev *mdev = madev->mdev;
2077 	struct mlx5_vdpa_mgmtdev *mgtdev;
2078 	int err;
2079 
2080 	mgtdev = kzalloc(sizeof(*mgtdev), GFP_KERNEL);
2081 	if (!mgtdev)
2082 		return -ENOMEM;
2083 
2084 	mgtdev->mgtdev.ops = &mdev_ops;
2085 	mgtdev->mgtdev.device = mdev->device;
2086 	mgtdev->mgtdev.id_table = id_table;
2087 	mgtdev->madev = madev;
2088 
2089 	err = vdpa_mgmtdev_register(&mgtdev->mgtdev);
2090 	if (err)
2091 		goto reg_err;
2092 
2093 	dev_set_drvdata(&adev->dev, mgtdev);
2094 
2095 	return 0;
2096 
2097 reg_err:
2098 	kfree(mgtdev);
2099 	return err;
2100 }
2101 
2102 static void mlx5v_remove(struct auxiliary_device *adev)
2103 {
2104 	struct mlx5_vdpa_mgmtdev *mgtdev;
2105 
2106 	mgtdev = dev_get_drvdata(&adev->dev);
2107 	vdpa_mgmtdev_unregister(&mgtdev->mgtdev);
2108 	kfree(mgtdev);
2109 }
2110 
2111 static const struct auxiliary_device_id mlx5v_id_table[] = {
2112 	{ .name = MLX5_ADEV_NAME ".vnet", },
2113 	{},
2114 };
2115 
2116 MODULE_DEVICE_TABLE(auxiliary, mlx5v_id_table);
2117 
2118 static struct auxiliary_driver mlx5v_driver = {
2119 	.name = "vnet",
2120 	.probe = mlx5v_probe,
2121 	.remove = mlx5v_remove,
2122 	.id_table = mlx5v_id_table,
2123 };
2124 
2125 module_auxiliary_driver(mlx5v_driver);
2126