1# SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause)
2
3from argparse import ArgumentParser
4from argparse import FileType
5import os
6import sys
7import tpm2
8from tpm2 import ProtocolError
9import unittest
10import logging
11import struct
12
13class SmokeTest(unittest.TestCase):
14    def setUp(self):
15        self.client = tpm2.Client()
16        self.root_key = self.client.create_root_key()
17
18    def tearDown(self):
19        self.client.flush_context(self.root_key)
20        self.client.close()
21
22    def test_seal_with_auth(self):
23        data = ('X' * 64).encode()
24        auth = ('A' * 15).encode()
25
26        blob = self.client.seal(self.root_key, data, auth, None)
27        result = self.client.unseal(self.root_key, blob, auth, None)
28        self.assertEqual(data, result)
29
30    def test_seal_with_policy(self):
31        handle = self.client.start_auth_session(tpm2.TPM2_SE_TRIAL)
32
33        data = ('X' * 64).encode()
34        auth = ('A' * 15).encode()
35        pcrs = [16]
36
37        try:
38            self.client.policy_pcr(handle, pcrs)
39            self.client.policy_password(handle)
40
41            policy_dig = self.client.get_policy_digest(handle)
42        finally:
43            self.client.flush_context(handle)
44
45        blob = self.client.seal(self.root_key, data, auth, policy_dig)
46
47        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
48
49        try:
50            self.client.policy_pcr(handle, pcrs)
51            self.client.policy_password(handle)
52
53            result = self.client.unseal(self.root_key, blob, auth, handle)
54        except:
55            self.client.flush_context(handle)
56            raise
57
58        self.assertEqual(data, result)
59
60    def test_unseal_with_wrong_auth(self):
61        data = ('X' * 64).encode()
62        auth = ('A' * 20).encode()
63        rc = 0
64
65        blob = self.client.seal(self.root_key, data, auth, None)
66        try:
67            result = self.client.unseal(self.root_key, blob,
68                        auth[:-1] + 'B'.encode(), None)
69        except ProtocolError as e:
70            rc = e.rc
71
72        self.assertEqual(rc, tpm2.TPM2_RC_AUTH_FAIL)
73
74    def test_unseal_with_wrong_policy(self):
75        handle = self.client.start_auth_session(tpm2.TPM2_SE_TRIAL)
76
77        data = ('X' * 64).encode()
78        auth = ('A' * 17).encode()
79        pcrs = [16]
80
81        try:
82            self.client.policy_pcr(handle, pcrs)
83            self.client.policy_password(handle)
84
85            policy_dig = self.client.get_policy_digest(handle)
86        finally:
87            self.client.flush_context(handle)
88
89        blob = self.client.seal(self.root_key, data, auth, policy_dig)
90
91        # Extend first a PCR that is not part of the policy and try to unseal.
92        # This should succeed.
93
94        ds = tpm2.get_digest_size(tpm2.TPM2_ALG_SHA1)
95        self.client.extend_pcr(1, ('X' * ds).encode())
96
97        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
98
99        try:
100            self.client.policy_pcr(handle, pcrs)
101            self.client.policy_password(handle)
102
103            result = self.client.unseal(self.root_key, blob, auth, handle)
104        except:
105            self.client.flush_context(handle)
106            raise
107
108        self.assertEqual(data, result)
109
110        # Then, extend a PCR that is part of the policy and try to unseal.
111        # This should fail.
112        self.client.extend_pcr(16, ('X' * ds).encode())
113
114        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
115
116        rc = 0
117
118        try:
119            self.client.policy_pcr(handle, pcrs)
120            self.client.policy_password(handle)
121
122            result = self.client.unseal(self.root_key, blob, auth, handle)
123        except ProtocolError as e:
124            rc = e.rc
125            self.client.flush_context(handle)
126        except:
127            self.client.flush_context(handle)
128            raise
129
130        self.assertEqual(rc, tpm2.TPM2_RC_POLICY_FAIL)
131
132    def test_seal_with_too_long_auth(self):
133        ds = tpm2.get_digest_size(tpm2.TPM2_ALG_SHA1)
134        data = ('X' * 64).encode()
135        auth = ('A' * (ds + 1)).encode()
136
137        rc = 0
138        try:
139            blob = self.client.seal(self.root_key, data, auth, None)
140        except ProtocolError as e:
141            rc = e.rc
142
143        self.assertEqual(rc, tpm2.TPM2_RC_SIZE)
144
145    def test_too_short_cmd(self):
146        rejected = False
147        try:
148            fmt = '>HIII'
149            cmd = struct.pack(fmt,
150                              tpm2.TPM2_ST_NO_SESSIONS,
151                              struct.calcsize(fmt) + 1,
152                              tpm2.TPM2_CC_FLUSH_CONTEXT,
153                              0xDEADBEEF)
154
155            self.client.send_cmd(cmd)
156        except IOError as e:
157            rejected = True
158        except:
159            pass
160        self.assertEqual(rejected, True)
161
162    def test_read_partial_resp(self):
163        try:
164            fmt = '>HIIH'
165            cmd = struct.pack(fmt,
166                              tpm2.TPM2_ST_NO_SESSIONS,
167                              struct.calcsize(fmt),
168                              tpm2.TPM2_CC_GET_RANDOM,
169                              0x20)
170            self.client.tpm.write(cmd)
171            hdr = self.client.tpm.read(10)
172            sz = struct.unpack('>I', hdr[2:6])[0]
173            rsp = self.client.tpm.read()
174        except:
175            pass
176        self.assertEqual(sz, 10 + 2 + 32)
177        self.assertEqual(len(rsp), 2 + 32)
178
179    def test_read_partial_overwrite(self):
180        try:
181            fmt = '>HIIH'
182            cmd = struct.pack(fmt,
183                              tpm2.TPM2_ST_NO_SESSIONS,
184                              struct.calcsize(fmt),
185                              tpm2.TPM2_CC_GET_RANDOM,
186                              0x20)
187            self.client.tpm.write(cmd)
188            # Read part of the respone
189            rsp1 = self.client.tpm.read(15)
190
191            # Send a new cmd
192            self.client.tpm.write(cmd)
193
194            # Read the whole respone
195            rsp2 = self.client.tpm.read()
196        except:
197            pass
198        self.assertEqual(len(rsp1), 15)
199        self.assertEqual(len(rsp2), 10 + 2 + 32)
200
201    def test_send_two_cmds(self):
202        rejected = False
203        try:
204            fmt = '>HIIH'
205            cmd = struct.pack(fmt,
206                              tpm2.TPM2_ST_NO_SESSIONS,
207                              struct.calcsize(fmt),
208                              tpm2.TPM2_CC_GET_RANDOM,
209                              0x20)
210            self.client.tpm.write(cmd)
211
212            # expect the second one to raise -EBUSY error
213            self.client.tpm.write(cmd)
214            rsp = self.client.tpm.read()
215
216        except IOError as e:
217            # read the response
218            rsp = self.client.tpm.read()
219            rejected = True
220            pass
221        except:
222            pass
223        self.assertEqual(rejected, True)
224
225class SpaceTest(unittest.TestCase):
226    def setUp(self):
227        logging.basicConfig(filename='SpaceTest.log', level=logging.DEBUG)
228
229    def test_make_two_spaces(self):
230        log = logging.getLogger(__name__)
231        log.debug("test_make_two_spaces")
232
233        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
234        root1 = space1.create_root_key()
235        space2 = tpm2.Client(tpm2.Client.FLAG_SPACE)
236        root2 = space2.create_root_key()
237        root3 = space2.create_root_key()
238
239        log.debug("%08x" % (root1))
240        log.debug("%08x" % (root2))
241        log.debug("%08x" % (root3))
242
243    def test_flush_context(self):
244        log = logging.getLogger(__name__)
245        log.debug("test_flush_context")
246
247        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
248        root1 = space1.create_root_key()
249        log.debug("%08x" % (root1))
250
251        space1.flush_context(root1)
252
253    def test_get_handles(self):
254        log = logging.getLogger(__name__)
255        log.debug("test_get_handles")
256
257        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
258        space1.create_root_key()
259        space2 = tpm2.Client(tpm2.Client.FLAG_SPACE)
260        space2.create_root_key()
261        space2.create_root_key()
262
263        handles = space2.get_cap(tpm2.TPM2_CAP_HANDLES, tpm2.HR_TRANSIENT)
264
265        self.assertEqual(len(handles), 2)
266
267        log.debug("%08x" % (handles[0]))
268        log.debug("%08x" % (handles[1]))
269
270    def test_invalid_cc(self):
271        log = logging.getLogger(__name__)
272        log.debug(sys._getframe().f_code.co_name)
273
274        TPM2_CC_INVALID = tpm2.TPM2_CC_FIRST - 1
275
276        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
277        root1 = space1.create_root_key()
278        log.debug("%08x" % (root1))
279
280        fmt = '>HII'
281        cmd = struct.pack(fmt, tpm2.TPM2_ST_NO_SESSIONS, struct.calcsize(fmt),
282                          TPM2_CC_INVALID)
283
284        rc = 0
285        try:
286            space1.send_cmd(cmd)
287        except ProtocolError as e:
288            rc = e.rc
289
290        self.assertEqual(rc, tpm2.TPM2_RC_COMMAND_CODE |
291                         tpm2.TSS2_RESMGR_TPM_RC_LAYER)
292
293class AsyncTest(unittest.TestCase):
294    def setUp(self):
295        logging.basicConfig(filename='AsyncTest.log', level=logging.DEBUG)
296
297    def test_async(self):
298        log = logging.getLogger(__name__)
299        log.debug(sys._getframe().f_code.co_name)
300
301        async_client = tpm2.Client(tpm2.Client.FLAG_NONBLOCK)
302        log.debug("Calling get_cap in a NON_BLOCKING mode")
303        async_client.get_cap(tpm2.TPM2_CAP_HANDLES, tpm2.HR_LOADED_SESSION)
304        async_client.close()
305