1 #include "util.hpp"
2 
3 #include <arpa/inet.h>
4 #include <dlfcn.h>
5 #include <ifaddrs.h>
6 #include <linux/netlink.h>
7 #include <linux/rtnetlink.h>
8 #include <net/ethernet.h>
9 #include <net/if.h>
10 #include <netinet/in.h>
11 #include <sys/ioctl.h>
12 #include <sys/socket.h>
13 #include <sys/types.h>
14 #include <unistd.h>
15 
16 #include <cstdarg>
17 #include <cstdio>
18 #include <cstring>
19 #include <map>
20 #include <queue>
21 #include <stdexcept>
22 #include <stdplus/raw.hpp>
23 #include <string>
24 #include <string_view>
25 #include <vector>
26 
27 #define MAX_IFADDRS 5
28 
29 int debugging = false;
30 
31 /* Data for mocking getifaddrs */
32 struct ifaddr_storage
33 {
34     struct ifaddrs ifaddr;
35     struct sockaddr_storage addr;
36     struct sockaddr_storage mask;
37     struct sockaddr_storage bcast;
38 } mock_ifaddr_storage[MAX_IFADDRS];
39 
40 struct ifaddrs* mock_ifaddrs = nullptr;
41 
42 int ifaddr_count = 0;
43 
44 /* Stub library functions */
45 void freeifaddrs(ifaddrs* /*ifp*/)
46 {
47     return;
48 }
49 
50 std::map<int, std::queue<std::string>> mock_rtnetlinks;
51 
52 std::map<std::string, int> mock_if_nametoindex;
53 std::map<int, std::string> mock_if_indextoname;
54 std::map<std::string, ether_addr> mock_macs;
55 
56 void mock_clear()
57 {
58     mock_ifaddrs = nullptr;
59     ifaddr_count = 0;
60     mock_rtnetlinks.clear();
61     mock_if_nametoindex.clear();
62     mock_if_indextoname.clear();
63     mock_macs.clear();
64 }
65 
66 void mock_addIF(const std::string& name, int idx, const ether_addr& mac)
67 {
68     if (idx == 0)
69     {
70         throw std::invalid_argument("Bad interface index");
71     }
72 
73     mock_if_nametoindex[name] = idx;
74     mock_if_indextoname[idx] = name;
75     mock_macs[name] = mac;
76 }
77 
78 void mock_addIP(const char* name, const char* addr, const char* mask,
79                 unsigned int flags)
80 {
81     struct ifaddrs* ifaddr = &mock_ifaddr_storage[ifaddr_count].ifaddr;
82 
83     struct sockaddr_in* in =
84         reinterpret_cast<sockaddr_in*>(&mock_ifaddr_storage[ifaddr_count].addr);
85     struct sockaddr_in* mask_in =
86         reinterpret_cast<sockaddr_in*>(&mock_ifaddr_storage[ifaddr_count].mask);
87 
88     in->sin_family = AF_INET;
89     in->sin_port = 0;
90     in->sin_addr.s_addr = inet_addr(addr);
91 
92     mask_in->sin_family = AF_INET;
93     mask_in->sin_port = 0;
94     mask_in->sin_addr.s_addr = inet_addr(mask);
95 
96     ifaddr->ifa_next = nullptr;
97     ifaddr->ifa_name = const_cast<char*>(name);
98     ifaddr->ifa_flags = flags;
99     ifaddr->ifa_addr = reinterpret_cast<struct sockaddr*>(in);
100     ifaddr->ifa_netmask = reinterpret_cast<struct sockaddr*>(mask_in);
101     ifaddr->ifa_data = nullptr;
102 
103     if (ifaddr_count > 0)
104         mock_ifaddr_storage[ifaddr_count - 1].ifaddr.ifa_next = ifaddr;
105     ifaddr_count++;
106     mock_ifaddrs = &mock_ifaddr_storage[0].ifaddr;
107 }
108 
109 void validateMsgHdr(const struct msghdr* msg)
110 {
111     if (msg->msg_namelen != sizeof(sockaddr_nl))
112     {
113         fprintf(stderr, "bad namelen: %u\n", msg->msg_namelen);
114         abort();
115     }
116     const auto& from = *reinterpret_cast<sockaddr_nl*>(msg->msg_name);
117     if (from.nl_family != AF_NETLINK)
118     {
119         fprintf(stderr, "recvmsg bad family data\n");
120         abort();
121     }
122     if (msg->msg_iovlen != 1)
123     {
124         fprintf(stderr, "recvmsg unsupported iov configuration\n");
125         abort();
126     }
127 }
128 
129 ssize_t sendmsg_link_dump(std::queue<std::string>& msgs, std::string_view in)
130 {
131     const ssize_t ret = in.size();
132     const auto& hdrin = stdplus::raw::copyFrom<nlmsghdr>(in);
133     if (hdrin.nlmsg_type != RTM_GETLINK)
134     {
135         return 0;
136     }
137 
138     for (const auto& [name, idx] : mock_if_nametoindex)
139     {
140         ifinfomsg info{};
141         info.ifi_index = idx;
142         nlmsghdr hdr{};
143         hdr.nlmsg_len = NLMSG_LENGTH(sizeof(info));
144         hdr.nlmsg_type = RTM_NEWLINK;
145         hdr.nlmsg_flags = NLM_F_MULTI;
146         auto& out = msgs.emplace(hdr.nlmsg_len, '\0');
147         memcpy(out.data(), &hdr, sizeof(hdr));
148         memcpy(NLMSG_DATA(out.data()), &info, sizeof(info));
149     }
150 
151     nlmsghdr hdr{};
152     hdr.nlmsg_len = NLMSG_LENGTH(0);
153     hdr.nlmsg_type = NLMSG_DONE;
154     hdr.nlmsg_flags = NLM_F_MULTI;
155     auto& out = msgs.emplace(hdr.nlmsg_len, '\0');
156     memcpy(out.data(), &hdr, sizeof(hdr));
157     return ret;
158 }
159 
160 ssize_t sendmsg_ack(std::queue<std::string>& msgs, std::string_view in)
161 {
162     nlmsgerr ack{};
163     nlmsghdr hdr{};
164     hdr.nlmsg_len = NLMSG_LENGTH(sizeof(ack));
165     hdr.nlmsg_type = NLMSG_ERROR;
166     auto& out = msgs.emplace(hdr.nlmsg_len, '\0');
167     memcpy(out.data(), &hdr, sizeof(hdr));
168     memcpy(NLMSG_DATA(out.data()), &ack, sizeof(ack));
169     return in.size();
170 }
171 
172 extern "C" {
173 
174 int getifaddrs(ifaddrs** ifap)
175 {
176     *ifap = mock_ifaddrs;
177     if (mock_ifaddrs == nullptr)
178         return -1;
179     return (0);
180 }
181 
182 unsigned if_nametoindex(const char* ifname)
183 {
184     auto it = mock_if_nametoindex.find(ifname);
185     if (it == mock_if_nametoindex.end())
186     {
187         errno = ENXIO;
188         return 0;
189     }
190     return it->second;
191 }
192 
193 char* if_indextoname(unsigned ifindex, char* ifname)
194 {
195     auto it = mock_if_indextoname.find(ifindex);
196     if (it == mock_if_indextoname.end())
197     {
198         errno = ENXIO;
199         return NULL;
200     }
201     return std::strcpy(ifname, it->second.c_str());
202 }
203 
204 int ioctl(int fd, unsigned long int request, ...)
205 {
206     va_list vl;
207     va_start(vl, request);
208     void* data = va_arg(vl, void*);
209     va_end(vl);
210 
211     if (request == SIOCGIFHWADDR)
212     {
213         auto req = reinterpret_cast<ifreq*>(data);
214         auto it = mock_macs.find(req->ifr_name);
215         if (it == mock_macs.end())
216         {
217             errno = ENXIO;
218             return -1;
219         }
220         std::memcpy(req->ifr_hwaddr.sa_data, &it->second, sizeof(it->second));
221         return 0;
222     }
223 
224     static auto real_ioctl =
225         reinterpret_cast<decltype(&ioctl)>(dlsym(RTLD_NEXT, "ioctl"));
226     return real_ioctl(fd, request, data);
227 }
228 
229 int socket(int domain, int type, int protocol)
230 {
231     static auto real_socket =
232         reinterpret_cast<decltype(&socket)>(dlsym(RTLD_NEXT, "socket"));
233     int fd = real_socket(domain, type, protocol);
234     if (domain == AF_NETLINK && !(type & SOCK_RAW))
235     {
236         fprintf(stderr, "Netlink sockets must be RAW\n");
237         abort();
238     }
239     if (domain == AF_NETLINK && protocol == NETLINK_ROUTE)
240     {
241         mock_rtnetlinks[fd] = {};
242     }
243     return fd;
244 }
245 
246 int close(int fd)
247 {
248     auto it = mock_rtnetlinks.find(fd);
249     if (it != mock_rtnetlinks.end())
250     {
251         mock_rtnetlinks.erase(it);
252     }
253 
254     static auto real_close =
255         reinterpret_cast<decltype(&close)>(dlsym(RTLD_NEXT, "close"));
256     return real_close(fd);
257 }
258 
259 ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags)
260 {
261     auto it = mock_rtnetlinks.find(sockfd);
262     if (it == mock_rtnetlinks.end())
263     {
264         static auto real_sendmsg =
265             reinterpret_cast<decltype(&sendmsg)>(dlsym(RTLD_NEXT, "sendmsg"));
266         return real_sendmsg(sockfd, msg, flags);
267     }
268     auto& msgs = it->second;
269 
270     validateMsgHdr(msg);
271     if (!msgs.empty())
272     {
273         fprintf(stderr, "Unread netlink responses\n");
274         abort();
275     }
276 
277     ssize_t ret;
278     std::string_view iov(reinterpret_cast<char*>(msg->msg_iov[0].iov_base),
279                          msg->msg_iov[0].iov_len);
280 
281     ret = sendmsg_link_dump(msgs, iov);
282     if (ret != 0)
283     {
284         return ret;
285     }
286 
287     ret = sendmsg_ack(msgs, iov);
288     if (ret != 0)
289     {
290         return ret;
291     }
292 
293     errno = ENOSYS;
294     return -1;
295 }
296 
297 ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags)
298 {
299     auto it = mock_rtnetlinks.find(sockfd);
300     if (it == mock_rtnetlinks.end())
301     {
302         static auto real_recvmsg =
303             reinterpret_cast<decltype(&recvmsg)>(dlsym(RTLD_NEXT, "recvmsg"));
304         return real_recvmsg(sockfd, msg, flags);
305     }
306     auto& msgs = it->second;
307 
308     validateMsgHdr(msg);
309     constexpr size_t required_buf_size = 8192;
310     if (msg->msg_iov[0].iov_len < required_buf_size)
311     {
312         fprintf(stderr, "recvmsg iov too short: %zu\n",
313                 msg->msg_iov[0].iov_len);
314         abort();
315     }
316     if (msgs.empty())
317     {
318         fprintf(stderr, "No pending netlink responses\n");
319         abort();
320     }
321 
322     ssize_t ret = 0;
323     auto data = reinterpret_cast<char*>(msg->msg_iov[0].iov_base);
324     while (!msgs.empty())
325     {
326         const auto& msg = msgs.front();
327         if (NLMSG_ALIGN(ret) + msg.size() > required_buf_size)
328         {
329             break;
330         }
331         ret = NLMSG_ALIGN(ret);
332         memcpy(data + ret, msg.data(), msg.size());
333         ret += msg.size();
334         msgs.pop();
335     }
336     return ret;
337 }
338 
339 } // extern "C"
340