xref: /openbmc/openbmc/poky/meta/lib/oeqa/core/target/ssh.py (revision 169d7bcc)
1#
2# Copyright (C) 2016 Intel Corporation
3#
4# SPDX-License-Identifier: MIT
5#
6
7import os
8import time
9import select
10import logging
11import subprocess
12import codecs
13
14from . import OETarget
15
16class OESSHTarget(OETarget):
17    def __init__(self, logger, ip, server_ip, timeout=300, user='root',
18                 port=None, server_port=0, **kwargs):
19        if not logger:
20            logger = logging.getLogger('target')
21            logger.setLevel(logging.INFO)
22            filePath = os.path.join(os.getcwd(), 'remoteTarget.log')
23            fileHandler = logging.FileHandler(filePath, 'w', 'utf-8')
24            formatter = logging.Formatter(
25                        '%(asctime)s.%(msecs)03d %(levelname)s: %(message)s',
26                        '%H:%M:%S')
27            fileHandler.setFormatter(formatter)
28            logger.addHandler(fileHandler)
29
30        super(OESSHTarget, self).__init__(logger)
31        self.ip = ip
32        self.server_ip = server_ip
33        self.server_port = server_port
34        self.timeout = timeout
35        self.user = user
36        ssh_options = [
37                '-o', 'ServerAliveCountMax=2',
38                '-o', 'ServerAliveInterval=30',
39                '-o', 'UserKnownHostsFile=/dev/null',
40                '-o', 'StrictHostKeyChecking=no',
41                '-o', 'LogLevel=ERROR'
42                ]
43        scp_options = [
44                '-r'
45        ]
46        self.ssh = ['ssh', '-l', self.user ] + ssh_options
47        self.scp = ['scp'] + ssh_options + scp_options
48        if port:
49            self.ssh = self.ssh + [ '-p', port ]
50            self.scp = self.scp + [ '-P', port ]
51
52    def start(self, **kwargs):
53        pass
54
55    def stop(self, **kwargs):
56        pass
57
58    def _run(self, command, timeout=None, ignore_status=True):
59        """
60            Runs command in target using SSHProcess.
61        """
62        self.logger.debug("[Running]$ %s" % " ".join(command))
63
64        starttime = time.time()
65        status, output = SSHCall(command, self.logger, timeout)
66        self.logger.debug("[Command returned '%d' after %.2f seconds]"
67                 "" % (status, time.time() - starttime))
68
69        if status and not ignore_status:
70            raise AssertionError("Command '%s' returned non-zero exit "
71                                 "status %d:\n%s" % (command, status, output))
72
73        return (status, output)
74
75    def run(self, command, timeout=None, ignore_status=True):
76        """
77            Runs command in target.
78
79            command:    Command to run on target.
80            timeout:    <value>:    Kill command after <val> seconds.
81                        None:       Kill command default value seconds.
82                        0:          No timeout, runs until return.
83        """
84        targetCmd = 'export PATH=/usr/sbin:/sbin:/usr/bin:/bin; %s' % command
85        sshCmd = self.ssh + [self.ip, targetCmd]
86
87        if timeout:
88            processTimeout = timeout
89        elif timeout==0:
90            processTimeout = None
91        else:
92            processTimeout = self.timeout
93
94        status, output = self._run(sshCmd, processTimeout, ignore_status)
95        self.logger.debug('Command: %s\nStatus: %d Output:  %s\n' % (command, status, output))
96
97        return (status, output)
98
99    def copyTo(self, localSrc, remoteDst):
100        """
101            Copy file to target.
102
103            If local file is symlink, recreate symlink in target.
104        """
105        if os.path.islink(localSrc):
106            link = os.readlink(localSrc)
107            dstDir, dstBase = os.path.split(remoteDst)
108            sshCmd = 'cd %s; ln -s %s %s' % (dstDir, link, dstBase)
109            return self.run(sshCmd)
110
111        else:
112            remotePath = '%s@%s:%s' % (self.user, self.ip, remoteDst)
113            scpCmd = self.scp + [localSrc, remotePath]
114            return self._run(scpCmd, ignore_status=False)
115
116    def copyFrom(self, remoteSrc, localDst, warn_on_failure=False):
117        """
118            Copy file from target.
119        """
120        remotePath = '%s@%s:%s' % (self.user, self.ip, remoteSrc)
121        scpCmd = self.scp + [remotePath, localDst]
122        (status, output) = self._run(scpCmd, ignore_status=warn_on_failure)
123        if warn_on_failure and status:
124            self.logger.warning("Copy returned non-zero exit status %d:\n%s" % (status, output))
125        return (status, output)
126
127    def copyDirTo(self, localSrc, remoteDst):
128        """
129            Copy recursively localSrc directory to remoteDst in target.
130        """
131
132        for root, dirs, files in os.walk(localSrc):
133            # Create directories in the target as needed
134            for d in dirs:
135                tmpDir = os.path.join(root, d).replace(localSrc, "")
136                newDir = os.path.join(remoteDst, tmpDir.lstrip("/"))
137                cmd = "mkdir -p %s" % newDir
138                self.run(cmd)
139
140            # Copy files into the target
141            for f in files:
142                tmpFile = os.path.join(root, f).replace(localSrc, "")
143                dstFile = os.path.join(remoteDst, tmpFile.lstrip("/"))
144                srcFile = os.path.join(root, f)
145                self.copyTo(srcFile, dstFile)
146
147    def deleteFiles(self, remotePath, files):
148        """
149            Deletes files in target's remotePath.
150        """
151
152        cmd = "rm"
153        if not isinstance(files, list):
154            files = [files]
155
156        for f in files:
157            cmd = "%s %s" % (cmd, os.path.join(remotePath, f))
158
159        self.run(cmd)
160
161
162    def deleteDir(self, remotePath):
163        """
164            Deletes target's remotePath directory.
165        """
166
167        cmd = "rmdir %s" % remotePath
168        self.run(cmd)
169
170
171    def deleteDirStructure(self, localPath, remotePath):
172        """
173        Delete recursively localPath structure directory in target's remotePath.
174
175        This function is very usefult to delete a package that is installed in
176        the DUT and the host running the test has such package extracted in tmp
177        directory.
178
179        Example:
180            pwd: /home/user/tmp
181            tree:   .
182                    └── work
183                        ├── dir1
184                        │   └── file1
185                        └── dir2
186
187            localpath = "/home/user/tmp" and remotepath = "/home/user"
188
189            With the above variables this function will try to delete the
190            directory in the DUT in this order:
191                /home/user/work/dir1/file1
192                /home/user/work/dir1        (if dir is empty)
193                /home/user/work/dir2        (if dir is empty)
194                /home/user/work             (if dir is empty)
195        """
196
197        for root, dirs, files in os.walk(localPath, topdown=False):
198            # Delete files first
199            tmpDir = os.path.join(root).replace(localPath, "")
200            remoteDir = os.path.join(remotePath, tmpDir.lstrip("/"))
201            self.deleteFiles(remoteDir, files)
202
203            # Remove dirs if empty
204            for d in dirs:
205                tmpDir = os.path.join(root, d).replace(localPath, "")
206                remoteDir = os.path.join(remotePath, tmpDir.lstrip("/"))
207                self.deleteDir(remoteDir)
208
209def SSHCall(command, logger, timeout=None, **opts):
210
211    def run():
212        nonlocal output
213        nonlocal process
214        output_raw = b''
215        starttime = time.time()
216        process = subprocess.Popen(command, **options)
217        has_timeout = False
218        if timeout:
219            endtime = starttime + timeout
220            eof = False
221            os.set_blocking(process.stdout.fileno(), False)
222            while not has_timeout and not eof:
223                try:
224                    logger.debug('Waiting for process output: time: %s, endtime: %s' % (time.time(), endtime))
225                    if select.select([process.stdout], [], [], 5)[0] != []:
226                        # wait a bit for more data, tries to avoid reading single characters
227                        time.sleep(0.2)
228                        data = process.stdout.read()
229                        if not data:
230                            eof = True
231                        else:
232                            output_raw += data
233                            # ignore errors to capture as much as possible
234                            logger.debug('Partial data from SSH call:\n%s' % data.decode('utf-8', errors='ignore'))
235                            endtime = time.time() + timeout
236                except InterruptedError:
237                    logger.debug('InterruptedError')
238                    continue
239                except BlockingIOError:
240                    logger.debug('BlockingIOError')
241                    continue
242
243                if time.time() >= endtime:
244                    logger.debug('SSHCall has timeout! Time: %s, endtime: %s' % (time.time(), endtime))
245                    has_timeout = True
246
247            process.stdout.close()
248
249            # process hasn't returned yet
250            if not eof:
251                process.terminate()
252                time.sleep(5)
253                try:
254                    process.kill()
255                except OSError:
256                    logger.debug('OSError when killing process')
257                    pass
258                endtime = time.time() - starttime
259                lastline = ("\nProcess killed - no output for %d seconds. Total"
260                            " running time: %d seconds." % (timeout, endtime))
261                logger.debug('Received data from SSH call:\n%s ' % lastline)
262                output += lastline
263                process.wait()
264
265        else:
266            output_raw = process.communicate()[0]
267
268        output = output_raw.decode('utf-8', errors='ignore')
269        logger.debug('Data from SSH call:\n%s' % output.rstrip())
270
271        # timout or not, make sure process exits and is not hanging
272        if process.returncode == None:
273            try:
274                process.wait(timeout=5)
275            except TimeoutExpired:
276                try:
277                    process.kill()
278                except OSError:
279                    logger.debug('OSError')
280                    pass
281                process.wait()
282
283        if has_timeout:
284            # Version of openssh before 8.6_p1 returns error code 0 when killed
285            # by a signal, when the timeout occurs we will receive a 0 error
286            # code because the process is been terminated and it's wrong because
287            # that value means success, but the process timed out.
288            # Afterwards, from version 8.6_p1 onwards, the returned code is 255.
289            # Fix this behaviour by checking the return code
290            if process.returncode == 0:
291                process.returncode = 255
292
293    options = {
294        "stdout": subprocess.PIPE,
295        "stderr": subprocess.STDOUT,
296        "stdin": None,
297        "shell": False,
298        "bufsize": -1,
299        "start_new_session": True,
300    }
301    options.update(opts)
302    output = ''
303    process = None
304
305    # Unset DISPLAY which means we won't trigger SSH_ASKPASS
306    env = os.environ.copy()
307    if "DISPLAY" in env:
308        del env['DISPLAY']
309    options['env'] = env
310
311    try:
312        run()
313    except:
314        # Need to guard against a SystemExit or other exception ocurring
315        # whilst running and ensure we don't leave a process behind.
316        if process.poll() is None:
317            process.kill()
318        if process.returncode == None:
319            process.wait()
320        logger.debug('Something went wrong, killing SSH process')
321        raise
322
323    return (process.returncode, output.rstrip())
324