1 #include "util.hpp" 2 3 #include <arpa/inet.h> 4 #include <dlfcn.h> 5 #include <ifaddrs.h> 6 #include <linux/netlink.h> 7 #include <linux/rtnetlink.h> 8 #include <net/ethernet.h> 9 #include <net/if.h> 10 #include <netinet/in.h> 11 #include <sys/ioctl.h> 12 #include <sys/socket.h> 13 #include <sys/types.h> 14 #include <unistd.h> 15 16 #include <cstdarg> 17 #include <cstdio> 18 #include <cstring> 19 #include <map> 20 #include <queue> 21 #include <stdexcept> 22 #include <stdplus/raw.hpp> 23 #include <string> 24 #include <string_view> 25 #include <vector> 26 27 #define MAX_IFADDRS 5 28 29 int debugging = false; 30 31 /* Data for mocking getifaddrs */ 32 struct ifaddr_storage 33 { 34 struct ifaddrs ifaddr; 35 struct sockaddr_storage addr; 36 struct sockaddr_storage mask; 37 struct sockaddr_storage bcast; 38 } mock_ifaddr_storage[MAX_IFADDRS]; 39 40 struct ifaddrs* mock_ifaddrs = nullptr; 41 42 int ifaddr_count = 0; 43 44 /* Stub library functions */ 45 void freeifaddrs(ifaddrs* /*ifp*/) 46 { 47 return; 48 } 49 50 std::map<int, std::queue<std::string>> mock_rtnetlinks; 51 52 std::map<std::string, int> mock_if_nametoindex; 53 std::map<int, std::string> mock_if_indextoname; 54 std::map<std::string, ether_addr> mock_macs; 55 56 void mock_clear() 57 { 58 mock_ifaddrs = nullptr; 59 ifaddr_count = 0; 60 mock_rtnetlinks.clear(); 61 mock_if_nametoindex.clear(); 62 mock_if_indextoname.clear(); 63 mock_macs.clear(); 64 } 65 66 void mock_addIF(const std::string& name, int idx, const ether_addr& mac) 67 { 68 if (idx == 0) 69 { 70 throw std::invalid_argument("Bad interface index"); 71 } 72 73 mock_if_nametoindex[name] = idx; 74 mock_if_indextoname[idx] = name; 75 mock_macs[name] = mac; 76 } 77 78 void mock_addIP(const char* name, const char* addr, const char* mask, 79 unsigned int flags) 80 { 81 struct ifaddrs* ifaddr = &mock_ifaddr_storage[ifaddr_count].ifaddr; 82 83 struct sockaddr_in* in = 84 reinterpret_cast<sockaddr_in*>(&mock_ifaddr_storage[ifaddr_count].addr); 85 struct sockaddr_in* mask_in = 86 reinterpret_cast<sockaddr_in*>(&mock_ifaddr_storage[ifaddr_count].mask); 87 88 in->sin_family = AF_INET; 89 in->sin_port = 0; 90 in->sin_addr.s_addr = inet_addr(addr); 91 92 mask_in->sin_family = AF_INET; 93 mask_in->sin_port = 0; 94 mask_in->sin_addr.s_addr = inet_addr(mask); 95 96 ifaddr->ifa_next = nullptr; 97 ifaddr->ifa_name = const_cast<char*>(name); 98 ifaddr->ifa_flags = flags; 99 ifaddr->ifa_addr = reinterpret_cast<struct sockaddr*>(in); 100 ifaddr->ifa_netmask = reinterpret_cast<struct sockaddr*>(mask_in); 101 ifaddr->ifa_data = nullptr; 102 103 if (ifaddr_count > 0) 104 mock_ifaddr_storage[ifaddr_count - 1].ifaddr.ifa_next = ifaddr; 105 ifaddr_count++; 106 mock_ifaddrs = &mock_ifaddr_storage[0].ifaddr; 107 } 108 109 void validateMsgHdr(const struct msghdr* msg) 110 { 111 if (msg->msg_namelen != sizeof(sockaddr_nl)) 112 { 113 fprintf(stderr, "bad namelen: %u\n", msg->msg_namelen); 114 abort(); 115 } 116 const auto& from = *reinterpret_cast<sockaddr_nl*>(msg->msg_name); 117 if (from.nl_family != AF_NETLINK) 118 { 119 fprintf(stderr, "recvmsg bad family data\n"); 120 abort(); 121 } 122 if (msg->msg_iovlen != 1) 123 { 124 fprintf(stderr, "recvmsg unsupported iov configuration\n"); 125 abort(); 126 } 127 } 128 129 ssize_t sendmsg_link_dump(std::queue<std::string>& msgs, std::string_view in) 130 { 131 const ssize_t ret = in.size(); 132 const auto& hdrin = stdplus::raw::copyFrom<nlmsghdr>(in); 133 if (hdrin.nlmsg_type != RTM_GETLINK) 134 { 135 return 0; 136 } 137 138 for (const auto& [name, idx] : mock_if_nametoindex) 139 { 140 ifinfomsg info{}; 141 info.ifi_index = idx; 142 nlmsghdr hdr{}; 143 hdr.nlmsg_len = NLMSG_LENGTH(sizeof(info)); 144 hdr.nlmsg_type = RTM_NEWLINK; 145 hdr.nlmsg_flags = NLM_F_MULTI; 146 auto& out = msgs.emplace(hdr.nlmsg_len, '\0'); 147 memcpy(out.data(), &hdr, sizeof(hdr)); 148 memcpy(NLMSG_DATA(out.data()), &info, sizeof(info)); 149 } 150 151 nlmsghdr hdr{}; 152 hdr.nlmsg_len = NLMSG_LENGTH(0); 153 hdr.nlmsg_type = NLMSG_DONE; 154 hdr.nlmsg_flags = NLM_F_MULTI; 155 auto& out = msgs.emplace(hdr.nlmsg_len, '\0'); 156 memcpy(out.data(), &hdr, sizeof(hdr)); 157 return ret; 158 } 159 160 ssize_t sendmsg_ack(std::queue<std::string>& msgs, std::string_view in) 161 { 162 nlmsgerr ack{}; 163 nlmsghdr hdr{}; 164 hdr.nlmsg_len = NLMSG_LENGTH(sizeof(ack)); 165 hdr.nlmsg_type = NLMSG_ERROR; 166 auto& out = msgs.emplace(hdr.nlmsg_len, '\0'); 167 memcpy(out.data(), &hdr, sizeof(hdr)); 168 memcpy(NLMSG_DATA(out.data()), &ack, sizeof(ack)); 169 return in.size(); 170 } 171 172 extern "C" { 173 174 int getifaddrs(ifaddrs** ifap) 175 { 176 *ifap = mock_ifaddrs; 177 if (mock_ifaddrs == nullptr) 178 return -1; 179 return (0); 180 } 181 182 unsigned if_nametoindex(const char* ifname) 183 { 184 auto it = mock_if_nametoindex.find(ifname); 185 if (it == mock_if_nametoindex.end()) 186 { 187 errno = ENXIO; 188 return 0; 189 } 190 return it->second; 191 } 192 193 char* if_indextoname(unsigned ifindex, char* ifname) 194 { 195 auto it = mock_if_indextoname.find(ifindex); 196 if (it == mock_if_indextoname.end()) 197 { 198 errno = ENXIO; 199 return NULL; 200 } 201 return std::strcpy(ifname, it->second.c_str()); 202 } 203 204 int ioctl(int fd, unsigned long int request, ...) 205 { 206 va_list vl; 207 va_start(vl, request); 208 void* data = va_arg(vl, void*); 209 va_end(vl); 210 211 if (request == SIOCGIFHWADDR) 212 { 213 auto req = reinterpret_cast<ifreq*>(data); 214 auto it = mock_macs.find(req->ifr_name); 215 if (it == mock_macs.end()) 216 { 217 errno = ENXIO; 218 return -1; 219 } 220 std::memcpy(req->ifr_hwaddr.sa_data, &it->second, sizeof(it->second)); 221 return 0; 222 } 223 224 static auto real_ioctl = 225 reinterpret_cast<decltype(&ioctl)>(dlsym(RTLD_NEXT, "ioctl")); 226 return real_ioctl(fd, request, data); 227 } 228 229 int socket(int domain, int type, int protocol) 230 { 231 static auto real_socket = 232 reinterpret_cast<decltype(&socket)>(dlsym(RTLD_NEXT, "socket")); 233 int fd = real_socket(domain, type, protocol); 234 if (domain == AF_NETLINK && !(type & SOCK_RAW)) 235 { 236 fprintf(stderr, "Netlink sockets must be RAW\n"); 237 abort(); 238 } 239 if (domain == AF_NETLINK && protocol == NETLINK_ROUTE) 240 { 241 mock_rtnetlinks[fd] = {}; 242 } 243 return fd; 244 } 245 246 int close(int fd) 247 { 248 auto it = mock_rtnetlinks.find(fd); 249 if (it != mock_rtnetlinks.end()) 250 { 251 mock_rtnetlinks.erase(it); 252 } 253 254 static auto real_close = 255 reinterpret_cast<decltype(&close)>(dlsym(RTLD_NEXT, "close")); 256 return real_close(fd); 257 } 258 259 ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags) 260 { 261 auto it = mock_rtnetlinks.find(sockfd); 262 if (it == mock_rtnetlinks.end()) 263 { 264 static auto real_sendmsg = 265 reinterpret_cast<decltype(&sendmsg)>(dlsym(RTLD_NEXT, "sendmsg")); 266 return real_sendmsg(sockfd, msg, flags); 267 } 268 auto& msgs = it->second; 269 270 validateMsgHdr(msg); 271 if (!msgs.empty()) 272 { 273 fprintf(stderr, "Unread netlink responses\n"); 274 abort(); 275 } 276 277 ssize_t ret; 278 std::string_view iov(reinterpret_cast<char*>(msg->msg_iov[0].iov_base), 279 msg->msg_iov[0].iov_len); 280 281 ret = sendmsg_link_dump(msgs, iov); 282 if (ret != 0) 283 { 284 return ret; 285 } 286 287 ret = sendmsg_ack(msgs, iov); 288 if (ret != 0) 289 { 290 return ret; 291 } 292 293 errno = ENOSYS; 294 return -1; 295 } 296 297 ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags) 298 { 299 auto it = mock_rtnetlinks.find(sockfd); 300 if (it == mock_rtnetlinks.end()) 301 { 302 static auto real_recvmsg = 303 reinterpret_cast<decltype(&recvmsg)>(dlsym(RTLD_NEXT, "recvmsg")); 304 return real_recvmsg(sockfd, msg, flags); 305 } 306 auto& msgs = it->second; 307 308 validateMsgHdr(msg); 309 constexpr size_t required_buf_size = 8192; 310 if (msg->msg_iov[0].iov_len < required_buf_size) 311 { 312 fprintf(stderr, "recvmsg iov too short: %zu\n", 313 msg->msg_iov[0].iov_len); 314 abort(); 315 } 316 if (msgs.empty()) 317 { 318 fprintf(stderr, "No pending netlink responses\n"); 319 abort(); 320 } 321 322 ssize_t ret = 0; 323 auto data = reinterpret_cast<char*>(msg->msg_iov[0].iov_base); 324 while (!msgs.empty()) 325 { 326 const auto& msg = msgs.front(); 327 if (NLMSG_ALIGN(ret) + msg.size() > required_buf_size) 328 { 329 break; 330 } 331 ret = NLMSG_ALIGN(ret); 332 memcpy(data + ret, msg.data(), msg.size()); 333 ret += msg.size(); 334 msgs.pop(); 335 } 336 return ret; 337 } 338 339 } // extern "C" 340