xref: /openbmc/openbmc/poky/meta/lib/oeqa/utils/sshcontrol.py (revision 96e4b4e121e0e2da1535d7d537d6a982a6ff5bc0)
1#
2# Copyright (C) 2013 Intel Corporation
3#
4# SPDX-License-Identifier: MIT
5#
6
7# Provides a class for setting up ssh connections,
8# running commands and copying files to/from a target.
9# It's used by testimage.bbclass and tests in lib/oeqa/runtime.
10
11import subprocess
12import time
13import os
14import select
15
16
17class SSHProcess(object):
18    def __init__(self, **options):
19
20        self.defaultopts = {
21            "stdout": subprocess.PIPE,
22            "stderr": subprocess.STDOUT,
23            "stdin": None,
24            "shell": False,
25            "bufsize": -1,
26            "start_new_session": True,
27        }
28        self.options = dict(self.defaultopts)
29        self.options.update(options)
30        self.status = None
31        self.output = None
32        self.process = None
33        self.starttime = None
34        self.logfile = None
35
36        # Unset DISPLAY which means we won't trigger SSH_ASKPASS
37        env = os.environ.copy()
38        if "DISPLAY" in env:
39            del env['DISPLAY']
40        self.options['env'] = env
41
42    def log(self, msg):
43        if self.logfile:
44            with open(self.logfile, "a") as f:
45               f.write("%s" % msg)
46
47    def _run(self, command, timeout=None, logfile=None):
48        self.logfile = logfile
49        self.starttime = time.time()
50        output = ''
51        self.process = subprocess.Popen(command, **self.options)
52        if timeout:
53            endtime = self.starttime + timeout
54            eof = False
55            while time.time() < endtime and not eof:
56                try:
57                    if select.select([self.process.stdout], [], [], 5)[0] != []:
58                        data = os.read(self.process.stdout.fileno(), 1024)
59                        if not data:
60                            self.process.poll()
61                            if self.process.returncode is None:
62                                self.process.stdout.close()
63                                eof = True
64                        else:
65                            data = data.decode("utf-8")
66                            output += data
67                            self.log(data)
68                            endtime = time.time() + timeout
69                except InterruptedError:
70                    continue
71
72            # process hasn't returned yet
73            if not eof:
74                self.process.terminate()
75                time.sleep(5)
76                try:
77                    self.process.kill()
78                except OSError:
79                    pass
80                lastline = "\nProcess killed - no output for %d seconds. Total running time: %d seconds." % (timeout, time.time() - self.starttime)
81                self.log(lastline)
82                output += lastline
83        else:
84            output = self.process.communicate()[0]
85            self.log(output.rstrip())
86
87        self.status = self.process.wait()
88        self.output = output.rstrip()
89
90    def run(self, command, timeout=None, logfile=None):
91        try:
92            self._run(command, timeout, logfile)
93        except:
94            # Need to guard against a SystemExit or other exception occuring whilst running
95            # and ensure we don't leave a process behind.
96            if self.process.poll() is None:
97                self.process.kill()
98                self.status = self.process.wait()
99            raise
100        return (self.status, self.output)
101
102class SSHControl(object):
103    def __init__(self, ip, logfile=None, timeout=300, user='root', port=None):
104        self.ip = ip
105        self.defaulttimeout = timeout
106        self.ignore_status = True
107        self.logfile = logfile
108        self.user = user
109        self.ssh_options = [
110                '-o', 'UserKnownHostsFile=/dev/null',
111                '-o', 'StrictHostKeyChecking=no',
112                '-o', 'LogLevel=ERROR'
113                ]
114        self.ssh = ['ssh', '-l', self.user ] + self.ssh_options
115        self.scp = ['scp'] + self.ssh_options
116        if port:
117            self.ssh = self.ssh + [ '-p', port ]
118            self.scp = self.scp + [ '-P', port ]
119
120    def log(self, msg):
121        if self.logfile:
122            with open(self.logfile, "a") as f:
123                f.write("%s\n" % msg)
124
125    def _internal_run(self, command, timeout=None, ignore_status = True):
126        self.log("[Running]$ %s" % " ".join(command))
127
128        proc = SSHProcess()
129        status, output = proc.run(command, timeout, logfile=self.logfile)
130
131        self.log("[Command returned '%d' after %.2f seconds]" % (status, time.time() - proc.starttime))
132
133        if status and not ignore_status:
134            raise AssertionError("Command '%s' returned non-zero exit status %d:\n%s" % (command, status, output))
135
136        return (status, output)
137
138    def run(self, command, timeout=None):
139        """
140        command - ssh command to run
141        timeout=<val> - kill command if there is no output after <val> seconds
142        timeout=None - kill command if there is no output after a default value seconds
143        timeout=0 - no timeout, let command run until it returns
144        """
145
146        command = self.ssh + [self.ip, 'export PATH=/usr/sbin:/sbin:/usr/bin:/bin; ' + command]
147
148        if timeout is None:
149            return self._internal_run(command, self.defaulttimeout, self.ignore_status)
150        if timeout == 0:
151            return self._internal_run(command, None, self.ignore_status)
152        return self._internal_run(command, timeout, self.ignore_status)
153
154    def copy_to(self, localpath, remotepath):
155        if os.path.islink(localpath):
156            localpath = os.path.dirname(localpath) + "/" + os.readlink(localpath)
157        command = self.scp + [localpath, '%s@%s:%s' % (self.user, self.ip, remotepath)]
158        return self._internal_run(command, ignore_status=False)
159
160    def copy_from(self, remotepath, localpath):
161        command = self.scp + ['%s@%s:%s' % (self.user, self.ip, remotepath), localpath]
162        return self._internal_run(command, ignore_status=False)
163
164    def copy_dir_to(self, localpath, remotepath):
165        """
166        Copy recursively localpath directory to remotepath in target.
167        """
168
169        for root, dirs, files in os.walk(localpath):
170            # Create directories in the target as needed
171            for d in dirs:
172                tmp_dir = os.path.join(root, d).replace(localpath, "")
173                new_dir = os.path.join(remotepath, tmp_dir.lstrip("/"))
174                cmd = "mkdir -p %s" % new_dir
175                self.run(cmd)
176
177            # Copy files into the target
178            for f in files:
179                tmp_file = os.path.join(root, f).replace(localpath, "")
180                dst_file = os.path.join(remotepath, tmp_file.lstrip("/"))
181                src_file = os.path.join(root, f)
182                self.copy_to(src_file, dst_file)
183
184
185    def delete_files(self, remotepath, files):
186        """
187        Delete files in target's remote path.
188        """
189
190        cmd = "rm"
191        if not isinstance(files, list):
192            files = [files]
193
194        for f in files:
195            cmd = "%s %s" % (cmd, os.path.join(remotepath, f))
196
197        self.run(cmd)
198
199
200    def delete_dir(self, remotepath):
201        """
202        Delete remotepath directory in target.
203        """
204
205        cmd = "rmdir %s" % remotepath
206        self.run(cmd)
207
208
209    def delete_dir_structure(self, localpath, remotepath):
210        """
211        Delete recursively localpath structure directory in target's remotepath.
212
213        This function is very usefult to delete a package that is installed in
214        the DUT and the host running the test has such package extracted in tmp
215        directory.
216
217        Example:
218            pwd: /home/user/tmp
219            tree:   .
220                    └── work
221                        ├── dir1
222                        │   └── file1
223                        └── dir2
224
225            localpath = "/home/user/tmp" and remotepath = "/home/user"
226
227            With the above variables this function will try to delete the
228            directory in the DUT in this order:
229                /home/user/work/dir1/file1
230                /home/user/work/dir1        (if dir is empty)
231                /home/user/work/dir2        (if dir is empty)
232                /home/user/work             (if dir is empty)
233        """
234
235        for root, dirs, files in os.walk(localpath, topdown=False):
236            # Delete files first
237            tmpdir = os.path.join(root).replace(localpath, "")
238            remotedir = os.path.join(remotepath, tmpdir.lstrip("/"))
239            self.delete_files(remotedir, files)
240
241            # Remove dirs if empty
242            for d in dirs:
243                tmpdir = os.path.join(root, d).replace(localpath, "")
244                remotedir = os.path.join(remotepath, tmpdir.lstrip("/"))
245                self.delete_dir(remotepath)
246