1 #include "util.hpp" 2 3 #include <arpa/inet.h> 4 #include <dlfcn.h> 5 #include <ifaddrs.h> 6 #include <linux/netlink.h> 7 #include <linux/rtnetlink.h> 8 #include <net/ethernet.h> 9 #include <net/if.h> 10 #include <netinet/in.h> 11 #include <sys/ioctl.h> 12 #include <sys/socket.h> 13 #include <sys/types.h> 14 #include <unistd.h> 15 16 #include <cstdarg> 17 #include <cstdio> 18 #include <cstring> 19 #include <map> 20 #include <queue> 21 #include <stdexcept> 22 #include <stdplus/raw.hpp> 23 #include <string> 24 #include <string_view> 25 #include <vector> 26 27 #define MAX_IFADDRS 5 28 29 int debugging = false; 30 31 /* Data for mocking getifaddrs */ 32 struct ifaddr_storage 33 { 34 struct ifaddrs ifaddr; 35 struct sockaddr_storage addr; 36 struct sockaddr_storage mask; 37 struct sockaddr_storage bcast; 38 } mock_ifaddr_storage[MAX_IFADDRS]; 39 40 struct ifaddrs* mock_ifaddrs = nullptr; 41 42 int ifaddr_count = 0; 43 44 /* Stub library functions */ 45 void freeifaddrs(ifaddrs* /*ifp*/) 46 { 47 return; 48 } 49 50 std::map<int, std::queue<std::string>> mock_rtnetlinks; 51 52 struct MockInfo 53 { 54 unsigned idx; 55 unsigned flags; 56 std::optional<ether_addr> mac; 57 std::optional<unsigned> mtu; 58 }; 59 60 std::map<std::string, MockInfo> mock_if; 61 std::map<int, std::string> mock_if_indextoname; 62 63 void mock_clear() 64 { 65 mock_ifaddrs = nullptr; 66 ifaddr_count = 0; 67 mock_rtnetlinks.clear(); 68 mock_if.clear(); 69 mock_if_indextoname.clear(); 70 } 71 72 void mock_addIF(const std::string& name, unsigned idx, unsigned flags, 73 std::optional<ether_addr> mac, std::optional<unsigned> mtu) 74 { 75 if (idx == 0) 76 { 77 throw std::invalid_argument("Bad interface index"); 78 } 79 80 mock_if.emplace( 81 name, MockInfo{.idx = idx, .flags = flags, .mac = mac, .mtu = mtu}); 82 mock_if_indextoname.emplace(idx, name); 83 } 84 85 void mock_addIP(const char* name, const char* addr, const char* mask) 86 { 87 struct ifaddrs* ifaddr = &mock_ifaddr_storage[ifaddr_count].ifaddr; 88 89 struct sockaddr_in* in = 90 reinterpret_cast<sockaddr_in*>(&mock_ifaddr_storage[ifaddr_count].addr); 91 struct sockaddr_in* mask_in = 92 reinterpret_cast<sockaddr_in*>(&mock_ifaddr_storage[ifaddr_count].mask); 93 94 in->sin_family = AF_INET; 95 in->sin_port = 0; 96 in->sin_addr.s_addr = inet_addr(addr); 97 98 mask_in->sin_family = AF_INET; 99 mask_in->sin_port = 0; 100 mask_in->sin_addr.s_addr = inet_addr(mask); 101 102 ifaddr->ifa_next = nullptr; 103 ifaddr->ifa_name = const_cast<char*>(name); 104 ifaddr->ifa_flags = 0; 105 ifaddr->ifa_addr = reinterpret_cast<struct sockaddr*>(in); 106 ifaddr->ifa_netmask = reinterpret_cast<struct sockaddr*>(mask_in); 107 ifaddr->ifa_data = nullptr; 108 109 if (ifaddr_count > 0) 110 mock_ifaddr_storage[ifaddr_count - 1].ifaddr.ifa_next = ifaddr; 111 ifaddr_count++; 112 mock_ifaddrs = &mock_ifaddr_storage[0].ifaddr; 113 } 114 115 void validateMsgHdr(const struct msghdr* msg) 116 { 117 if (msg->msg_namelen != sizeof(sockaddr_nl)) 118 { 119 fprintf(stderr, "bad namelen: %u\n", msg->msg_namelen); 120 abort(); 121 } 122 const auto& from = *reinterpret_cast<sockaddr_nl*>(msg->msg_name); 123 if (from.nl_family != AF_NETLINK) 124 { 125 fprintf(stderr, "recvmsg bad family data\n"); 126 abort(); 127 } 128 if (msg->msg_iovlen != 1) 129 { 130 fprintf(stderr, "recvmsg unsupported iov configuration\n"); 131 abort(); 132 } 133 } 134 135 ssize_t sendmsg_link_dump(std::queue<std::string>& msgs, std::string_view in) 136 { 137 const ssize_t ret = in.size(); 138 const auto& hdrin = stdplus::raw::copyFrom<nlmsghdr>(in); 139 if (hdrin.nlmsg_type != RTM_GETLINK) 140 { 141 return 0; 142 } 143 144 for (const auto& [name, i] : mock_if) 145 { 146 ifinfomsg info{}; 147 info.ifi_index = i.idx; 148 info.ifi_flags = i.flags; 149 nlmsghdr hdr{}; 150 hdr.nlmsg_len = NLMSG_LENGTH(sizeof(info)); 151 hdr.nlmsg_type = RTM_NEWLINK; 152 hdr.nlmsg_flags = NLM_F_MULTI; 153 auto& out = msgs.emplace(hdr.nlmsg_len, '\0'); 154 memcpy(out.data(), &hdr, sizeof(hdr)); 155 memcpy(NLMSG_DATA(out.data()), &info, sizeof(info)); 156 } 157 158 nlmsghdr hdr{}; 159 hdr.nlmsg_len = NLMSG_LENGTH(0); 160 hdr.nlmsg_type = NLMSG_DONE; 161 hdr.nlmsg_flags = NLM_F_MULTI; 162 auto& out = msgs.emplace(hdr.nlmsg_len, '\0'); 163 memcpy(out.data(), &hdr, sizeof(hdr)); 164 return ret; 165 } 166 167 ssize_t sendmsg_ack(std::queue<std::string>& msgs, std::string_view in) 168 { 169 nlmsgerr ack{}; 170 nlmsghdr hdr{}; 171 hdr.nlmsg_len = NLMSG_LENGTH(sizeof(ack)); 172 hdr.nlmsg_type = NLMSG_ERROR; 173 auto& out = msgs.emplace(hdr.nlmsg_len, '\0'); 174 memcpy(out.data(), &hdr, sizeof(hdr)); 175 memcpy(NLMSG_DATA(out.data()), &ack, sizeof(ack)); 176 return in.size(); 177 } 178 179 extern "C" { 180 181 int getifaddrs(ifaddrs** ifap) 182 { 183 *ifap = mock_ifaddrs; 184 if (mock_ifaddrs == nullptr) 185 return -1; 186 return (0); 187 } 188 189 unsigned if_nametoindex(const char* ifname) 190 { 191 auto it = mock_if.find(ifname); 192 if (it == mock_if.end()) 193 { 194 errno = ENXIO; 195 return 0; 196 } 197 return it->second.idx; 198 } 199 200 char* if_indextoname(unsigned ifindex, char* ifname) 201 { 202 auto it = mock_if_indextoname.find(ifindex); 203 if (it == mock_if_indextoname.end()) 204 { 205 errno = ENXIO; 206 return NULL; 207 } 208 return std::strcpy(ifname, it->second.c_str()); 209 } 210 211 int ioctl(int fd, unsigned long int request, ...) 212 { 213 va_list vl; 214 va_start(vl, request); 215 void* data = va_arg(vl, void*); 216 va_end(vl); 217 218 auto req = reinterpret_cast<ifreq*>(data); 219 if (request == SIOCGIFHWADDR) 220 { 221 auto it = mock_if.find(req->ifr_name); 222 if (it == mock_if.end()) 223 { 224 errno = ENXIO; 225 return -1; 226 } 227 if (!it->second.mac) 228 { 229 errno = EOPNOTSUPP; 230 return -1; 231 } 232 std::memcpy(req->ifr_hwaddr.sa_data, &*it->second.mac, 233 sizeof(*it->second.mac)); 234 return 0; 235 } 236 else if (request == SIOCGIFFLAGS) 237 { 238 auto it = mock_if.find(req->ifr_name); 239 if (it == mock_if.end()) 240 { 241 errno = ENXIO; 242 return -1; 243 } 244 req->ifr_flags = it->second.flags; 245 return 0; 246 } 247 else if (request == SIOCGIFMTU) 248 { 249 auto it = mock_if.find(req->ifr_name); 250 if (it == mock_if.end()) 251 { 252 errno = ENXIO; 253 return -1; 254 } 255 if (!it->second.mtu) 256 { 257 errno = EOPNOTSUPP; 258 return -1; 259 } 260 req->ifr_mtu = *it->second.mtu; 261 return 0; 262 } 263 264 static auto real_ioctl = 265 reinterpret_cast<decltype(&ioctl)>(dlsym(RTLD_NEXT, "ioctl")); 266 return real_ioctl(fd, request, data); 267 } 268 269 int socket(int domain, int type, int protocol) 270 { 271 static auto real_socket = 272 reinterpret_cast<decltype(&socket)>(dlsym(RTLD_NEXT, "socket")); 273 int fd = real_socket(domain, type, protocol); 274 if (domain == AF_NETLINK && !(type & SOCK_RAW)) 275 { 276 fprintf(stderr, "Netlink sockets must be RAW\n"); 277 abort(); 278 } 279 if (domain == AF_NETLINK && protocol == NETLINK_ROUTE) 280 { 281 mock_rtnetlinks[fd] = {}; 282 } 283 return fd; 284 } 285 286 int close(int fd) 287 { 288 auto it = mock_rtnetlinks.find(fd); 289 if (it != mock_rtnetlinks.end()) 290 { 291 mock_rtnetlinks.erase(it); 292 } 293 294 static auto real_close = 295 reinterpret_cast<decltype(&close)>(dlsym(RTLD_NEXT, "close")); 296 return real_close(fd); 297 } 298 299 ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags) 300 { 301 auto it = mock_rtnetlinks.find(sockfd); 302 if (it == mock_rtnetlinks.end()) 303 { 304 static auto real_sendmsg = 305 reinterpret_cast<decltype(&sendmsg)>(dlsym(RTLD_NEXT, "sendmsg")); 306 return real_sendmsg(sockfd, msg, flags); 307 } 308 auto& msgs = it->second; 309 310 validateMsgHdr(msg); 311 if (!msgs.empty()) 312 { 313 fprintf(stderr, "Unread netlink responses\n"); 314 abort(); 315 } 316 317 ssize_t ret; 318 std::string_view iov(reinterpret_cast<char*>(msg->msg_iov[0].iov_base), 319 msg->msg_iov[0].iov_len); 320 321 ret = sendmsg_link_dump(msgs, iov); 322 if (ret != 0) 323 { 324 return ret; 325 } 326 327 ret = sendmsg_ack(msgs, iov); 328 if (ret != 0) 329 { 330 return ret; 331 } 332 333 errno = ENOSYS; 334 return -1; 335 } 336 337 ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags) 338 { 339 auto it = mock_rtnetlinks.find(sockfd); 340 if (it == mock_rtnetlinks.end()) 341 { 342 static auto real_recvmsg = 343 reinterpret_cast<decltype(&recvmsg)>(dlsym(RTLD_NEXT, "recvmsg")); 344 return real_recvmsg(sockfd, msg, flags); 345 } 346 auto& msgs = it->second; 347 348 validateMsgHdr(msg); 349 constexpr size_t required_buf_size = 8192; 350 if (msg->msg_iov[0].iov_len < required_buf_size) 351 { 352 fprintf(stderr, "recvmsg iov too short: %zu\n", 353 msg->msg_iov[0].iov_len); 354 abort(); 355 } 356 if (msgs.empty()) 357 { 358 fprintf(stderr, "No pending netlink responses\n"); 359 abort(); 360 } 361 362 ssize_t ret = 0; 363 auto data = reinterpret_cast<char*>(msg->msg_iov[0].iov_base); 364 while (!msgs.empty()) 365 { 366 const auto& msg = msgs.front(); 367 if (NLMSG_ALIGN(ret) + msg.size() > required_buf_size) 368 { 369 break; 370 } 371 ret = NLMSG_ALIGN(ret); 372 memcpy(data + ret, msg.data(), msg.size()); 373 ret += msg.size(); 374 msgs.pop(); 375 } 376 return ret; 377 } 378 379 } // extern "C" 380