135b1b538SChuck Lever // SPDX-License-Identifier: GPL-2.0-only
235b1b538SChuck Lever /*
335b1b538SChuck Lever * Handle the TLS Alert protocol
435b1b538SChuck Lever *
535b1b538SChuck Lever * Author: Chuck Lever <chuck.lever@oracle.com>
635b1b538SChuck Lever *
735b1b538SChuck Lever * Copyright (c) 2023, Oracle and/or its affiliates.
835b1b538SChuck Lever */
935b1b538SChuck Lever
1035b1b538SChuck Lever #include <linux/types.h>
1135b1b538SChuck Lever #include <linux/socket.h>
1235b1b538SChuck Lever #include <linux/kernel.h>
1335b1b538SChuck Lever #include <linux/module.h>
1435b1b538SChuck Lever #include <linux/skbuff.h>
1535b1b538SChuck Lever #include <linux/inet.h>
1635b1b538SChuck Lever
1735b1b538SChuck Lever #include <net/sock.h>
1835b1b538SChuck Lever #include <net/handshake.h>
1935b1b538SChuck Lever #include <net/tls.h>
2035b1b538SChuck Lever #include <net/tls_prot.h>
2135b1b538SChuck Lever
2235b1b538SChuck Lever #include "handshake.h"
2335b1b538SChuck Lever
24*b470985cSChuck Lever #include <trace/events/handshake.h>
25*b470985cSChuck Lever
2635b1b538SChuck Lever /**
2735b1b538SChuck Lever * tls_alert_send - send a TLS Alert on a kTLS socket
2835b1b538SChuck Lever * @sock: open kTLS socket to send on
2935b1b538SChuck Lever * @level: TLS Alert level
3035b1b538SChuck Lever * @description: TLS Alert description
3135b1b538SChuck Lever *
3235b1b538SChuck Lever * Returns zero on success or a negative errno.
3335b1b538SChuck Lever */
tls_alert_send(struct socket * sock,u8 level,u8 description)3435b1b538SChuck Lever int tls_alert_send(struct socket *sock, u8 level, u8 description)
3535b1b538SChuck Lever {
3635b1b538SChuck Lever u8 record_type = TLS_RECORD_TYPE_ALERT;
3735b1b538SChuck Lever u8 buf[CMSG_SPACE(sizeof(record_type))];
3835b1b538SChuck Lever struct msghdr msg = { 0 };
3935b1b538SChuck Lever struct cmsghdr *cmsg;
4035b1b538SChuck Lever struct kvec iov;
4135b1b538SChuck Lever u8 alert[2];
4235b1b538SChuck Lever int ret;
4335b1b538SChuck Lever
44*b470985cSChuck Lever trace_tls_alert_send(sock->sk, level, description);
45*b470985cSChuck Lever
4635b1b538SChuck Lever alert[0] = level;
4735b1b538SChuck Lever alert[1] = description;
4835b1b538SChuck Lever iov.iov_base = alert;
4935b1b538SChuck Lever iov.iov_len = sizeof(alert);
5035b1b538SChuck Lever
5135b1b538SChuck Lever memset(buf, 0, sizeof(buf));
5235b1b538SChuck Lever msg.msg_control = buf;
5335b1b538SChuck Lever msg.msg_controllen = sizeof(buf);
5435b1b538SChuck Lever msg.msg_flags = MSG_DONTWAIT;
5535b1b538SChuck Lever
5635b1b538SChuck Lever cmsg = CMSG_FIRSTHDR(&msg);
5735b1b538SChuck Lever cmsg->cmsg_level = SOL_TLS;
5835b1b538SChuck Lever cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
5935b1b538SChuck Lever cmsg->cmsg_len = CMSG_LEN(sizeof(record_type));
6035b1b538SChuck Lever memcpy(CMSG_DATA(cmsg), &record_type, sizeof(record_type));
6135b1b538SChuck Lever
6235b1b538SChuck Lever iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iov, 1, iov.iov_len);
6335b1b538SChuck Lever ret = sock_sendmsg(sock, &msg);
6435b1b538SChuck Lever return ret < 0 ? ret : 0;
6535b1b538SChuck Lever }
6639d0e38dSChuck Lever
6739d0e38dSChuck Lever /**
6839d0e38dSChuck Lever * tls_get_record_type - Look for TLS RECORD_TYPE information
6939d0e38dSChuck Lever * @sk: socket (for IP address information)
7039d0e38dSChuck Lever * @cmsg: incoming message to be parsed
7139d0e38dSChuck Lever *
7239d0e38dSChuck Lever * Returns zero or a TLS_RECORD_TYPE value.
7339d0e38dSChuck Lever */
tls_get_record_type(const struct sock * sk,const struct cmsghdr * cmsg)7439d0e38dSChuck Lever u8 tls_get_record_type(const struct sock *sk, const struct cmsghdr *cmsg)
7539d0e38dSChuck Lever {
7639d0e38dSChuck Lever u8 record_type;
7739d0e38dSChuck Lever
7839d0e38dSChuck Lever if (cmsg->cmsg_level != SOL_TLS)
7939d0e38dSChuck Lever return 0;
8039d0e38dSChuck Lever if (cmsg->cmsg_type != TLS_GET_RECORD_TYPE)
8139d0e38dSChuck Lever return 0;
8239d0e38dSChuck Lever
8339d0e38dSChuck Lever record_type = *((u8 *)CMSG_DATA(cmsg));
84*b470985cSChuck Lever trace_tls_contenttype(sk, record_type);
8539d0e38dSChuck Lever return record_type;
8639d0e38dSChuck Lever }
8739d0e38dSChuck Lever EXPORT_SYMBOL(tls_get_record_type);
8839d0e38dSChuck Lever
8939d0e38dSChuck Lever /**
9039d0e38dSChuck Lever * tls_alert_recv - Parse TLS Alert messages
9139d0e38dSChuck Lever * @sk: socket (for IP address information)
9239d0e38dSChuck Lever * @msg: incoming message to be parsed
9339d0e38dSChuck Lever * @level: OUT - TLS AlertLevel value
9439d0e38dSChuck Lever * @description: OUT - TLS AlertDescription value
9539d0e38dSChuck Lever *
9639d0e38dSChuck Lever */
tls_alert_recv(const struct sock * sk,const struct msghdr * msg,u8 * level,u8 * description)9739d0e38dSChuck Lever void tls_alert_recv(const struct sock *sk, const struct msghdr *msg,
9839d0e38dSChuck Lever u8 *level, u8 *description)
9939d0e38dSChuck Lever {
10039d0e38dSChuck Lever const struct kvec *iov;
10139d0e38dSChuck Lever u8 *data;
10239d0e38dSChuck Lever
10339d0e38dSChuck Lever iov = msg->msg_iter.kvec;
10439d0e38dSChuck Lever data = iov->iov_base;
10539d0e38dSChuck Lever *level = data[0];
10639d0e38dSChuck Lever *description = data[1];
107*b470985cSChuck Lever
108*b470985cSChuck Lever trace_tls_alert_recv(sk, *level, *description);
10939d0e38dSChuck Lever }
11039d0e38dSChuck Lever EXPORT_SYMBOL(tls_alert_recv);
111