1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4"""
5This takes a crashing qtest trace and tries to remove superfluous operations
6"""
7
8import sys
9import os
10import subprocess
11import time
12import struct
13
14QEMU_ARGS = None
15QEMU_PATH = None
16TIMEOUT = 5
17CRASH_TOKEN = None
18
19# Minimization levels
20M1 = False # try removing IO commands iteratively
21M2 = False # try setting bits in operand of write/out to zero
22
23write_suffix_lookup = {"b": (1, "B"),
24                       "w": (2, "H"),
25                       "l": (4, "L"),
26                       "q": (8, "Q")}
27
28def usage():
29    sys.exit("""\
30Usage:
31
32QEMU_PATH="/path/to/qemu" QEMU_ARGS="args" {} [Options] input_trace output_trace
33
34By default, will try to use the second-to-last line in the output to identify
35whether the crash occred. Optionally, manually set a string that idenitifes the
36crash by setting CRASH_TOKEN=
37
38Options:
39
40-M1: enable a loop around the remove minimizer, which may help decrease some
41     timing dependent instructions. Off by default.
42-M2: try setting bits in operand of write/out to zero. Off by default.
43
44""".format((sys.argv[0])))
45
46deduplication_note = """\n\
47Note: While trimming the input, sometimes the mutated trace triggers a different
48type crash but indicates the same bug. Under this situation, our minimizer is
49incapable of recognizing and stopped from removing it. In the future, we may
50use a more sophisticated crash case deduplication method.
51\n"""
52
53def check_if_trace_crashes(trace, path):
54    with open(path, "w") as tracefile:
55        tracefile.write("".join(trace))
56
57    rc = subprocess.Popen("timeout -s 9 {timeout}s {qemu_path} {qemu_args} 2>&1\
58    < {trace_path}".format(timeout=TIMEOUT,
59                           qemu_path=QEMU_PATH,
60                           qemu_args=QEMU_ARGS,
61                           trace_path=path),
62                          shell=True,
63                          stdin=subprocess.PIPE,
64                          stdout=subprocess.PIPE,
65                          encoding="utf-8")
66    global CRASH_TOKEN
67    if CRASH_TOKEN is None:
68        try:
69            outs, _ = rc.communicate(timeout=5)
70            CRASH_TOKEN = " ".join(outs.splitlines()[-2].split()[0:3])
71        except subprocess.TimeoutExpired:
72            print("subprocess.TimeoutExpired")
73            return False
74        print("Identifying Crashes by this string: {}".format(CRASH_TOKEN))
75        global deduplication_note
76        print(deduplication_note)
77        return True
78
79    for line in iter(rc.stdout.readline, ""):
80        if "CLOSED" in line:
81            return False
82        if CRASH_TOKEN in line:
83            return True
84
85    print("\nWarning:")
86    print("  There is no 'CLOSED'or CRASH_TOKEN in the stdout of subprocess.")
87    print("  Usually this indicates a different type of crash.\n")
88    return False
89
90
91# If previous write commands write the same length of data at the same
92# interval, we view it as a hint.
93def split_write_hint(newtrace, i):
94    HINT_LEN = 3 # > 2
95    if i <=(HINT_LEN-1):
96        return None
97
98    #find previous continuous write traces
99    k = 0
100    l = i-1
101    writes = []
102    while (k != HINT_LEN and l >= 0):
103        if newtrace[l].startswith("write "):
104            writes.append(newtrace[l])
105            k += 1
106            l -= 1
107        elif newtrace[l] == "":
108            l -= 1
109        else:
110            return None
111    if k != HINT_LEN:
112        return None
113
114    length = int(writes[0].split()[2], 16)
115    for j in range(1, HINT_LEN):
116        if length != int(writes[j].split()[2], 16):
117            return None
118
119    step = int(writes[0].split()[1], 16) - int(writes[1].split()[1], 16)
120    for j in range(1, HINT_LEN-1):
121        if step != int(writes[j].split()[1], 16) - \
122            int(writes[j+1].split()[1], 16):
123            return None
124
125    return (int(writes[0].split()[1], 16)+step, length)
126
127
128def remove_lines(newtrace, outpath):
129    remove_step = 1
130    i = 0
131    while i < len(newtrace):
132        # 1.) Try to remove lines completely and reproduce the crash.
133        # If it works, we're done.
134        if (i+remove_step) >= len(newtrace):
135            remove_step = 1
136        prior = newtrace[i:i+remove_step]
137        for j in range(i, i+remove_step):
138            newtrace[j] = ""
139        print("Removing {lines} ...\n".format(lines=prior))
140        if check_if_trace_crashes(newtrace, outpath):
141            i += remove_step
142            # Double the number of lines to remove for next round
143            remove_step *= 2
144            continue
145        # Failed to remove multiple IOs, fast recovery
146        if remove_step > 1:
147            for j in range(i, i+remove_step):
148                newtrace[j] = prior[j-i]
149            remove_step = 1
150            continue
151        newtrace[i] = prior[0] # remove_step = 1
152
153        # 2.) Try to replace write{bwlq} commands with a write addr, len
154        # command. Since this can require swapping endianness, try both LE and
155        # BE options. We do this, so we can "trim" the writes in (3)
156
157        if (newtrace[i].startswith("write") and not
158            newtrace[i].startswith("write ")):
159            suffix = newtrace[i].split()[0][-1]
160            assert(suffix in write_suffix_lookup)
161            addr = int(newtrace[i].split()[1], 16)
162            value = int(newtrace[i].split()[2], 16)
163            for endianness in ['<', '>']:
164                data = struct.pack("{end}{size}".format(end=endianness,
165                                   size=write_suffix_lookup[suffix][1]),
166                                   value)
167                newtrace[i] = "write {addr} {size} 0x{data}\n".format(
168                    addr=hex(addr),
169                    size=hex(write_suffix_lookup[suffix][0]),
170                    data=data.hex())
171                if(check_if_trace_crashes(newtrace, outpath)):
172                    break
173            else:
174                newtrace[i] = prior[0]
175
176        # 3.) If it is a qtest write command: write addr len data, try to split
177        # it into two separate write commands. If splitting the data operand
178        # from length/2^n bytes to the left does not work, try to move the pivot
179        # to the right side, then add one to n, until length/2^n == 0. The idea
180        # is to prune unnecessary bytes from long writes, while accommodating
181        # arbitrary MemoryRegion access sizes and alignments.
182
183        # This algorithm will fail under some rare situations.
184        # e.g., xxxxxxxxxuxxxxxx (u is the unnecessary byte)
185
186        if newtrace[i].startswith("write "):
187            addr = int(newtrace[i].split()[1], 16)
188            length = int(newtrace[i].split()[2], 16)
189            data = newtrace[i].split()[3][2:]
190            if length > 1:
191
192                # Can we get a hint from previous writes?
193                hint = split_write_hint(newtrace, i)
194                if hint is not None:
195                    hint_addr = hint[0]
196                    hint_len = hint[1]
197                    if hint_addr >= addr and hint_addr+hint_len <= addr+length:
198                        newtrace[i] = "write {addr} {size} 0x{data}\n".format(
199                            addr=hex(hint_addr),
200                            size=hex(hint_len),
201                            data=data[(hint_addr-addr)*2:\
202                                (hint_addr-addr)*2+hint_len*2])
203                        if check_if_trace_crashes(newtrace, outpath):
204                            # next round
205                            i += 1
206                            continue
207                        newtrace[i] = prior[0]
208
209                # Try splitting it using a binary approach
210                leftlength = int(length/2)
211                rightlength = length - leftlength
212                newtrace.insert(i+1, "")
213                power = 1
214                while leftlength > 0:
215                    newtrace[i] = "write {addr} {size} 0x{data}\n".format(
216                            addr=hex(addr),
217                            size=hex(leftlength),
218                            data=data[:leftlength*2])
219                    newtrace[i+1] = "write {addr} {size} 0x{data}\n".format(
220                            addr=hex(addr+leftlength),
221                            size=hex(rightlength),
222                            data=data[leftlength*2:])
223                    if check_if_trace_crashes(newtrace, outpath):
224                        break
225                    # move the pivot to right side
226                    if leftlength < rightlength:
227                        rightlength, leftlength = leftlength, rightlength
228                        continue
229                    power += 1
230                    leftlength = int(length/pow(2, power))
231                    rightlength = length - leftlength
232                if check_if_trace_crashes(newtrace, outpath):
233                    i -= 1
234                else:
235                    newtrace[i] = prior[0]
236                    del newtrace[i+1]
237        i += 1
238
239
240def clear_bits(newtrace, outpath):
241    # try setting bits in operands of out/write to zero
242    i = 0
243    while i < len(newtrace):
244        if (not newtrace[i].startswith("write ") and not
245           newtrace[i].startswith("out")):
246           i += 1
247           continue
248        # write ADDR SIZE DATA
249        # outx ADDR VALUE
250        print("\nzero setting bits: {}".format(newtrace[i]))
251
252        prefix = " ".join(newtrace[i].split()[:-1])
253        data = newtrace[i].split()[-1]
254        data_bin = bin(int(data, 16))
255        data_bin_list = list(data_bin)
256
257        for j in range(2, len(data_bin_list)):
258            prior = newtrace[i]
259            if (data_bin_list[j] == '1'):
260                data_bin_list[j] = '0'
261                data_try = hex(int("".join(data_bin_list), 2))
262                # It seems qtest only accepts padded hex-values.
263                if len(data_try) % 2 == 1:
264                    data_try = data_try[:2] + "0" + data_try[2:]
265
266                newtrace[i] = "{prefix} {data_try}\n".format(
267                        prefix=prefix,
268                        data_try=data_try)
269
270                if not check_if_trace_crashes(newtrace, outpath):
271                    data_bin_list[j] = '1'
272                    newtrace[i] = prior
273        i += 1
274
275
276def minimize_trace(inpath, outpath):
277    global TIMEOUT
278    with open(inpath) as f:
279        trace = f.readlines()
280    start = time.time()
281    if not check_if_trace_crashes(trace, outpath):
282        sys.exit("The input qtest trace didn't cause a crash...")
283    end = time.time()
284    print("Crashed in {} seconds".format(end-start))
285    TIMEOUT = (end-start)*5
286    print("Setting the timeout for {} seconds".format(TIMEOUT))
287
288    newtrace = trace[:]
289    global M1, M2
290
291    # remove lines
292    old_len = len(newtrace) + 1
293    while(old_len > len(newtrace)):
294        old_len = len(newtrace)
295        print("trace length = ", old_len)
296        remove_lines(newtrace, outpath)
297        if not M1 and not M2:
298            break
299        newtrace = list(filter(lambda s: s != "", newtrace))
300    assert(check_if_trace_crashes(newtrace, outpath))
301
302    # set bits to zero
303    if M2:
304        clear_bits(newtrace, outpath)
305    assert(check_if_trace_crashes(newtrace, outpath))
306
307
308if __name__ == '__main__':
309    if len(sys.argv) < 3:
310        usage()
311    if "-M1" in sys.argv:
312        M1 = True
313    if "-M2" in sys.argv:
314        M2 = True
315    QEMU_PATH = os.getenv("QEMU_PATH")
316    QEMU_ARGS = os.getenv("QEMU_ARGS")
317    if QEMU_PATH is None or QEMU_ARGS is None:
318        usage()
319    # if "accel" not in QEMU_ARGS:
320    #     QEMU_ARGS += " -accel qtest"
321    CRASH_TOKEN = os.getenv("CRASH_TOKEN")
322    QEMU_ARGS += " -qtest stdio -monitor none -serial none "
323    minimize_trace(sys.argv[-2], sys.argv[-1])
324