#!/usr/bin/env python3 import copy import glob import jinja2 import jinja2.ext import os import shutil import subprocess import yaml # For files generated, the copyright message can be adapted # see https://github.com/open-quantum-safe/oqs-provider/issues/2#issuecomment-920904048 # SPDX message to be leading, OpenSSL Copyright notice to be deleted def fixup_copyright(filename): with open(filename, "r") as origfile: with open(filename+".new", "w") as newfile: newfile.write("// SPDX-License-Identifier: Apache-2.0 AND MIT\n\n") skipline = False checkline = True for line in origfile: if checkline==True and " * Copyright" in line: skipline=True if "*/" in line: skipline=False checkline=False if not skipline: newfile.write(line) os.rename(filename+".new", filename) def get_kem_nistlevel(alg): if 'LIBOQS_SRC_DIR' not in os.environ: print("Must include LIBOQS_SRC_DIR in environment") exit(1) # translate family names in generate.yml to directory names for liboqs algorithm datasheets if alg['family'] == 'CRYSTALS-Kyber': datasheetname = 'kyber' elif alg['family'] == 'SIDH': datasheetname = 'sike' elif alg['family'] == 'NTRU-Prime': datasheetname = 'ntruprime' else: datasheetname = alg['family'].lower() # load datasheet algymlfilename = os.path.join(os.environ['LIBOQS_SRC_DIR'], 'docs', 'algorithms', 'kem', '{:s}.yml'.format(datasheetname)) algyml = yaml.safe_load(file_get_contents(algymlfilename, encoding='utf-8')) # hacks to match names def matches(name, alg): def simplify(s): return s.lower().replace('_', '').replace('-', '') if 'FrodoKEM' in name: name = name.replace('FrodoKEM', 'Frodo') if 'Saber-KEM' in name: name = name.replace('-KEM', '') if '-90s' in name: name = name.replace('-90s', '').replace('Kyber', 'Kyber90s') if simplify(name) == simplify(alg['name_group']): return True return False # find the variant that matches for variant in algyml['parameter-sets']: if matches(variant['name'], alg): return variant['claimed-nist-level'] return None def get_sig_nistlevel(family, alg): if 'LIBOQS_SRC_DIR' not in os.environ: print("Must include LIBOQS_SRC_DIR in environment") exit(1) # translate family names in generate.yml to directory names for liboqs algorithm datasheets if family['family'] == 'CRYSTALS-Dilithium': datasheetname = 'dilithium' elif family['family'] == 'SPHINCS-Haraka': datasheetname = 'sphincs' elif family['family'] == 'SPHINCS-SHA2': datasheetname = 'sphincs' elif family['family'] == 'SPHINCS-SHAKE': datasheetname = 'sphincs' else: datasheetname = family['family'].lower() # load datasheet algymlfilename = os.path.join(os.environ['LIBOQS_SRC_DIR'], 'docs', 'algorithms', 'sig', '{:s}.yml'.format(datasheetname)) algyml = yaml.safe_load(file_get_contents(algymlfilename, encoding='utf-8')) # hacks to match names def matches(name, alg): def simplify(s): return s.lower().replace('_', '').replace('-', '').replace('+', '') if simplify(name) == simplify(alg['name']): return True return False # find the variant that matches for variant in algyml['parameter-sets']: if matches(variant['name'], alg): return variant['claimed-nist-level'] return None def nist_to_bits(nistlevel): if nistlevel==1 or nistlevel==2: return 128 elif nistlevel==3 or nistlevel==4: return 192 elif nistlevel==5: return 256 else: return None def complete_config(config): for kem in config['kems']: bits_level = nist_to_bits(get_kem_nistlevel(kem)) if bits_level == None: print("Cannot find security level for {:s} {:s}".format(kem['family'], kem['name_group'])) exit(1) kem['bit_security'] = bits_level # now add hybrid_nid to hybrid_groups phyb = {} if (bits_level == 128): phyb['hybrid_group']='p256' elif (bits_level == 192): phyb['hybrid_group']='p384' elif (bits_level == 256): phyb['hybrid_group']='p521' else: print("Warning: Unknown bit level for %s. Cannot assign hybrid." % (kem['group_name'])) exit(1) phyb['bit_security']=bits_level phyb['nid']=kem['nid_hybrid'] kem['hybrids'].insert(0, phyb) for famsig in config['sigs']: for sig in famsig['variants']: bits_level = nist_to_bits(get_sig_nistlevel(famsig, sig)) if bits_level == None: print("Cannot find security level for {:s} {:s}. Setting to 0.".format(famsig['family'], sig['name'])) bits_level = 0 sig['security'] = bits_level return config def run_subprocess(command, outfilename=None, working_dir='.', expected_returncode=0, input=None, ignore_returncode=False): result = subprocess.run( command, input=input, stdout=(open(outfilename, "w") if outfilename!=None else subprocess.PIPE), stderr=subprocess.PIPE, cwd=working_dir, ) if not(ignore_returncode) and (result.returncode != expected_returncode): if outfilename == None: print(result.stdout.decode('utf-8')) assert False, "Got unexpected return code {}".format(result.returncode) # For list.append in Jinja templates Jinja2 = jinja2.Environment(loader=jinja2.FileSystemLoader(searchpath="."),extensions=['jinja2.ext.do']) def file_get_contents(filename, encoding=None): with open(filename, mode='r', encoding=encoding) as fh: return fh.read() def file_put_contents(filename, s, encoding=None): with open(filename, mode='w', encoding=encoding) as fh: fh.write(s) def populate(filename, config, delimiter, overwrite=False): fragments = glob.glob(os.path.join('oqs-template', filename, '*.fragment')) if overwrite == True: source_file = os.path.join('oqs-template', filename, os.path.basename(filename)+ '.base') contents = file_get_contents(source_file) else: contents = file_get_contents(filename) for fragment in fragments: identifier = os.path.splitext(os.path.basename(fragment))[0] if filename.endswith('.md'): identifier_start = '{} OQS_TEMPLATE_FRAGMENT_{}_START -->'.format(delimiter, identifier.upper()) else: identifier_start = '{} OQS_TEMPLATE_FRAGMENT_{}_START'.format(delimiter, identifier.upper()) identifier_end = '{} OQS_TEMPLATE_FRAGMENT_{}_END'.format(delimiter, identifier.upper()) preamble = contents[:contents.find(identifier_start)] postamble = contents[contents.find(identifier_end):] if overwrite == True: contents = preamble + Jinja2.get_template(fragment).render({'config': config}) + postamble.replace(identifier_end + '\n', '') else: contents = preamble + identifier_start + Jinja2.get_template(fragment).render({'config': config}) + postamble file_put_contents(filename, contents) def load_config(include_disabled_sigs=False): config = file_get_contents(os.path.join('oqs-template', 'generate.yml'), encoding='utf-8') config = yaml.safe_load(config) if not include_disabled_sigs: for sig in config['sigs']: sig['variants'] = [variant for variant in sig['variants'] if ('enable' in variant and variant['enable'])] # remove KEMs without NID (old stuff) newkems = [] for kem in config['kems']: if 'nid' in kem: newkems.append(kem) config['kems']=newkems # remove SIGs without OID (old stuff) for sig in config['sigs']: newvars = [] for variant in sig['variants']: if 'oid' in variant: newvars.append(variant) sig['variants']=newvars for kem in config['kems']: kem['hybrids'] = [] try: for extra_nid_current in kem['extra_nids']['current']: extra_hybrid = extra_nid_current if extra_nid_current['hybrid_group'] == "x25519" or extra_nid_current['hybrid_group'] == "p256": extra_hybrid['bit_security']=128 if extra_nid_current['hybrid_group'] == "x448" or extra_nid_current['hybrid_group'] == "p384": extra_hybrid['bit_security']=192 if extra_nid_current['hybrid_group'] == "p521": extra_hybrid['bit_security']=256 kem['hybrids'].append(extra_hybrid) if 'hybrid_group' in extra_nid_current and extra_nid_current['hybrid_group'] in ["x25519", "x448"]: extra_hyb_nid = extra_nid_current['nid'] if 'nid_ecx_hybrid' in kem: print("Warning, duplicate nid_ecx_hybrid for", kem['name_group'], ":", extra_hyb_nid, "in generate.yml,", kem['nid_ecx_hybrid'], "in generate_extras.yml, using generate.yml entry.") kem['nid_ecx_hybrid'] = extra_hyb_nid except KeyError as ke: pass return config # extend config with "hybrid_groups" array: config = load_config() # extend config with "hybrid_groups" array # complete config with "bit_security" and "hybrid_group from # nid_hybrid information config = complete_config(config) populate('test/oqs_test_signatures.c', config, '/////') populate('test/oqs_test_kems.c', config, '/////') populate('test/oqs_test_groups.c', config, '/////') populate('test/oqs_test_endecode.c', config, '/////') populate('oqsprov/oqsencoders.inc', config, '/////') populate('oqsprov/oqsdecoders.inc', config, '/////') populate('oqsprov/oqs_prov.h', config, '/////') populate('oqsprov/oqsprov.c', config, '/////') populate('oqsprov/oqsprov_capabilities.c', config, '/////') populate('oqsprov/oqs_kmgmt.c', config, '/////') populate('oqsprov/oqs_encode_key2any.c', config, '/////') populate('oqsprov/oqs_decode_der2key.c', config, '/////') populate('oqsprov/oqsprov_keys.c', config, '/////') config2 = load_config(include_disabled_sigs=True) config2 = complete_config(config2) populate('ALGORITHMS.md', config2, '