1 #include "util.hpp"
2 
3 #include <arpa/inet.h>
4 #include <dlfcn.h>
5 #include <ifaddrs.h>
6 #include <linux/netlink.h>
7 #include <linux/rtnetlink.h>
8 #include <net/ethernet.h>
9 #include <net/if.h>
10 #include <netinet/in.h>
11 #include <sys/ioctl.h>
12 #include <sys/socket.h>
13 #include <sys/types.h>
14 #include <unistd.h>
15 
16 #include <cstdarg>
17 #include <cstdio>
18 #include <cstring>
19 #include <map>
20 #include <queue>
21 #include <stdexcept>
22 #include <stdplus/raw.hpp>
23 #include <string>
24 #include <string_view>
25 #include <vector>
26 
27 #define MAX_IFADDRS 5
28 
29 int debugging = false;
30 
31 /* Data for mocking getifaddrs */
32 struct ifaddr_storage
33 {
34     struct ifaddrs ifaddr;
35     struct sockaddr_storage addr;
36     struct sockaddr_storage mask;
37     struct sockaddr_storage bcast;
38 } mock_ifaddr_storage[MAX_IFADDRS];
39 
40 struct ifaddrs* mock_ifaddrs = nullptr;
41 
42 int ifaddr_count = 0;
43 
44 /* Stub library functions */
45 void freeifaddrs(ifaddrs* /*ifp*/)
46 {
47     return;
48 }
49 
50 std::map<int, std::queue<std::string>> mock_rtnetlinks;
51 
52 struct MockInfo
53 {
54     unsigned idx;
55     unsigned flags;
56     std::optional<ether_addr> mac;
57     std::optional<unsigned> mtu;
58 };
59 
60 std::map<std::string, MockInfo> mock_if;
61 std::map<int, std::string> mock_if_indextoname;
62 
63 void mock_clear()
64 {
65     mock_ifaddrs = nullptr;
66     ifaddr_count = 0;
67     mock_rtnetlinks.clear();
68     mock_if.clear();
69     mock_if_indextoname.clear();
70 }
71 
72 void mock_addIF(const std::string& name, unsigned idx, unsigned flags,
73                 std::optional<ether_addr> mac, std::optional<unsigned> mtu)
74 {
75     if (idx == 0)
76     {
77         throw std::invalid_argument("Bad interface index");
78     }
79 
80     mock_if.emplace(
81         name, MockInfo{.idx = idx, .flags = flags, .mac = mac, .mtu = mtu});
82     mock_if_indextoname.emplace(idx, name);
83 }
84 
85 void mock_addIP(const char* name, const char* addr, const char* mask)
86 {
87     struct ifaddrs* ifaddr = &mock_ifaddr_storage[ifaddr_count].ifaddr;
88 
89     struct sockaddr_in* in =
90         reinterpret_cast<sockaddr_in*>(&mock_ifaddr_storage[ifaddr_count].addr);
91     struct sockaddr_in* mask_in =
92         reinterpret_cast<sockaddr_in*>(&mock_ifaddr_storage[ifaddr_count].mask);
93 
94     in->sin_family = AF_INET;
95     in->sin_port = 0;
96     in->sin_addr.s_addr = inet_addr(addr);
97 
98     mask_in->sin_family = AF_INET;
99     mask_in->sin_port = 0;
100     mask_in->sin_addr.s_addr = inet_addr(mask);
101 
102     ifaddr->ifa_next = nullptr;
103     ifaddr->ifa_name = const_cast<char*>(name);
104     ifaddr->ifa_flags = 0;
105     ifaddr->ifa_addr = reinterpret_cast<struct sockaddr*>(in);
106     ifaddr->ifa_netmask = reinterpret_cast<struct sockaddr*>(mask_in);
107     ifaddr->ifa_data = nullptr;
108 
109     if (ifaddr_count > 0)
110         mock_ifaddr_storage[ifaddr_count - 1].ifaddr.ifa_next = ifaddr;
111     ifaddr_count++;
112     mock_ifaddrs = &mock_ifaddr_storage[0].ifaddr;
113 }
114 
115 void validateMsgHdr(const struct msghdr* msg)
116 {
117     if (msg->msg_namelen != sizeof(sockaddr_nl))
118     {
119         fprintf(stderr, "bad namelen: %u\n", msg->msg_namelen);
120         abort();
121     }
122     const auto& from = *reinterpret_cast<sockaddr_nl*>(msg->msg_name);
123     if (from.nl_family != AF_NETLINK)
124     {
125         fprintf(stderr, "recvmsg bad family data\n");
126         abort();
127     }
128     if (msg->msg_iovlen != 1)
129     {
130         fprintf(stderr, "recvmsg unsupported iov configuration\n");
131         abort();
132     }
133 }
134 
135 ssize_t sendmsg_link_dump(std::queue<std::string>& msgs, std::string_view in)
136 {
137     const ssize_t ret = in.size();
138     const auto& hdrin = stdplus::raw::copyFrom<nlmsghdr>(in);
139     if (hdrin.nlmsg_type != RTM_GETLINK)
140     {
141         return 0;
142     }
143 
144     for (const auto& [name, i] : mock_if)
145     {
146         ifinfomsg info{};
147         info.ifi_index = i.idx;
148         info.ifi_flags = i.flags;
149         nlmsghdr hdr{};
150         hdr.nlmsg_len = NLMSG_LENGTH(sizeof(info));
151         hdr.nlmsg_type = RTM_NEWLINK;
152         hdr.nlmsg_flags = NLM_F_MULTI;
153         auto& out = msgs.emplace(hdr.nlmsg_len, '\0');
154         memcpy(out.data(), &hdr, sizeof(hdr));
155         memcpy(NLMSG_DATA(out.data()), &info, sizeof(info));
156     }
157 
158     nlmsghdr hdr{};
159     hdr.nlmsg_len = NLMSG_LENGTH(0);
160     hdr.nlmsg_type = NLMSG_DONE;
161     hdr.nlmsg_flags = NLM_F_MULTI;
162     auto& out = msgs.emplace(hdr.nlmsg_len, '\0');
163     memcpy(out.data(), &hdr, sizeof(hdr));
164     return ret;
165 }
166 
167 ssize_t sendmsg_ack(std::queue<std::string>& msgs, std::string_view in)
168 {
169     nlmsgerr ack{};
170     nlmsghdr hdr{};
171     hdr.nlmsg_len = NLMSG_LENGTH(sizeof(ack));
172     hdr.nlmsg_type = NLMSG_ERROR;
173     auto& out = msgs.emplace(hdr.nlmsg_len, '\0');
174     memcpy(out.data(), &hdr, sizeof(hdr));
175     memcpy(NLMSG_DATA(out.data()), &ack, sizeof(ack));
176     return in.size();
177 }
178 
179 extern "C" {
180 
181 int getifaddrs(ifaddrs** ifap)
182 {
183     *ifap = mock_ifaddrs;
184     if (mock_ifaddrs == nullptr)
185         return -1;
186     return (0);
187 }
188 
189 unsigned if_nametoindex(const char* ifname)
190 {
191     auto it = mock_if.find(ifname);
192     if (it == mock_if.end())
193     {
194         errno = ENXIO;
195         return 0;
196     }
197     return it->second.idx;
198 }
199 
200 char* if_indextoname(unsigned ifindex, char* ifname)
201 {
202     auto it = mock_if_indextoname.find(ifindex);
203     if (it == mock_if_indextoname.end())
204     {
205         errno = ENXIO;
206         return NULL;
207     }
208     return std::strcpy(ifname, it->second.c_str());
209 }
210 
211 int ioctl(int fd, unsigned long int request, ...)
212 {
213     va_list vl;
214     va_start(vl, request);
215     void* data = va_arg(vl, void*);
216     va_end(vl);
217 
218     auto req = reinterpret_cast<ifreq*>(data);
219     if (request == SIOCGIFHWADDR)
220     {
221         auto it = mock_if.find(req->ifr_name);
222         if (it == mock_if.end())
223         {
224             errno = ENXIO;
225             return -1;
226         }
227         if (!it->second.mac)
228         {
229             errno = EOPNOTSUPP;
230             return -1;
231         }
232         std::memcpy(req->ifr_hwaddr.sa_data, &*it->second.mac,
233                     sizeof(*it->second.mac));
234         return 0;
235     }
236     else if (request == SIOCGIFFLAGS)
237     {
238         auto it = mock_if.find(req->ifr_name);
239         if (it == mock_if.end())
240         {
241             errno = ENXIO;
242             return -1;
243         }
244         req->ifr_flags = it->second.flags;
245         return 0;
246     }
247     else if (request == SIOCGIFMTU)
248     {
249         auto it = mock_if.find(req->ifr_name);
250         if (it == mock_if.end())
251         {
252             errno = ENXIO;
253             return -1;
254         }
255         if (!it->second.mtu)
256         {
257             errno = EOPNOTSUPP;
258             return -1;
259         }
260         req->ifr_mtu = *it->second.mtu;
261         return 0;
262     }
263 
264     static auto real_ioctl =
265         reinterpret_cast<decltype(&ioctl)>(dlsym(RTLD_NEXT, "ioctl"));
266     return real_ioctl(fd, request, data);
267 }
268 
269 int socket(int domain, int type, int protocol)
270 {
271     static auto real_socket =
272         reinterpret_cast<decltype(&socket)>(dlsym(RTLD_NEXT, "socket"));
273     int fd = real_socket(domain, type, protocol);
274     if (domain == AF_NETLINK && !(type & SOCK_RAW))
275     {
276         fprintf(stderr, "Netlink sockets must be RAW\n");
277         abort();
278     }
279     if (domain == AF_NETLINK && protocol == NETLINK_ROUTE)
280     {
281         mock_rtnetlinks[fd] = {};
282     }
283     return fd;
284 }
285 
286 int close(int fd)
287 {
288     auto it = mock_rtnetlinks.find(fd);
289     if (it != mock_rtnetlinks.end())
290     {
291         mock_rtnetlinks.erase(it);
292     }
293 
294     static auto real_close =
295         reinterpret_cast<decltype(&close)>(dlsym(RTLD_NEXT, "close"));
296     return real_close(fd);
297 }
298 
299 ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags)
300 {
301     auto it = mock_rtnetlinks.find(sockfd);
302     if (it == mock_rtnetlinks.end())
303     {
304         static auto real_sendmsg =
305             reinterpret_cast<decltype(&sendmsg)>(dlsym(RTLD_NEXT, "sendmsg"));
306         return real_sendmsg(sockfd, msg, flags);
307     }
308     auto& msgs = it->second;
309 
310     validateMsgHdr(msg);
311     if (!msgs.empty())
312     {
313         fprintf(stderr, "Unread netlink responses\n");
314         abort();
315     }
316 
317     ssize_t ret;
318     std::string_view iov(reinterpret_cast<char*>(msg->msg_iov[0].iov_base),
319                          msg->msg_iov[0].iov_len);
320 
321     ret = sendmsg_link_dump(msgs, iov);
322     if (ret != 0)
323     {
324         return ret;
325     }
326 
327     ret = sendmsg_ack(msgs, iov);
328     if (ret != 0)
329     {
330         return ret;
331     }
332 
333     errno = ENOSYS;
334     return -1;
335 }
336 
337 ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags)
338 {
339     auto it = mock_rtnetlinks.find(sockfd);
340     if (it == mock_rtnetlinks.end())
341     {
342         static auto real_recvmsg =
343             reinterpret_cast<decltype(&recvmsg)>(dlsym(RTLD_NEXT, "recvmsg"));
344         return real_recvmsg(sockfd, msg, flags);
345     }
346     auto& msgs = it->second;
347 
348     validateMsgHdr(msg);
349     constexpr size_t required_buf_size = 8192;
350     if (msg->msg_iov[0].iov_len < required_buf_size)
351     {
352         fprintf(stderr, "recvmsg iov too short: %zu\n",
353                 msg->msg_iov[0].iov_len);
354         abort();
355     }
356     if (msgs.empty())
357     {
358         fprintf(stderr, "No pending netlink responses\n");
359         abort();
360     }
361 
362     ssize_t ret = 0;
363     auto data = reinterpret_cast<char*>(msg->msg_iov[0].iov_base);
364     while (!msgs.empty())
365     {
366         const auto& msg = msgs.front();
367         if (NLMSG_ALIGN(ret) + msg.size() > required_buf_size)
368         {
369             break;
370         }
371         ret = NLMSG_ALIGN(ret);
372         memcpy(data + ret, msg.data(), msg.size());
373         ret += msg.size();
374         msgs.pop();
375     }
376     return ret;
377 }
378 
379 } // extern "C"
380