1# SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause) 2 3from argparse import ArgumentParser 4from argparse import FileType 5import os 6import sys 7import tpm2 8from tpm2 import ProtocolError 9import unittest 10import logging 11import struct 12 13class SmokeTest(unittest.TestCase): 14 def setUp(self): 15 self.client = tpm2.Client() 16 self.root_key = self.client.create_root_key() 17 18 def tearDown(self): 19 self.client.flush_context(self.root_key) 20 self.client.close() 21 22 def test_seal_with_auth(self): 23 data = ('X' * 64).encode() 24 auth = ('A' * 15).encode() 25 26 blob = self.client.seal(self.root_key, data, auth, None) 27 result = self.client.unseal(self.root_key, blob, auth, None) 28 self.assertEqual(data, result) 29 30 def determine_bank_alg(self, mask): 31 pcr_banks = self.client.get_cap_pcrs() 32 for bank_alg, pcrSelection in pcr_banks.items(): 33 if pcrSelection & mask == mask: 34 return bank_alg 35 return None 36 37 def test_seal_with_policy(self): 38 bank_alg = self.determine_bank_alg(1 << 16) 39 self.assertIsNotNone(bank_alg) 40 41 handle = self.client.start_auth_session(tpm2.TPM2_SE_TRIAL) 42 43 data = ('X' * 64).encode() 44 auth = ('A' * 15).encode() 45 pcrs = [16] 46 47 try: 48 self.client.policy_pcr(handle, pcrs, bank_alg=bank_alg) 49 self.client.policy_password(handle) 50 51 policy_dig = self.client.get_policy_digest(handle) 52 finally: 53 self.client.flush_context(handle) 54 55 blob = self.client.seal(self.root_key, data, auth, policy_dig) 56 57 handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY) 58 59 try: 60 self.client.policy_pcr(handle, pcrs, bank_alg=bank_alg) 61 self.client.policy_password(handle) 62 63 result = self.client.unseal(self.root_key, blob, auth, handle) 64 except: 65 self.client.flush_context(handle) 66 raise 67 68 self.assertEqual(data, result) 69 70 def test_unseal_with_wrong_auth(self): 71 data = ('X' * 64).encode() 72 auth = ('A' * 20).encode() 73 rc = 0 74 75 blob = self.client.seal(self.root_key, data, auth, None) 76 try: 77 result = self.client.unseal(self.root_key, blob, 78 auth[:-1] + 'B'.encode(), None) 79 except ProtocolError as e: 80 rc = e.rc 81 82 self.assertEqual(rc, tpm2.TPM2_RC_AUTH_FAIL) 83 84 def test_unseal_with_wrong_policy(self): 85 bank_alg = self.determine_bank_alg(1 << 16 | 1 << 1) 86 self.assertIsNotNone(bank_alg) 87 88 handle = self.client.start_auth_session(tpm2.TPM2_SE_TRIAL) 89 90 data = ('X' * 64).encode() 91 auth = ('A' * 17).encode() 92 pcrs = [16] 93 94 try: 95 self.client.policy_pcr(handle, pcrs, bank_alg=bank_alg) 96 self.client.policy_password(handle) 97 98 policy_dig = self.client.get_policy_digest(handle) 99 finally: 100 self.client.flush_context(handle) 101 102 blob = self.client.seal(self.root_key, data, auth, policy_dig) 103 104 # Extend first a PCR that is not part of the policy and try to unseal. 105 # This should succeed. 106 107 ds = tpm2.get_digest_size(bank_alg) 108 self.client.extend_pcr(1, ('X' * ds).encode(), bank_alg=bank_alg) 109 110 handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY) 111 112 try: 113 self.client.policy_pcr(handle, pcrs, bank_alg=bank_alg) 114 self.client.policy_password(handle) 115 116 result = self.client.unseal(self.root_key, blob, auth, handle) 117 except: 118 self.client.flush_context(handle) 119 raise 120 121 self.assertEqual(data, result) 122 123 # Then, extend a PCR that is part of the policy and try to unseal. 124 # This should fail. 125 self.client.extend_pcr(16, ('X' * ds).encode(), bank_alg=bank_alg) 126 127 handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY) 128 129 rc = 0 130 131 try: 132 self.client.policy_pcr(handle, pcrs, bank_alg=bank_alg) 133 self.client.policy_password(handle) 134 135 result = self.client.unseal(self.root_key, blob, auth, handle) 136 except ProtocolError as e: 137 rc = e.rc 138 self.client.flush_context(handle) 139 except: 140 self.client.flush_context(handle) 141 raise 142 143 self.assertEqual(rc, tpm2.TPM2_RC_POLICY_FAIL) 144 145 def test_seal_with_too_long_auth(self): 146 ds = tpm2.get_digest_size(tpm2.TPM2_ALG_SHA1) 147 data = ('X' * 64).encode() 148 auth = ('A' * (ds + 1)).encode() 149 150 rc = 0 151 try: 152 blob = self.client.seal(self.root_key, data, auth, None) 153 except ProtocolError as e: 154 rc = e.rc 155 156 self.assertEqual(rc, tpm2.TPM2_RC_SIZE) 157 158 def test_too_short_cmd(self): 159 rejected = False 160 try: 161 fmt = '>HIII' 162 cmd = struct.pack(fmt, 163 tpm2.TPM2_ST_NO_SESSIONS, 164 struct.calcsize(fmt) + 1, 165 tpm2.TPM2_CC_FLUSH_CONTEXT, 166 0xDEADBEEF) 167 168 self.client.send_cmd(cmd) 169 except IOError as e: 170 rejected = True 171 except: 172 pass 173 self.assertEqual(rejected, True) 174 175 def test_read_partial_resp(self): 176 try: 177 fmt = '>HIIH' 178 cmd = struct.pack(fmt, 179 tpm2.TPM2_ST_NO_SESSIONS, 180 struct.calcsize(fmt), 181 tpm2.TPM2_CC_GET_RANDOM, 182 0x20) 183 self.client.tpm.write(cmd) 184 hdr = self.client.tpm.read(10) 185 sz = struct.unpack('>I', hdr[2:6])[0] 186 rsp = self.client.tpm.read() 187 except: 188 pass 189 self.assertEqual(sz, 10 + 2 + 32) 190 self.assertEqual(len(rsp), 2 + 32) 191 192 def test_read_partial_overwrite(self): 193 try: 194 fmt = '>HIIH' 195 cmd = struct.pack(fmt, 196 tpm2.TPM2_ST_NO_SESSIONS, 197 struct.calcsize(fmt), 198 tpm2.TPM2_CC_GET_RANDOM, 199 0x20) 200 self.client.tpm.write(cmd) 201 # Read part of the respone 202 rsp1 = self.client.tpm.read(15) 203 204 # Send a new cmd 205 self.client.tpm.write(cmd) 206 207 # Read the whole respone 208 rsp2 = self.client.tpm.read() 209 except: 210 pass 211 self.assertEqual(len(rsp1), 15) 212 self.assertEqual(len(rsp2), 10 + 2 + 32) 213 214 def test_send_two_cmds(self): 215 rejected = False 216 try: 217 fmt = '>HIIH' 218 cmd = struct.pack(fmt, 219 tpm2.TPM2_ST_NO_SESSIONS, 220 struct.calcsize(fmt), 221 tpm2.TPM2_CC_GET_RANDOM, 222 0x20) 223 self.client.tpm.write(cmd) 224 225 # expect the second one to raise -EBUSY error 226 self.client.tpm.write(cmd) 227 rsp = self.client.tpm.read() 228 229 except IOError as e: 230 # read the response 231 rsp = self.client.tpm.read() 232 rejected = True 233 pass 234 except: 235 pass 236 self.assertEqual(rejected, True) 237 238class SpaceTest(unittest.TestCase): 239 def setUp(self): 240 logging.basicConfig(filename='SpaceTest.log', level=logging.DEBUG) 241 242 def test_make_two_spaces(self): 243 log = logging.getLogger(__name__) 244 log.debug("test_make_two_spaces") 245 246 space1 = tpm2.Client(tpm2.Client.FLAG_SPACE) 247 root1 = space1.create_root_key() 248 space2 = tpm2.Client(tpm2.Client.FLAG_SPACE) 249 root2 = space2.create_root_key() 250 root3 = space2.create_root_key() 251 252 log.debug("%08x" % (root1)) 253 log.debug("%08x" % (root2)) 254 log.debug("%08x" % (root3)) 255 256 def test_flush_context(self): 257 log = logging.getLogger(__name__) 258 log.debug("test_flush_context") 259 260 space1 = tpm2.Client(tpm2.Client.FLAG_SPACE) 261 root1 = space1.create_root_key() 262 log.debug("%08x" % (root1)) 263 264 space1.flush_context(root1) 265 266 def test_get_handles(self): 267 log = logging.getLogger(__name__) 268 log.debug("test_get_handles") 269 270 space1 = tpm2.Client(tpm2.Client.FLAG_SPACE) 271 space1.create_root_key() 272 space2 = tpm2.Client(tpm2.Client.FLAG_SPACE) 273 space2.create_root_key() 274 space2.create_root_key() 275 276 handles = space2.get_cap(tpm2.TPM2_CAP_HANDLES, tpm2.HR_TRANSIENT) 277 278 self.assertEqual(len(handles), 2) 279 280 log.debug("%08x" % (handles[0])) 281 log.debug("%08x" % (handles[1])) 282 283 def test_invalid_cc(self): 284 log = logging.getLogger(__name__) 285 log.debug(sys._getframe().f_code.co_name) 286 287 TPM2_CC_INVALID = tpm2.TPM2_CC_FIRST - 1 288 289 space1 = tpm2.Client(tpm2.Client.FLAG_SPACE) 290 root1 = space1.create_root_key() 291 log.debug("%08x" % (root1)) 292 293 fmt = '>HII' 294 cmd = struct.pack(fmt, tpm2.TPM2_ST_NO_SESSIONS, struct.calcsize(fmt), 295 TPM2_CC_INVALID) 296 297 rc = 0 298 try: 299 space1.send_cmd(cmd) 300 except ProtocolError as e: 301 rc = e.rc 302 303 self.assertEqual(rc, tpm2.TPM2_RC_COMMAND_CODE | 304 tpm2.TSS2_RESMGR_TPM_RC_LAYER) 305 306class AsyncTest(unittest.TestCase): 307 def setUp(self): 308 logging.basicConfig(filename='AsyncTest.log', level=logging.DEBUG) 309 310 def test_async(self): 311 log = logging.getLogger(__name__) 312 log.debug(sys._getframe().f_code.co_name) 313 314 async_client = tpm2.Client(tpm2.Client.FLAG_NONBLOCK) 315 log.debug("Calling get_cap in a NON_BLOCKING mode") 316 async_client.get_cap(tpm2.TPM2_CAP_HANDLES, tpm2.HR_LOADED_SESSION) 317 async_client.close() 318 319 def test_flush_invalid_context(self): 320 log = logging.getLogger(__name__) 321 log.debug(sys._getframe().f_code.co_name) 322 323 async_client = tpm2.Client(tpm2.Client.FLAG_SPACE | tpm2.Client.FLAG_NONBLOCK) 324 log.debug("Calling flush_context passing in an invalid handle ") 325 handle = 0x80123456 326 rc = 0 327 try: 328 async_client.flush_context(handle) 329 except OSError as e: 330 rc = e.errno 331 332 self.assertEqual(rc, 22) 333 async_client.close() 334