summaryrefslogtreecommitdiff
path: root/tests/scripts/test_psa_constant_names.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/scripts/test_psa_constant_names.py')
-rwxr-xr-xtests/scripts/test_psa_constant_names.py191
1 files changed, 191 insertions, 0 deletions
diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py
new file mode 100755
index 00000000000..6883e279faa
--- /dev/null
+++ b/tests/scripts/test_psa_constant_names.py
@@ -0,0 +1,191 @@
+#!/usr/bin/env python3
+"""Test the program psa_constant_names.
+Gather constant names from header files and test cases. Compile a C program
+to print out their numerical values, feed these numerical values to
+psa_constant_names, and check that the output is the original name.
+Return 0 if all test cases pass, 1 if the output was not always as expected,
+or 1 (with a Python backtrace) if there was an operational error.
+"""
+
+# Copyright The Mbed TLS Contributors
+# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+
+import argparse
+from collections import namedtuple
+import os
+import re
+import subprocess
+import sys
+from typing import Iterable, List, Optional, Tuple
+
+import scripts_path # pylint: disable=unused-import
+from mbedtls_dev import c_build_helper
+from mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator
+from mbedtls_dev import typing_util
+
+def gather_inputs(headers: Iterable[str],
+ test_suites: Iterable[str],
+ inputs_class=InputsForTest) -> PSAMacroEnumerator:
+ """Read the list of inputs to test psa_constant_names with."""
+ inputs = inputs_class()
+ for header in headers:
+ inputs.parse_header(header)
+ for test_cases in test_suites:
+ inputs.parse_test_cases(test_cases)
+ inputs.add_numerical_values()
+ inputs.gather_arguments()
+ return inputs
+
+def run_c(type_word: str,
+ expressions: Iterable[str],
+ include_path: Optional[str] = None,
+ keep_c: bool = False) -> List[str]:
+ """Generate and run a program to print out numerical values of C expressions."""
+ if type_word == 'status':
+ cast_to = 'long'
+ printf_format = '%ld'
+ else:
+ cast_to = 'unsigned long'
+ printf_format = '0x%08lx'
+ return c_build_helper.get_c_expression_values(
+ cast_to, printf_format,
+ expressions,
+ caller='test_psa_constant_names.py for {} values'.format(type_word),
+ file_label=type_word,
+ header='#include <psa/crypto.h>',
+ include_path=include_path,
+ keep_c=keep_c
+ )
+
+NORMALIZE_STRIP_RE = re.compile(r'\s+')
+def normalize(expr: str) -> str:
+ """Normalize the C expression so as not to care about trivial differences.
+
+ Currently "trivial differences" means whitespace.
+ """
+ return re.sub(NORMALIZE_STRIP_RE, '', expr)
+
+ALG_TRUNCATED_TO_SELF_RE = \
+ re.compile(r'PSA_ALG_AEAD_WITH_SHORTENED_TAG\('
+ r'PSA_ALG_(?:CCM|CHACHA20_POLY1305|GCM)'
+ r', *16\)\Z')
+
+def is_simplifiable(expr: str) -> bool:
+ """Determine whether an expression is simplifiable.
+
+ Simplifiable expressions can't be output in their input form, since
+ the output will be the simple form. Therefore they must be excluded
+ from testing.
+ """
+ if ALG_TRUNCATED_TO_SELF_RE.match(expr):
+ return True
+ return False
+
+def collect_values(inputs: InputsForTest,
+ type_word: str,
+ include_path: Optional[str] = None,
+ keep_c: bool = False) -> Tuple[List[str], List[str]]:
+ """Generate expressions using known macro names and calculate their values.
+
+ Return a list of pairs of (expr, value) where expr is an expression and
+ value is a string representation of its integer value.
+ """
+ names = inputs.get_names(type_word)
+ expressions = sorted(expr
+ for expr in inputs.generate_expressions(names)
+ if not is_simplifiable(expr))
+ values = run_c(type_word, expressions,
+ include_path=include_path, keep_c=keep_c)
+ return expressions, values
+
+class Tests:
+ """An object representing tests and their results."""
+
+ Error = namedtuple('Error',
+ ['type', 'expression', 'value', 'output'])
+
+ def __init__(self, options) -> None:
+ self.options = options
+ self.count = 0
+ self.errors = [] #type: List[Tests.Error]
+
+ def run_one(self, inputs: InputsForTest, type_word: str) -> None:
+ """Test psa_constant_names for the specified type.
+
+ Run the program on the names for this type.
+ Use the inputs to figure out what arguments to pass to macros that
+ take arguments.
+ """
+ expressions, values = collect_values(inputs, type_word,
+ include_path=self.options.include,
+ keep_c=self.options.keep_c)
+ output_bytes = subprocess.check_output([self.options.program,
+ type_word] + values)
+ output = output_bytes.decode('ascii')
+ outputs = output.strip().split('\n')
+ self.count += len(expressions)
+ for expr, value, output in zip(expressions, values, outputs):
+ if self.options.show:
+ sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output))
+ if normalize(expr) != normalize(output):
+ self.errors.append(self.Error(type=type_word,
+ expression=expr,
+ value=value,
+ output=output))
+
+ def run_all(self, inputs: InputsForTest) -> None:
+ """Run psa_constant_names on all the gathered inputs."""
+ for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
+ 'key_type', 'key_usage']:
+ self.run_one(inputs, type_word)
+
+ def report(self, out: typing_util.Writable) -> None:
+ """Describe each case where the output is not as expected.
+
+ Write the errors to ``out``.
+ Also write a total.
+ """
+ for error in self.errors:
+ out.write('For {} "{}", got "{}" (value: {})\n'
+ .format(error.type, error.expression,
+ error.output, error.value))
+ out.write('{} test cases'.format(self.count))
+ if self.errors:
+ out.write(', {} FAIL\n'.format(len(self.errors)))
+ else:
+ out.write(' PASS\n')
+
+HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
+TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
+
+def main():
+ parser = argparse.ArgumentParser(description=globals()['__doc__'])
+ parser.add_argument('--include', '-I',
+ action='append', default=['include'],
+ help='Directory for header files')
+ parser.add_argument('--keep-c',
+ action='store_true', dest='keep_c', default=False,
+ help='Keep the intermediate C file')
+ parser.add_argument('--no-keep-c',
+ action='store_false', dest='keep_c',
+ help='Don\'t keep the intermediate C file (default)')
+ parser.add_argument('--program',
+ default='programs/psa/psa_constant_names',
+ help='Program to test')
+ parser.add_argument('--show',
+ action='store_true',
+ help='Show tested values on stdout')
+ parser.add_argument('--no-show',
+ action='store_false', dest='show',
+ help='Don\'t show tested values (default)')
+ options = parser.parse_args()
+ headers = [os.path.join(options.include[0], h) for h in HEADERS]
+ inputs = gather_inputs(headers, TEST_SUITES)
+ tests = Tests(options)
+ tests.run_all(inputs)
+ tests.report(sys.stdout)
+ if tests.errors:
+ sys.exit(1)
+
+if __name__ == '__main__':
+ main()