Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/424 #431

Merged
merged 10 commits into from
Jul 3, 2020
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ cma

apscheduler
pycryptodome
paramiko
sshpubkeys
PyNaCl
requests
requests_toolbelt
Expand Down
113 changes: 102 additions & 11 deletions studio/encrypted_payload_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from Crypto.Hash import SHA256
import nacl.secret
import nacl.utils
import nacl.signing
import paramiko
import base64
import json
from sshpubkeys import SSHKey

from .payload_builder import PayloadBuilder
from studio import logs
Expand All @@ -15,28 +18,71 @@ class EncryptedPayloadBuilder(PayloadBuilder):
Implementation for experiment payload builder
using public key RSA encryption.
"""
def __init__(self, name: str, keypath: str):
def __init__(self, name: str,
receiver_keypath: str,
sender_keypath: str = None):
"""
param: name - payload builder name
param: keypath - file path to .pem file with public key
param: receiver_keypath - file path to .pem file
with recipient public key
param: sender_keypath - file path to .pem file
with sender private key
"""
super(EncryptedPayloadBuilder, self).__init__(name)

# XXX Set logger verbosity level here
self.logger = logs.getLogger(self.__class__.__name__)

self.key_path = keypath
self.recipient_key_path = receiver_keypath
self.recipient_key = None
try:
self.recipient_key = RSA.import_key(open(self.key_path).read())
self.recipient_key =\
RSA.import_key(open(self.recipient_key_path).read())
except:
self.logger.error(
"FAILED to import recipient public key from: {0}".format(self.key_path))
return
msg = "FAILED to import recipient public key from: {0}"\
.format(self.recipient_key_path)
self.logger.error(msg)
raise ValueError(msg)

self.sender_key_path = sender_keypath
self.sender_key = None
self.sender_fingerprint = None

if self.sender_key_path is None:
self.logger.error("Signing key path must be specified for encrypted payloads. ABORTING.")
raise ValueError()

# We expect ed25519 signing key in "private key" format
try:
self.sender_key =\
paramiko.Ed25519Key(filename=self.sender_key_path)

if self.sender_key is None:
self._raise_error("Failed to import private signing key. ABORTING.")
except:
self._raise_error("FAILED to open/read private signing key file: {0}"\
.format(self.sender_key_path))

self.sender_fingerprint = \
self._get_fingerprint(self.sender_key)

self.simple_builder =\
UnencryptedPayloadBuilder("simple-builder-for-encryptor")

def _raise_error(self, msg: str):
self.logger.error(msg)
raise ValueError(msg)

def _get_fingerprint(self, signing_key):
ssh_key = SSHKey("ssh-ed25519 {0}"
.format(signing_key.get_base64()))
try:
ssh_key.parse()
except:
self._raise_error("INVALID signing key type. ABORTING.")

return ssh_key.hash_sha256() # SHA256:xyz

def _import_rsa_key(self, key_path: str):
key = None
try:
Expand All @@ -47,6 +93,13 @@ def _import_rsa_key(self, key_path: str):
key = None
return key

def _rsa_encrypt_data_to_base64(self, key, data):
# Encrypt byte data with RSA key
cipher_rsa = PKCS1_OAEP.new(key=key, hashAlgo=SHA256)
encrypted_data = cipher_rsa.encrypt(data)
encrypted_data_base64 = base64.b64encode(encrypted_data)
return encrypted_data_base64

def _encrypt_str(self, workload: str):
# Generate one-time symmetric session key:
session_key = nacl.utils.random(32)
Expand All @@ -58,12 +111,39 @@ def _encrypt_str(self, workload: str):
encrypted_data_text = base64.b64encode(encrypted_data)

# Encrypt the session key with the public RSA key
cipher_rsa = PKCS1_OAEP.new(key=self.recipient_key, hashAlgo=SHA256)
encrypted_session_key = cipher_rsa.encrypt(session_key)
encrypted_session_key_text = base64.b64encode(encrypted_session_key)
encrypted_session_key_text =\
self._rsa_encrypt_data_to_base64(self.recipient_key, session_key)

return encrypted_session_key_text, encrypted_data_text

def _verify_signature(self, data, msg):
if msg.get_text() != "ssh-ed25519":
return False

try:
self.sender_key._signing_key.verify_key.verify(data, msg.get_binary())
except:
return False
else:
return True

def _sign_payload(self, encrypted_payload):
"""
encrypted_payload - base64 representation of the encrypted payload.
returns: base64-encoded signature
"""
sign_message = self.sender_key.sign_ssh_data(encrypted_payload)

# Verify what we generated just in case:
verify_message = paramiko.Message(sign_message.asbytes())
verify_res = self._verify_signature(encrypted_payload, verify_message)

if not verify_res:
self._raise_error("FAILED to verify signed data. ABORTING.")

result = base64.b64encode(sign_message.asbytes())
return result

def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_text):
private_key = self._import_rsa_key(private_key_path)
if private_key is None:
Expand Down Expand Up @@ -92,6 +172,7 @@ def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_tex
def construct(self, experiment, config, packages):
unencrypted_payload =\
self.simple_builder.construct(experiment, config, packages)
unencrypted_payload_str = json.dumps(unencrypted_payload)

# Construct payload template:
encrypted_payload = {
Expand All @@ -108,7 +189,7 @@ def construct(self, experiment, config, packages):
}

# Now fill it up with experiment properties:
enc_key, enc_payload = self._encrypt_str(json.dumps(unencrypted_payload))
enc_key, enc_payload = self._encrypt_str(unencrypted_payload_str)

encrypted_payload["message"]["experiment"]["status"] =\
experiment.status
Expand All @@ -122,6 +203,16 @@ def construct(self, experiment, config, packages):
experiment.resources_needed
encrypted_payload["message"]["payload"] =\
"{0},{1}".format(enc_key.decode("utf-8"), enc_payload.decode("utf-8"))
if self.sender_key is not None:
# Generate sender/workload signature:
final_payload = encrypted_payload["message"]["payload"]
payload_signature = self._sign_payload(final_payload.encode("utf-8"))
encrypted_payload["message"]["signature"] =\
"{0}".format(payload_signature.decode("utf-8"))
encrypted_payload["message"]["fingerprint"] =\
"{0}".format(self.sender_fingerprint)

print(json.dumps(encrypted_payload, indent=4))

return encrypted_payload

7 changes: 6 additions & 1 deletion studio/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .local_db_provider import LocalDbProvider
from .s3_provider import S3Provider
from .gs_provider import GSProvider
from .model_setup import setup_model
from .model_setup import setup_model, get_model_db_provider
from . import logs

def get_config(config_file=None):
Expand Down Expand Up @@ -59,6 +59,11 @@ def replace_with_env(config):
.format(config_paths))

def get_db_provider(config=None, blocking_auth=True):

db_provider = get_model_db_provider()
if not db_provider is None:
return db_provider

if not config:
config = get_config()
verbose = parse_verbosity(config.get('verbose'))
Expand Down
4 changes: 2 additions & 2 deletions studio/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
def setup_model(db_provider, artifact_store):
_model_setup = { DB_KEY: db_provider, STORE_KEY: artifact_store }

def get_db_provider():
def get_model_db_provider():
if _model_setup is None:
return None
return _model_setup.get(DB_KEY, None)

def get_artifact_store():
def get_model_artifact_store():
if _model_setup is None:
return None
return _model_setup.get(STORE_KEY, None)
15 changes: 11 additions & 4 deletions studio/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,16 @@ def submit_experiments(

payload_builder = UnencryptedPayloadBuilder("simple-payload")
# Are we using experiment payload encryption?
key_path = config.get('public_key_path')
if key_path is not None:
logger.info("Using RSA public key path: {0}".format(key_path))
public_key_path = config.get('public_key_path', None)
if public_key_path is not None:
logger.info("Using RSA public key path: {0}".format(public_key_path))
signing_key_path = config.get('signing_key_path', None)
if signing_key_path is not None:
logger.info("Using RSA signing key path: {0}".format(signing_key_path))
payload_builder = \
EncryptedPayloadBuilder(
"cs-rsa-encryptor [{0}]".format(key_path), key_path)
"cs-rsa-encryptor [{0}]".format(public_key_path),
public_key_path, signing_key_path)

start_time = time.time()

Expand Down Expand Up @@ -637,6 +641,9 @@ def submit_experiments(

for experiment in experiments:
payload = payload_builder.construct(experiment, config, python_pkg)

print(json.dumps(payload, indent=4))

queue.enqueue(json.dumps(payload))
logger.info("studio run: submitted experiment " + experiment.key)

Expand Down