1# SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause)
2
3from argparse import ArgumentParser
4from argparse import FileType
5import os
6import sys
7import tpm2
8from tpm2 import ProtocolError
9import unittest
10import logging
11import struct
12
13class SmokeTest(unittest.TestCase):
14    def setUp(self):
15        self.client = tpm2.Client()
16        self.root_key = self.client.create_root_key()
17
18    def tearDown(self):
19        self.client.flush_context(self.root_key)
20        self.client.close()
21
22    def test_seal_with_auth(self):
23        data = 'X' * 64
24        auth = 'A' * 15
25
26        blob = self.client.seal(self.root_key, data, auth, None)
27        result = self.client.unseal(self.root_key, blob, auth, None)
28        self.assertEqual(data, result)
29
30    def test_seal_with_policy(self):
31        handle = self.client.start_auth_session(tpm2.TPM2_SE_TRIAL)
32
33        data = 'X' * 64
34        auth = 'A' * 15
35        pcrs = [16]
36
37        try:
38            self.client.policy_pcr(handle, pcrs)
39            self.client.policy_password(handle)
40
41            policy_dig = self.client.get_policy_digest(handle)
42        finally:
43            self.client.flush_context(handle)
44
45        blob = self.client.seal(self.root_key, data, auth, policy_dig)
46
47        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
48
49        try:
50            self.client.policy_pcr(handle, pcrs)
51            self.client.policy_password(handle)
52
53            result = self.client.unseal(self.root_key, blob, auth, handle)
54        except:
55            self.client.flush_context(handle)
56            raise
57
58        self.assertEqual(data, result)
59
60    def test_unseal_with_wrong_auth(self):
61        data = 'X' * 64
62        auth = 'A' * 20
63        rc = 0
64
65        blob = self.client.seal(self.root_key, data, auth, None)
66        try:
67            result = self.client.unseal(self.root_key, blob, auth[:-1] + 'B', None)
68        except ProtocolError, e:
69            rc = e.rc
70
71        self.assertEqual(rc, tpm2.TPM2_RC_AUTH_FAIL)
72
73    def test_unseal_with_wrong_policy(self):
74        handle = self.client.start_auth_session(tpm2.TPM2_SE_TRIAL)
75
76        data = 'X' * 64
77        auth = 'A' * 17
78        pcrs = [16]
79
80        try:
81            self.client.policy_pcr(handle, pcrs)
82            self.client.policy_password(handle)
83
84            policy_dig = self.client.get_policy_digest(handle)
85        finally:
86            self.client.flush_context(handle)
87
88        blob = self.client.seal(self.root_key, data, auth, policy_dig)
89
90        # Extend first a PCR that is not part of the policy and try to unseal.
91        # This should succeed.
92
93        ds = tpm2.get_digest_size(tpm2.TPM2_ALG_SHA1)
94        self.client.extend_pcr(1, 'X' * ds)
95
96        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
97
98        try:
99            self.client.policy_pcr(handle, pcrs)
100            self.client.policy_password(handle)
101
102            result = self.client.unseal(self.root_key, blob, auth, handle)
103        except:
104            self.client.flush_context(handle)
105            raise
106
107        self.assertEqual(data, result)
108
109        # Then, extend a PCR that is part of the policy and try to unseal.
110        # This should fail.
111        self.client.extend_pcr(16, 'X' * ds)
112
113        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
114
115        rc = 0
116
117        try:
118            self.client.policy_pcr(handle, pcrs)
119            self.client.policy_password(handle)
120
121            result = self.client.unseal(self.root_key, blob, auth, handle)
122        except ProtocolError, e:
123            rc = e.rc
124            self.client.flush_context(handle)
125        except:
126            self.client.flush_context(handle)
127            raise
128
129        self.assertEqual(rc, tpm2.TPM2_RC_POLICY_FAIL)
130
131    def test_seal_with_too_long_auth(self):
132        ds = tpm2.get_digest_size(tpm2.TPM2_ALG_SHA1)
133        data = 'X' * 64
134        auth = 'A' * (ds + 1)
135
136        rc = 0
137        try:
138            blob = self.client.seal(self.root_key, data, auth, None)
139        except ProtocolError, e:
140            rc = e.rc
141
142        self.assertEqual(rc, tpm2.TPM2_RC_SIZE)
143
144    def test_too_short_cmd(self):
145        rejected = False
146        try:
147            fmt = '>HIII'
148            cmd = struct.pack(fmt,
149                              tpm2.TPM2_ST_NO_SESSIONS,
150                              struct.calcsize(fmt) + 1,
151                              tpm2.TPM2_CC_FLUSH_CONTEXT,
152                              0xDEADBEEF)
153
154            self.client.send_cmd(cmd)
155        except IOError, e:
156            rejected = True
157        except:
158            pass
159        self.assertEqual(rejected, True)
160
161class SpaceTest(unittest.TestCase):
162    def setUp(self):
163        logging.basicConfig(filename='SpaceTest.log', level=logging.DEBUG)
164
165    def test_make_two_spaces(self):
166        log = logging.getLogger(__name__)
167        log.debug("test_make_two_spaces")
168
169        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
170        root1 = space1.create_root_key()
171        space2 = tpm2.Client(tpm2.Client.FLAG_SPACE)
172        root2 = space2.create_root_key()
173        root3 = space2.create_root_key()
174
175        log.debug("%08x" % (root1))
176        log.debug("%08x" % (root2))
177        log.debug("%08x" % (root3))
178
179    def test_flush_context(self):
180        log = logging.getLogger(__name__)
181        log.debug("test_flush_context")
182
183        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
184        root1 = space1.create_root_key()
185        log.debug("%08x" % (root1))
186
187        space1.flush_context(root1)
188
189    def test_get_handles(self):
190        log = logging.getLogger(__name__)
191        log.debug("test_get_handles")
192
193        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
194        space1.create_root_key()
195        space2 = tpm2.Client(tpm2.Client.FLAG_SPACE)
196        space2.create_root_key()
197        space2.create_root_key()
198
199        handles = space2.get_cap(tpm2.TPM2_CAP_HANDLES, tpm2.HR_TRANSIENT)
200
201        self.assertEqual(len(handles), 2)
202
203        log.debug("%08x" % (handles[0]))
204        log.debug("%08x" % (handles[1]))
205
206    def test_invalid_cc(self):
207        log = logging.getLogger(__name__)
208        log.debug(sys._getframe().f_code.co_name)
209
210        TPM2_CC_INVALID = tpm2.TPM2_CC_FIRST - 1
211
212        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
213        root1 = space1.create_root_key()
214        log.debug("%08x" % (root1))
215
216        fmt = '>HII'
217        cmd = struct.pack(fmt, tpm2.TPM2_ST_NO_SESSIONS, struct.calcsize(fmt),
218                          TPM2_CC_INVALID)
219
220        rc = 0
221        try:
222            space1.send_cmd(cmd)
223        except ProtocolError, e:
224            rc = e.rc
225
226        self.assertEqual(rc, tpm2.TPM2_RC_COMMAND_CODE |
227                         tpm2.TSS2_RESMGR_TPM_RC_LAYER)
228