xref: /openbmc/openbmc/poky/meta/lib/oeqa/core/target/serial.py (revision 8460358c3d24c71d9d38fd126c745854a6301564)
1#
2# SPDX-License-Identifier: MIT
3#
4
5import base64
6import logging
7import os
8from threading import Lock
9from . import OETarget
10
11class OESerialTarget(OETarget):
12
13    def __init__(self, logger, target_ip, server_ip, server_port=0,
14                 timeout=300, serialcontrol_cmd=None, serialcontrol_extra_args=None,
15                 serialcontrol_ps1=None, serialcontrol_connect_timeout=None,
16                 machine=None, **kwargs):
17        if not logger:
18            logger = logging.getLogger('target')
19            logger.setLevel(logging.INFO)
20            filePath = os.path.join(os.getcwd(), 'remoteTarget.log')
21            fileHandler = logging.FileHandler(filePath, 'w', 'utf-8')
22            formatter = logging.Formatter(
23                        '%(asctime)s.%(msecs)03d %(levelname)s: %(message)s',
24                        '%H:%M:%S')
25            fileHandler.setFormatter(formatter)
26            logger.addHandler(fileHandler)
27
28        super(OESerialTarget, self).__init__(logger)
29
30        if serialcontrol_ps1:
31            self.target_ps1 = serialcontrol_ps1
32        elif machine:
33            # fallback to a default value which assumes root@machine
34            self.target_ps1 = f'root@{machine}:.*# '
35        else:
36            raise ValueError("Unable to determine shell command prompt (PS1) format.")
37
38        if not serialcontrol_cmd:
39            raise ValueError("Unable to determine serial control command.")
40
41        if serialcontrol_extra_args:
42            self.connection_script = f'{serialcontrol_cmd} {serialcontrol_extra_args}'
43        else:
44            self.connection_script = serialcontrol_cmd
45
46        if serialcontrol_connect_timeout:
47            self.connect_timeout = serialcontrol_connect_timeout
48        else:
49            self.connect_timeout = 10 # default to 10s connection timeout
50
51        self.default_command_timeout = timeout
52        self.ip = target_ip
53        self.server_ip = server_ip
54        self.server_port = server_port
55        self.conn = None
56        self.mutex = Lock()
57
58    def start(self, **kwargs):
59        pass
60
61    def stop(self, **kwargs):
62        pass
63
64    def get_connection(self):
65        if self.conn is None:
66            self.conn = SerialConnection(self.connection_script,
67                                         self.target_ps1,
68                                         self.connect_timeout,
69                                         self.default_command_timeout)
70
71        return self.conn
72
73    def run(self, cmd, timeout=None):
74        """
75            Runs command on target over the provided serial connection.
76            The first call will open the connection, and subsequent
77            calls will re-use the same connection to send new commands.
78
79            command:    Command to run on target.
80            timeout:    <value>:    Kill command after <val> seconds.
81                        None:       Kill command default value seconds.
82                        0:          No timeout, runs until return.
83        """
84        # Lock needed to avoid multiple threads running commands concurrently
85        # A serial connection can only be used by one caller at a time
86        with self.mutex:
87            conn = self.get_connection()
88
89            self.logger.debug(f"[Running]$ {cmd}")
90            # Run the command, then echo $? to get the command's return code
91            try:
92                output = conn.run_command(cmd, timeout)
93                status = conn.run_command("echo $?")
94                self.logger.debug(f"   [stdout]: {output}")
95                self.logger.debug(f"   [ret code]: {status}\n\n")
96            except SerialTimeoutException as e:
97                self.logger.debug(e)
98                output = ""
99                status = 255
100
101            # Return to $HOME after each command to simulate a stateless SSH connection
102            conn.run_command('cd "$HOME"')
103
104        return (int(status), output)
105
106    def copyTo(self, localSrc, remoteDst):
107        """
108            Copies files by converting them to base 32, then transferring
109            the ASCII text to the target, and decoding it in place on the
110            target.
111
112            On a 115k baud serial connection, this method transfers at
113            roughly 30kbps.
114        """
115        with open(localSrc, 'rb') as file:
116            data = file.read()
117
118        b32 = base64.b32encode(data).decode('utf-8')
119
120        # To avoid shell line limits, send a chunk at a time
121        SPLIT_LEN = 512
122        lines = [b32[i:i+SPLIT_LEN] for i in range(0, len(b32), SPLIT_LEN)]
123
124        with self.mutex:
125            conn = self.get_connection()
126
127            filename = os.path.basename(localSrc)
128            TEMP = f'/tmp/{filename}.b32'
129
130            # Create or empty out the temp file
131            conn.run_command(f'echo -n "" > {TEMP}')
132
133            for line in lines:
134                conn.run_command(f'echo -n {line} >> {TEMP}')
135
136            # Check to see whether the remoteDst is a directory
137            is_directory = conn.run_command(f'[[ -d {remoteDst} ]]; echo $?')
138            if int(is_directory) == 0:
139                # append the localSrc filename to the end of remoteDst
140                remoteDst = os.path.join(remoteDst, filename)
141
142            conn.run_command(f'base32 -d {TEMP} > {remoteDst}')
143            conn.run_command(f'rm {TEMP}')
144
145        return 0, 'Success'
146
147    def copyFrom(self, remoteSrc, localDst):
148        """
149            Copies files by converting them to base 32 on the target, then
150            transferring the ASCII text to the host. That text is then
151            decoded here and written out to the destination.
152
153            On a 115k baud serial connection, this method transfers at
154            roughly 30kbps.
155        """
156        with self.mutex:
157            b32 = self.get_connection().run_command(f'base32 {remoteSrc}')
158
159            data = base64.b32decode(b32.replace('\r\n', ''))
160
161            # If the local path is a directory, get the filename from
162            # the remoteSrc path and append it to localDst
163            if os.path.isdir(localDst):
164                filename = os.path.basename(remoteSrc)
165                localDst = os.path.join(localDst, filename)
166
167            with open(localDst, 'wb') as file:
168                file.write(data)
169
170        return 0, 'Success'
171
172    def copyDirTo(self, localSrc, remoteDst):
173        """
174            Copy recursively localSrc directory to remoteDst in target.
175        """
176
177        for root, dirs, files in os.walk(localSrc):
178            # Create directories in the target as needed
179            for d in dirs:
180                tmpDir = os.path.join(root, d).replace(localSrc, "")
181                newDir = os.path.join(remoteDst, tmpDir.lstrip("/"))
182                cmd = "mkdir -p %s" % newDir
183                self.run(cmd)
184
185            # Copy files into the target
186            for f in files:
187                tmpFile = os.path.join(root, f).replace(localSrc, "")
188                dstFile = os.path.join(remoteDst, tmpFile.lstrip("/"))
189                srcFile = os.path.join(root, f)
190                self.copyTo(srcFile, dstFile)
191
192    def deleteFiles(self, remotePath, files):
193        """
194            Deletes files in target's remotePath.
195        """
196
197        cmd = "rm"
198        if not isinstance(files, list):
199            files = [files]
200
201        for f in files:
202            cmd = "%s %s" % (cmd, os.path.join(remotePath, f))
203
204        self.run(cmd)
205
206    def deleteDir(self, remotePath):
207        """
208            Deletes target's remotePath directory.
209        """
210
211        cmd = "rmdir %s" % remotePath
212        self.run(cmd)
213
214    def deleteDirStructure(self, localPath, remotePath):
215        """
216        Delete recursively localPath structure directory in target's remotePath.
217
218        This function is useful to delete a package that is installed in the
219        device under test (DUT) and the host running the test has such package
220        extracted in tmp directory.
221
222        Example:
223            pwd: /home/user/tmp
224            tree:   .
225                    └── work
226                        ├── dir1
227                        │   └── file1
228                        └── dir2
229
230            localpath = "/home/user/tmp" and remotepath = "/home/user"
231
232            With the above variables this function will try to delete the
233            directory in the DUT in this order:
234                /home/user/work/dir1/file1
235                /home/user/work/dir1        (if dir is empty)
236                /home/user/work/dir2        (if dir is empty)
237                /home/user/work             (if dir is empty)
238        """
239
240        for root, dirs, files in os.walk(localPath, topdown=False):
241            # Delete files first
242            tmpDir = os.path.join(root).replace(localPath, "")
243            remoteDir = os.path.join(remotePath, tmpDir.lstrip("/"))
244            self.deleteFiles(remoteDir, files)
245
246            # Remove dirs if empty
247            for d in dirs:
248                tmpDir = os.path.join(root, d).replace(localPath, "")
249                remoteDir = os.path.join(remotePath, tmpDir.lstrip("/"))
250                self.deleteDir(remoteDir)
251
252class SerialTimeoutException(Exception):
253    def __init__(self, msg):
254        self.msg = msg
255    def __str__(self):
256        return self.msg
257
258class SerialConnection:
259
260    def __init__(self, script, target_prompt, connect_timeout, default_command_timeout):
261        import pexpect # limiting scope to avoid build dependency
262        self.prompt = target_prompt
263        self.connect_timeout = connect_timeout
264        self.default_command_timeout = default_command_timeout
265        self.conn = pexpect.spawn('/bin/bash', ['-c', script], encoding='utf8')
266        self._seek_to_clean_shell()
267        # Disable echo to avoid the need to parse the outgoing command
268        self.run_command('stty -echo')
269
270    def _seek_to_clean_shell(self):
271        """
272            Attempts to find a clean shell, meaning it is clear and
273            ready to accept a new command. This is necessary to ensure
274            the correct output is captured from each command.
275        """
276        import pexpect # limiting scope to avoid build dependency
277        # Look for a clean shell
278        # Wait a short amount of time for the connection to finish
279        pexpect_code = self.conn.expect([self.prompt, pexpect.TIMEOUT],
280                                        timeout=self.connect_timeout)
281
282        # if a timeout occurred, send an empty line and wait for a clean shell
283        if pexpect_code == 1:
284            # send a newline to clear and present the shell
285            self.conn.sendline("")
286            pexpect_code = self.conn.expect(self.prompt)
287
288    def run_command(self, cmd, timeout=None):
289        """
290            Runs command on target over the provided serial connection.
291            Returns any output on the shell while the command was run.
292
293            command:    Command to run on target.
294            timeout:    <value>:    Kill command after <val> seconds.
295                        None:       Kill command default value seconds.
296                        0:          No timeout, runs until return.
297        """
298        import pexpect # limiting scope to avoid build dependency
299        # Convert from the OETarget defaults to pexpect timeout values
300        if timeout is None:
301            timeout = self.default_command_timeout
302        elif timeout == 0:
303            timeout = None # passing None to pexpect is infinite timeout
304
305        self.conn.sendline(cmd)
306        pexpect_code = self.conn.expect([self.prompt, pexpect.TIMEOUT], timeout=timeout)
307
308        # check for timeout
309        if pexpect_code == 1:
310            self.conn.send('\003') # send Ctrl+C
311            self._seek_to_clean_shell()
312            raise SerialTimeoutException(f'Timeout executing: {cmd} after {timeout}s')
313
314        return self.conn.before.removesuffix('\r\n')
315
316