1# SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause)
2
3from argparse import ArgumentParser
4from argparse import FileType
5import os
6import sys
7import tpm2
8from tpm2 import ProtocolError
9import unittest
10import logging
11import struct
12
13class SmokeTest(unittest.TestCase):
14    def setUp(self):
15        self.client = tpm2.Client()
16        self.root_key = self.client.create_root_key()
17
18    def tearDown(self):
19        self.client.flush_context(self.root_key)
20        self.client.close()
21
22    def test_seal_with_auth(self):
23        data = 'X' * 64
24        auth = 'A' * 15
25
26        blob = self.client.seal(self.root_key, data, auth, None)
27        result = self.client.unseal(self.root_key, blob, auth, None)
28        self.assertEqual(data, result)
29
30    def test_seal_with_policy(self):
31        handle = self.client.start_auth_session(tpm2.TPM2_SE_TRIAL)
32
33        data = 'X' * 64
34        auth = 'A' * 15
35        pcrs = [16]
36
37        try:
38            self.client.policy_pcr(handle, pcrs)
39            self.client.policy_password(handle)
40
41            policy_dig = self.client.get_policy_digest(handle)
42        finally:
43            self.client.flush_context(handle)
44
45        blob = self.client.seal(self.root_key, data, auth, policy_dig)
46
47        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
48
49        try:
50            self.client.policy_pcr(handle, pcrs)
51            self.client.policy_password(handle)
52
53            result = self.client.unseal(self.root_key, blob, auth, handle)
54        except:
55            self.client.flush_context(handle)
56            raise
57
58        self.assertEqual(data, result)
59
60    def test_unseal_with_wrong_auth(self):
61        data = 'X' * 64
62        auth = 'A' * 20
63        rc = 0
64
65        blob = self.client.seal(self.root_key, data, auth, None)
66        try:
67            result = self.client.unseal(self.root_key, blob, auth[:-1] + 'B', None)
68        except ProtocolError, e:
69            rc = e.rc
70
71        self.assertEqual(rc, tpm2.TPM2_RC_AUTH_FAIL)
72
73    def test_unseal_with_wrong_policy(self):
74        handle = self.client.start_auth_session(tpm2.TPM2_SE_TRIAL)
75
76        data = 'X' * 64
77        auth = 'A' * 17
78        pcrs = [16]
79
80        try:
81            self.client.policy_pcr(handle, pcrs)
82            self.client.policy_password(handle)
83
84            policy_dig = self.client.get_policy_digest(handle)
85        finally:
86            self.client.flush_context(handle)
87
88        blob = self.client.seal(self.root_key, data, auth, policy_dig)
89
90        # Extend first a PCR that is not part of the policy and try to unseal.
91        # This should succeed.
92
93        ds = tpm2.get_digest_size(tpm2.TPM2_ALG_SHA1)
94        self.client.extend_pcr(1, 'X' * ds)
95
96        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
97
98        try:
99            self.client.policy_pcr(handle, pcrs)
100            self.client.policy_password(handle)
101
102            result = self.client.unseal(self.root_key, blob, auth, handle)
103        except:
104            self.client.flush_context(handle)
105            raise
106
107        self.assertEqual(data, result)
108
109        # Then, extend a PCR that is part of the policy and try to unseal.
110        # This should fail.
111        self.client.extend_pcr(16, 'X' * ds)
112
113        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
114
115        rc = 0
116
117        try:
118            self.client.policy_pcr(handle, pcrs)
119            self.client.policy_password(handle)
120
121            result = self.client.unseal(self.root_key, blob, auth, handle)
122        except ProtocolError, e:
123            rc = e.rc
124            self.client.flush_context(handle)
125        except:
126            self.client.flush_context(handle)
127            raise
128
129        self.assertEqual(rc, tpm2.TPM2_RC_POLICY_FAIL)
130
131    def test_seal_with_too_long_auth(self):
132        ds = tpm2.get_digest_size(tpm2.TPM2_ALG_SHA1)
133        data = 'X' * 64
134        auth = 'A' * (ds + 1)
135
136        rc = 0
137        try:
138            blob = self.client.seal(self.root_key, data, auth, None)
139        except ProtocolError, e:
140            rc = e.rc
141
142        self.assertEqual(rc, tpm2.TPM2_RC_SIZE)
143
144    def test_too_short_cmd(self):
145        rejected = False
146        try:
147            fmt = '>HIII'
148            cmd = struct.pack(fmt,
149                              tpm2.TPM2_ST_NO_SESSIONS,
150                              struct.calcsize(fmt) + 1,
151                              tpm2.TPM2_CC_FLUSH_CONTEXT,
152                              0xDEADBEEF)
153
154            self.client.send_cmd(cmd)
155        except IOError, e:
156            rejected = True
157        except:
158            pass
159        self.assertEqual(rejected, True)
160
161    def test_read_partial_resp(self):
162        try:
163            fmt = '>HIIH'
164            cmd = struct.pack(fmt,
165                              tpm2.TPM2_ST_NO_SESSIONS,
166                              struct.calcsize(fmt),
167                              tpm2.TPM2_CC_GET_RANDOM,
168                              0x20)
169            self.client.tpm.write(cmd)
170            hdr = self.client.tpm.read(10)
171            sz = struct.unpack('>I', hdr[2:6])[0]
172            rsp = self.client.tpm.read()
173        except:
174            pass
175        self.assertEqual(sz, 10 + 2 + 32)
176        self.assertEqual(len(rsp), 2 + 32)
177
178    def test_read_partial_overwrite(self):
179        try:
180            fmt = '>HIIH'
181            cmd = struct.pack(fmt,
182                              tpm2.TPM2_ST_NO_SESSIONS,
183                              struct.calcsize(fmt),
184                              tpm2.TPM2_CC_GET_RANDOM,
185                              0x20)
186            self.client.tpm.write(cmd)
187            # Read part of the respone
188            rsp1 = self.client.tpm.read(15)
189
190            # Send a new cmd
191            self.client.tpm.write(cmd)
192
193            # Read the whole respone
194            rsp2 = self.client.tpm.read()
195        except:
196            pass
197        self.assertEqual(len(rsp1), 15)
198        self.assertEqual(len(rsp2), 10 + 2 + 32)
199
200    def test_send_two_cmds(self):
201        rejected = False
202        try:
203            fmt = '>HIIH'
204            cmd = struct.pack(fmt,
205                              tpm2.TPM2_ST_NO_SESSIONS,
206                              struct.calcsize(fmt),
207                              tpm2.TPM2_CC_GET_RANDOM,
208                              0x20)
209            self.client.tpm.write(cmd)
210
211            # expect the second one to raise -EBUSY error
212            self.client.tpm.write(cmd)
213            rsp = self.client.tpm.read()
214
215        except IOError, e:
216            # read the response
217            rsp = self.client.tpm.read()
218            rejected = True
219            pass
220        except:
221            pass
222        self.assertEqual(rejected, True)
223
224class SpaceTest(unittest.TestCase):
225    def setUp(self):
226        logging.basicConfig(filename='SpaceTest.log', level=logging.DEBUG)
227
228    def test_make_two_spaces(self):
229        log = logging.getLogger(__name__)
230        log.debug("test_make_two_spaces")
231
232        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
233        root1 = space1.create_root_key()
234        space2 = tpm2.Client(tpm2.Client.FLAG_SPACE)
235        root2 = space2.create_root_key()
236        root3 = space2.create_root_key()
237
238        log.debug("%08x" % (root1))
239        log.debug("%08x" % (root2))
240        log.debug("%08x" % (root3))
241
242    def test_flush_context(self):
243        log = logging.getLogger(__name__)
244        log.debug("test_flush_context")
245
246        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
247        root1 = space1.create_root_key()
248        log.debug("%08x" % (root1))
249
250        space1.flush_context(root1)
251
252    def test_get_handles(self):
253        log = logging.getLogger(__name__)
254        log.debug("test_get_handles")
255
256        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
257        space1.create_root_key()
258        space2 = tpm2.Client(tpm2.Client.FLAG_SPACE)
259        space2.create_root_key()
260        space2.create_root_key()
261
262        handles = space2.get_cap(tpm2.TPM2_CAP_HANDLES, tpm2.HR_TRANSIENT)
263
264        self.assertEqual(len(handles), 2)
265
266        log.debug("%08x" % (handles[0]))
267        log.debug("%08x" % (handles[1]))
268
269    def test_invalid_cc(self):
270        log = logging.getLogger(__name__)
271        log.debug(sys._getframe().f_code.co_name)
272
273        TPM2_CC_INVALID = tpm2.TPM2_CC_FIRST - 1
274
275        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
276        root1 = space1.create_root_key()
277        log.debug("%08x" % (root1))
278
279        fmt = '>HII'
280        cmd = struct.pack(fmt, tpm2.TPM2_ST_NO_SESSIONS, struct.calcsize(fmt),
281                          TPM2_CC_INVALID)
282
283        rc = 0
284        try:
285            space1.send_cmd(cmd)
286        except ProtocolError, e:
287            rc = e.rc
288
289        self.assertEqual(rc, tpm2.TPM2_RC_COMMAND_CODE |
290                         tpm2.TSS2_RESMGR_TPM_RC_LAYER)
291
292class AsyncTest(unittest.TestCase):
293    def setUp(self):
294        logging.basicConfig(filename='AsyncTest.log', level=logging.DEBUG)
295
296    def test_async(self):
297        log = logging.getLogger(__name__)
298        log.debug(sys._getframe().f_code.co_name)
299
300        async_client = tpm2.Client(tpm2.Client.FLAG_NONBLOCK)
301        log.debug("Calling get_cap in a NON_BLOCKING mode")
302        async_client.get_cap(tpm2.TPM2_CAP_HANDLES, tpm2.HR_LOADED_SESSION)
303        async_client.close()
304