singe/thirdparty/openssl/oqs-provider/oqs-template/generate.py
2023-11-16 22:15:24 -06:00

246 lines
10 KiB
Python

#!/usr/bin/env python3
import copy
import glob
import jinja2
import jinja2.ext
import os
import shutil
import subprocess
import yaml
# For files generated, the copyright message can be adapted
# see https://github.com/open-quantum-safe/oqs-provider/issues/2#issuecomment-920904048
# SPDX message to be leading, OpenSSL Copyright notice to be deleted
def fixup_copyright(filename):
with open(filename, "r") as origfile:
with open(filename+".new", "w") as newfile:
newfile.write("// SPDX-License-Identifier: Apache-2.0 AND MIT\n\n")
skipline = False
checkline = True
for line in origfile:
if checkline==True and " * Copyright" in line:
skipline=True
if "*/" in line:
skipline=False
checkline=False
if not skipline:
newfile.write(line)
os.rename(filename+".new", filename)
def get_kem_nistlevel(alg):
if 'LIBOQS_SRC_DIR' not in os.environ:
print("Must include LIBOQS_SRC_DIR in environment")
exit(1)
# translate family names in generate.yml to directory names for liboqs algorithm datasheets
if alg['family'] == 'CRYSTALS-Kyber': datasheetname = 'kyber'
elif alg['family'] == 'SIDH': datasheetname = 'sike'
elif alg['family'] == 'NTRU-Prime': datasheetname = 'ntruprime'
else: datasheetname = alg['family'].lower()
# load datasheet
algymlfilename = os.path.join(os.environ['LIBOQS_SRC_DIR'], 'docs', 'algorithms', 'kem', '{:s}.yml'.format(datasheetname))
algyml = yaml.safe_load(file_get_contents(algymlfilename, encoding='utf-8'))
# hacks to match names
def matches(name, alg):
def simplify(s):
return s.lower().replace('_', '').replace('-', '')
if 'FrodoKEM' in name: name = name.replace('FrodoKEM', 'Frodo')
if 'Saber-KEM' in name: name = name.replace('-KEM', '')
if '-90s' in name: name = name.replace('-90s', '').replace('Kyber', 'Kyber90s')
if simplify(name) == simplify(alg['name_group']): return True
return False
# find the variant that matches
for variant in algyml['parameter-sets']:
if matches(variant['name'], alg):
return variant['claimed-nist-level']
return None
def get_sig_nistlevel(family, alg):
if 'LIBOQS_SRC_DIR' not in os.environ:
print("Must include LIBOQS_SRC_DIR in environment")
exit(1)
# translate family names in generate.yml to directory names for liboqs algorithm datasheets
if family['family'] == 'CRYSTALS-Dilithium': datasheetname = 'dilithium'
elif family['family'] == 'SPHINCS-Haraka': datasheetname = 'sphincs'
elif family['family'] == 'SPHINCS-SHA2': datasheetname = 'sphincs'
elif family['family'] == 'SPHINCS-SHAKE': datasheetname = 'sphincs'
else: datasheetname = family['family'].lower()
# load datasheet
algymlfilename = os.path.join(os.environ['LIBOQS_SRC_DIR'], 'docs', 'algorithms', 'sig', '{:s}.yml'.format(datasheetname))
algyml = yaml.safe_load(file_get_contents(algymlfilename, encoding='utf-8'))
# hacks to match names
def matches(name, alg):
def simplify(s):
return s.lower().replace('_', '').replace('-', '').replace('+', '')
if simplify(name) == simplify(alg['name']): return True
return False
# find the variant that matches
for variant in algyml['parameter-sets']:
if matches(variant['name'], alg):
return variant['claimed-nist-level']
return None
def nist_to_bits(nistlevel):
if nistlevel==1 or nistlevel==2:
return 128
elif nistlevel==3 or nistlevel==4:
return 192
elif nistlevel==5:
return 256
else:
return None
def complete_config(config):
for kem in config['kems']:
bits_level = nist_to_bits(get_kem_nistlevel(kem))
if bits_level == None:
print("Cannot find security level for {:s} {:s}".format(kem['family'], kem['name_group']))
exit(1)
kem['bit_security'] = bits_level
# now add hybrid_nid to hybrid_groups
phyb = {}
if (bits_level == 128):
phyb['hybrid_group']='p256'
elif (bits_level == 192):
phyb['hybrid_group']='p384'
elif (bits_level == 256):
phyb['hybrid_group']='p521'
else:
print("Warning: Unknown bit level for %s. Cannot assign hybrid." % (kem['group_name']))
exit(1)
phyb['bit_security']=bits_level
phyb['nid']=kem['nid_hybrid']
kem['hybrids'].insert(0, phyb)
for famsig in config['sigs']:
for sig in famsig['variants']:
bits_level = nist_to_bits(get_sig_nistlevel(famsig, sig))
if bits_level == None:
print("Cannot find security level for {:s} {:s}. Setting to 0.".format(famsig['family'], sig['name']))
bits_level = 0
sig['security'] = bits_level
return config
def run_subprocess(command, outfilename=None, working_dir='.', expected_returncode=0, input=None, ignore_returncode=False):
result = subprocess.run(
command,
input=input,
stdout=(open(outfilename, "w") if outfilename!=None else subprocess.PIPE),
stderr=subprocess.PIPE,
cwd=working_dir,
)
if not(ignore_returncode) and (result.returncode != expected_returncode):
if outfilename == None:
print(result.stdout.decode('utf-8'))
assert False, "Got unexpected return code {}".format(result.returncode)
# For list.append in Jinja templates
Jinja2 = jinja2.Environment(loader=jinja2.FileSystemLoader(searchpath="."),extensions=['jinja2.ext.do'])
def file_get_contents(filename, encoding=None):
with open(filename, mode='r', encoding=encoding) as fh:
return fh.read()
def file_put_contents(filename, s, encoding=None):
with open(filename, mode='w', encoding=encoding) as fh:
fh.write(s)
def populate(filename, config, delimiter, overwrite=False):
fragments = glob.glob(os.path.join('oqs-template', filename, '*.fragment'))
if overwrite == True:
source_file = os.path.join('oqs-template', filename, os.path.basename(filename)+ '.base')
contents = file_get_contents(source_file)
else:
contents = file_get_contents(filename)
for fragment in fragments:
identifier = os.path.splitext(os.path.basename(fragment))[0]
if filename.endswith('.md'):
identifier_start = '{} OQS_TEMPLATE_FRAGMENT_{}_START -->'.format(delimiter, identifier.upper())
else:
identifier_start = '{} OQS_TEMPLATE_FRAGMENT_{}_START'.format(delimiter, identifier.upper())
identifier_end = '{} OQS_TEMPLATE_FRAGMENT_{}_END'.format(delimiter, identifier.upper())
preamble = contents[:contents.find(identifier_start)]
postamble = contents[contents.find(identifier_end):]
if overwrite == True:
contents = preamble + Jinja2.get_template(fragment).render({'config': config}) + postamble.replace(identifier_end + '\n', '')
else:
contents = preamble + identifier_start + Jinja2.get_template(fragment).render({'config': config}) + postamble
file_put_contents(filename, contents)
def load_config(include_disabled_sigs=False):
config = file_get_contents(os.path.join('oqs-template', 'generate.yml'), encoding='utf-8')
config = yaml.safe_load(config)
if not include_disabled_sigs:
for sig in config['sigs']:
sig['variants'] = [variant for variant in sig['variants'] if ('enable' in variant and variant['enable'])]
# remove KEMs without NID (old stuff)
newkems = []
for kem in config['kems']:
if 'nid' in kem:
newkems.append(kem)
config['kems']=newkems
# remove SIGs without OID (old stuff)
for sig in config['sigs']:
newvars = []
for variant in sig['variants']:
if 'oid' in variant:
newvars.append(variant)
sig['variants']=newvars
for kem in config['kems']:
kem['hybrids'] = []
try:
for extra_nid_current in kem['extra_nids']['current']:
extra_hybrid = extra_nid_current
if extra_nid_current['hybrid_group'] == "x25519" or extra_nid_current['hybrid_group'] == "p256":
extra_hybrid['bit_security']=128
if extra_nid_current['hybrid_group'] == "x448" or extra_nid_current['hybrid_group'] == "p384":
extra_hybrid['bit_security']=192
if extra_nid_current['hybrid_group'] == "p521":
extra_hybrid['bit_security']=256
kem['hybrids'].append(extra_hybrid)
if 'hybrid_group' in extra_nid_current and extra_nid_current['hybrid_group'] in ["x25519", "x448"]:
extra_hyb_nid = extra_nid_current['nid']
if 'nid_ecx_hybrid' in kem:
print("Warning, duplicate nid_ecx_hybrid for",
kem['name_group'], ":", extra_hyb_nid, "in generate.yml,",
kem['nid_ecx_hybrid'], "in generate_extras.yml, using generate.yml entry.")
kem['nid_ecx_hybrid'] = extra_hyb_nid
except KeyError as ke:
pass
return config
# extend config with "hybrid_groups" array:
config = load_config() # extend config with "hybrid_groups" array
# complete config with "bit_security" and "hybrid_group from
# nid_hybrid information
config = complete_config(config)
populate('test/oqs_test_signatures.c', config, '/////')
populate('test/oqs_test_kems.c', config, '/////')
populate('test/oqs_test_groups.c', config, '/////')
populate('test/oqs_test_endecode.c', config, '/////')
populate('oqsprov/oqsencoders.inc', config, '/////')
populate('oqsprov/oqsdecoders.inc', config, '/////')
populate('oqsprov/oqs_prov.h', config, '/////')
populate('oqsprov/oqsprov.c', config, '/////')
populate('oqsprov/oqsprov_capabilities.c', config, '/////')
populate('oqsprov/oqs_kmgmt.c', config, '/////')
populate('oqsprov/oqs_encode_key2any.c', config, '/////')
populate('oqsprov/oqs_decode_der2key.c', config, '/////')
populate('oqsprov/oqsprov_keys.c', config, '/////')
config2 = load_config(include_disabled_sigs=True)
config2 = complete_config(config2)
populate('ALGORITHMS.md', config2, '<!---')
populate('README.md', config2, '<!---')
print("All files generated")
os.environ["LIBOQS_DOCS_DIR"]=os.path.join(os.environ["LIBOQS_SRC_DIR"], "docs")
import generate_oid_nid_table