xref: /openbmc/openbmc/poky/meta/lib/oeqa/core/target/ssh.py (revision 8460358c3d24c71d9d38fd126c745854a6301564)
1#
2# Copyright (C) 2016 Intel Corporation
3#
4# SPDX-License-Identifier: MIT
5#
6
7import os
8import time
9import select
10import logging
11import subprocess
12import codecs
13
14from . import OETarget
15
16class OESSHTarget(OETarget):
17    def __init__(self, logger, ip, server_ip, timeout=300, user='root',
18                 port=None, server_port=0, **kwargs):
19        if not logger:
20            logger = logging.getLogger('target')
21            logger.setLevel(logging.INFO)
22            filePath = os.path.join(os.getcwd(), 'remoteTarget.log')
23            fileHandler = logging.FileHandler(filePath, 'w', 'utf-8')
24            formatter = logging.Formatter(
25                        '%(asctime)s.%(msecs)03d %(levelname)s: %(message)s',
26                        '%H:%M:%S')
27            fileHandler.setFormatter(formatter)
28            logger.addHandler(fileHandler)
29
30        super(OESSHTarget, self).__init__(logger)
31        self.ip = ip
32        self.server_ip = server_ip
33        self.server_port = server_port
34        self.timeout = timeout
35        self.user = user
36        ssh_options = [
37                '-o', 'ServerAliveCountMax=2',
38                '-o', 'ServerAliveInterval=30',
39                '-o', 'UserKnownHostsFile=/dev/null',
40                '-o', 'StrictHostKeyChecking=no',
41                '-o', 'LogLevel=ERROR'
42                ]
43        scp_options = [
44                '-r'
45        ]
46        self.ssh = ['ssh', '-l', self.user ] + ssh_options
47        self.scp = ['scp'] + ssh_options + scp_options
48        if port:
49            self.ssh = self.ssh + [ '-p', port ]
50            self.scp = self.scp + [ '-P', port ]
51
52    def start(self, **kwargs):
53        pass
54
55    def stop(self, **kwargs):
56        pass
57
58    def _run(self, command, timeout=None, ignore_status=True, raw=False):
59        """
60            Runs command in target using SSHProcess.
61        """
62        self.logger.debug("[Running]$ %s" % " ".join(command))
63
64        starttime = time.time()
65        status, output = SSHCall(command, self.logger, timeout, raw)
66        self.logger.debug("[Command returned '%d' after %.2f seconds]"
67                 "" % (status, time.time() - starttime))
68
69        if status and not ignore_status:
70            raise AssertionError("Command '%s' returned non-zero exit "
71                                 "status %d:\n%s" % (command, status, output))
72
73        return (status, output)
74
75    def run(self, command, timeout=None, ignore_status=True, raw=False):
76        """
77            Runs command in target.
78
79            command:    Command to run on target.
80            timeout:    <value>:    Kill command after <val> seconds.
81                        None:       Kill command default value seconds.
82                        0:          No timeout, runs until return.
83        """
84        targetCmd = 'export PATH=/usr/sbin:/sbin:/usr/bin:/bin; %s' % command
85        sshCmd = self.ssh + [self.ip, targetCmd]
86
87        if timeout:
88            processTimeout = timeout
89        elif timeout==0:
90            processTimeout = None
91        else:
92            processTimeout = self.timeout
93
94        status, output = self._run(sshCmd, processTimeout, ignore_status, raw)
95        if len(output) > (64 * 1024):
96            self.logger.debug('Command: %s\nStatus: %d Output length:  %s\n' % (command, status, len(output)))
97        else:
98            self.logger.debug('Command: %s\nStatus: %d Output:  %s\n' % (command, status, output))
99
100        return (status, output)
101
102    def copyTo(self, localSrc, remoteDst):
103        """
104            Copy file to target.
105
106            If local file is symlink, recreate symlink in target.
107        """
108        if os.path.islink(localSrc):
109            link = os.readlink(localSrc)
110            dstDir, dstBase = os.path.split(remoteDst)
111            sshCmd = 'cd %s; ln -s %s %s' % (dstDir, link, dstBase)
112            return self.run(sshCmd)
113
114        else:
115            remotePath = '%s@%s:%s' % (self.user, self.ip, remoteDst)
116            scpCmd = self.scp + [localSrc, remotePath]
117            return self._run(scpCmd, ignore_status=False)
118
119    def copyFrom(self, remoteSrc, localDst, warn_on_failure=False):
120        """
121            Copy file from target.
122        """
123        remotePath = '%s@%s:%s' % (self.user, self.ip, remoteSrc)
124        scpCmd = self.scp + [remotePath, localDst]
125        (status, output) = self._run(scpCmd, ignore_status=warn_on_failure)
126        if warn_on_failure and status:
127            self.logger.warning("Copy returned non-zero exit status %d:\n%s" % (status, output))
128        return (status, output)
129
130    def copyDirTo(self, localSrc, remoteDst):
131        """
132            Copy recursively localSrc directory to remoteDst in target.
133        """
134
135        for root, dirs, files in os.walk(localSrc):
136            # Create directories in the target as needed
137            for d in dirs:
138                tmpDir = os.path.join(root, d).replace(localSrc, "")
139                newDir = os.path.join(remoteDst, tmpDir.lstrip("/"))
140                cmd = "mkdir -p %s" % newDir
141                self.run(cmd)
142
143            # Copy files into the target
144            for f in files:
145                tmpFile = os.path.join(root, f).replace(localSrc, "")
146                dstFile = os.path.join(remoteDst, tmpFile.lstrip("/"))
147                srcFile = os.path.join(root, f)
148                self.copyTo(srcFile, dstFile)
149
150    def deleteFiles(self, remotePath, files):
151        """
152            Deletes files in target's remotePath.
153        """
154
155        cmd = "rm"
156        if not isinstance(files, list):
157            files = [files]
158
159        for f in files:
160            cmd = "%s %s" % (cmd, os.path.join(remotePath, f))
161
162        self.run(cmd)
163
164
165    def deleteDir(self, remotePath):
166        """
167            Deletes target's remotePath directory.
168        """
169
170        cmd = "rmdir %s" % remotePath
171        self.run(cmd)
172
173
174    def deleteDirStructure(self, localPath, remotePath):
175        """
176        Delete recursively localPath structure directory in target's remotePath.
177
178        This function is very usefult to delete a package that is installed in
179        the DUT and the host running the test has such package extracted in tmp
180        directory.
181
182        Example:
183            pwd: /home/user/tmp
184            tree:   .
185                    └── work
186                        ├── dir1
187                        │   └── file1
188                        └── dir2
189
190            localpath = "/home/user/tmp" and remotepath = "/home/user"
191
192            With the above variables this function will try to delete the
193            directory in the DUT in this order:
194                /home/user/work/dir1/file1
195                /home/user/work/dir1        (if dir is empty)
196                /home/user/work/dir2        (if dir is empty)
197                /home/user/work             (if dir is empty)
198        """
199
200        for root, dirs, files in os.walk(localPath, topdown=False):
201            # Delete files first
202            tmpDir = os.path.join(root).replace(localPath, "")
203            remoteDir = os.path.join(remotePath, tmpDir.lstrip("/"))
204            self.deleteFiles(remoteDir, files)
205
206            # Remove dirs if empty
207            for d in dirs:
208                tmpDir = os.path.join(root, d).replace(localPath, "")
209                remoteDir = os.path.join(remotePath, tmpDir.lstrip("/"))
210                self.deleteDir(remoteDir)
211
212def SSHCall(command, logger, timeout=None, raw=False, **opts):
213
214    def run():
215        nonlocal output
216        nonlocal process
217        output_raw = bytearray()
218        starttime = time.time()
219        progress = time.time()
220        process = subprocess.Popen(command, **options)
221        has_timeout = False
222        appendline = None
223        if timeout:
224            endtime = starttime + timeout
225            eof = False
226            os.set_blocking(process.stdout.fileno(), False)
227            while not has_timeout and not eof:
228                try:
229                    if select.select([process.stdout], [], [], 5)[0] != []:
230                        # wait a bit for more data, tries to avoid reading single characters
231                        time.sleep(0.2)
232                        data = process.stdout.read()
233                        if not data:
234                            eof = True
235                        else:
236                            output_raw.extend(data)
237                            # ignore errors to capture as much as possible
238                            #logger.debug('Partial data from SSH call:\n%s' % data.decode('utf-8', errors='ignore'))
239                            endtime = time.time() + timeout
240                except InterruptedError:
241                    logger.debug('InterruptedError')
242                    continue
243                except BlockingIOError:
244                    logger.debug('BlockingIOError')
245                    continue
246
247                if time.time() >= endtime:
248                    logger.debug('SSHCall has timeout! Time: %s, endtime: %s' % (time.time(), endtime))
249                    has_timeout = True
250
251                if time.time() >= (progress + 60):
252                    logger.debug('Waiting for process output at time: %s with datasize: %s' % (time.time(), len(output_raw)))
253                    progress = time.time()
254
255            process.stdout.close()
256
257            # process hasn't returned yet
258            if not eof:
259                process.terminate()
260                time.sleep(5)
261                try:
262                    process.kill()
263                except OSError:
264                    logger.debug('OSError when killing process')
265                    pass
266                endtime = time.time() - starttime
267                appendline = ("\nProcess killed - no output for %d seconds. Total"
268                            " running time: %d seconds." % (timeout, endtime))
269                logger.debug('Received data from SSH call:\n%s ' % appendline)
270                process.wait()
271
272            if raw:
273                output = bytes(output_raw)
274                if appendline:
275                    output += bytes(appendline, "utf-8")
276            else:
277                output = output_raw.decode('utf-8', errors='ignore')
278                if appendline:
279                    output += appendline
280        else:
281            output = output_raw = process.communicate()[0]
282            if not raw:
283                output = output_raw.decode('utf-8', errors='ignore')
284
285        if len(output) < (64 * 1024):
286            if output.rstrip():
287                logger.debug('Data from SSH call:\n%s' % output.rstrip())
288            else:
289                logger.debug('No output from SSH call')
290
291        # timout or not, make sure process exits and is not hanging
292        if process.returncode == None:
293            try:
294                process.wait(timeout=5)
295            except TimeoutExpired:
296                try:
297                    process.kill()
298                except OSError:
299                    logger.debug('OSError')
300                    pass
301                process.wait()
302
303        if has_timeout:
304            # Version of openssh before 8.6_p1 returns error code 0 when killed
305            # by a signal, when the timeout occurs we will receive a 0 error
306            # code because the process is been terminated and it's wrong because
307            # that value means success, but the process timed out.
308            # Afterwards, from version 8.6_p1 onwards, the returned code is 255.
309            # Fix this behaviour by checking the return code
310            if process.returncode == 0:
311                process.returncode = 255
312
313    options = {
314        "stdout": subprocess.PIPE,
315        "stderr": subprocess.STDOUT if not raw else None,
316        "stdin": None,
317        "shell": False,
318        "bufsize": -1,
319        "start_new_session": True,
320    }
321    options.update(opts)
322    output = ''
323    process = None
324
325    # Unset DISPLAY which means we won't trigger SSH_ASKPASS
326    env = os.environ.copy()
327    if "DISPLAY" in env:
328        del env['DISPLAY']
329    options['env'] = env
330
331    try:
332        run()
333    except:
334        # Need to guard against a SystemExit or other exception ocurring
335        # whilst running and ensure we don't leave a process behind.
336        if process.poll() is None:
337            process.kill()
338        if process.returncode == None:
339            process.wait()
340        logger.debug('Something went wrong, killing SSH process')
341        raise
342
343    return (process.returncode, output if raw else output.rstrip())
344