1 /*
2  * Copyright 2019 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "net_handler.hpp"
18 
19 #include <errno.h>
20 #include <netinet/in.h>
21 #include <poll.h>
22 #include <sys/socket.h>
23 #include <unistd.h>
24 
25 #include <cstdio>
26 
27 namespace ipmi_flash
28 {
29 
open()30 bool NetDataHandler::open()
31 {
32     listenFd.reset(::socket(AF_INET6, SOCK_STREAM, 0));
33     if (*listenFd < 0)
34     {
35         std::perror("Failed to create socket");
36         (void)listenFd.release();
37         return false;
38     }
39 
40     struct sockaddr_in6 listenAddr;
41     listenAddr.sin6_family = AF_INET6;
42     listenAddr.sin6_port = htons(listenPort);
43     listenAddr.sin6_flowinfo = 0;
44     listenAddr.sin6_addr = in6addr_any;
45     listenAddr.sin6_scope_id = 0;
46 
47     if (::bind(*listenFd, (struct sockaddr*)&listenAddr, sizeof(listenAddr)) <
48         0)
49     {
50         std::perror("Failed to bind");
51         return false;
52     }
53 
54     if (::listen(*listenFd, 1) < 0)
55     {
56         std::perror("Failed to listen");
57         return false;
58     }
59     return true;
60 }
61 
close()62 bool NetDataHandler::close()
63 {
64     connFd.reset();
65     listenFd.reset();
66 
67     return true;
68 }
69 
copyFrom(std::uint32_t length)70 std::vector<std::uint8_t> NetDataHandler::copyFrom(std::uint32_t length)
71 {
72     if (!connFd)
73     {
74         struct pollfd fds;
75         fds.fd = *listenFd;
76         fds.events = POLLIN;
77 
78         int ret = ::poll(&fds, 1, timeoutS * 1000);
79         if (ret < 0)
80         {
81             std::perror("Failed to poll");
82             return std::vector<uint8_t>();
83         }
84         else if (ret == 0)
85         {
86             fprintf(stderr, "Timed out waiting for connection\n");
87             return std::vector<uint8_t>();
88         }
89         else if (fds.revents != POLLIN)
90         {
91             fprintf(stderr, "Invalid poll state: 0x%x\n", fds.revents);
92             return std::vector<uint8_t>();
93         }
94 
95         connFd.reset(::accept(*listenFd, nullptr, nullptr));
96         if (*connFd < 0)
97         {
98             std::perror("Failed to accept connection");
99             (void)connFd.release();
100             return std::vector<uint8_t>();
101         }
102 
103         struct timeval tv = {};
104         tv.tv_sec = timeoutS;
105         if (setsockopt(*connFd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0)
106         {
107             std::perror("Failed to set receive timeout");
108             return std::vector<uint8_t>();
109         }
110     }
111 
112     std::vector<std::uint8_t> data(length);
113 
114     std::uint32_t bytesRead = 0;
115     ssize_t ret;
116     do
117     {
118         ret = read(*connFd, data.data() + bytesRead, length - bytesRead);
119         if (ret < 0)
120         {
121             if (errno == EINTR || errno == EAGAIN)
122                 continue;
123             std::perror("Failed to read from socket");
124             break;
125         }
126 
127         bytesRead += ret;
128     } while (ret > 0 && bytesRead < length);
129 
130     if (bytesRead != length)
131     {
132         fprintf(stderr,
133                 "Couldn't read full expected amount. Wanted %u but got %u\n",
134                 length, bytesRead);
135         data.resize(bytesRead);
136     }
137 
138     return data;
139 }
140 
writeMeta(const std::vector<std::uint8_t> &)141 bool NetDataHandler::writeMeta(const std::vector<std::uint8_t>&)
142 {
143     // TODO: have the host tool send the expected IP address that it will
144     // connect from
145     return true;
146 }
147 
readMeta()148 std::vector<std::uint8_t> NetDataHandler::readMeta()
149 {
150     return std::vector<std::uint8_t>();
151 }
152 
153 } // namespace ipmi_flash
154