equinor / ert

ERT - Ensemble based Reservoir Tool - is designed for running ensembles of dynamical models such as reservoir models, in order to do sensitivity analysis and data assimilation. ERT supports data assimilation using the Ensemble Smoother (ES), Ensemble Smoother with Multiple Data Assimilation (ES-MDA) and Iterative Ensemble Smoother (IES).
https://ert.readthedocs.io/en/latest/
GNU General Public License v3.0
101 stars 104 forks source link

Explore dask communication layer tls providing encrypted communication #6534

Closed xjules closed 1 month ago

xjules commented 10 months ago

By default dask uses tcp sockets for communications. Nevertheless it provides also a secure communication transport via tls. We should find out how this can be setup in our infrastructure. More info here: https://distributed.dask.org/en/latest/communications.html

jonathan-eq commented 10 months ago

This is how it is suggested we add TLS encryption

from distributed import Client
from distributed.security import Security

sec = Security(tls_ca_file='cluster_ca.pem',
               tls_client_cert='cli_cert.pem',
               tls_client_key='cli_key.pem',
               require_encryption=True)

client = Client(..., security=sec)
jonathan-eq commented 9 months ago

That did not work with a local cluster, but I found a work around.

jonathan-eq commented 9 months ago

First we need some certificates, keys, and a CA certificate to verify the server. I used this script to generate all of those.

from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from datetime import datetime
from datetime import timedelta

from functools import partial
import contextlib
import os
from pathlib import Path

from yaml import load as _load_yaml
from yaml import dump as _dump_yaml
try:
    from yaml import CLoader as _yaml_Loader
except ImportError:
    from yaml import Loader as _yaml_Loader

def load_yaml(filename):
    with open(filename, 'r') as f:
        return _load_yaml(f, Loader=_yaml_Loader)

def dump_yaml(obj, filename):
    with open(filename, 'w') as f:
        _dump_yaml(obj, f)

def create_tls_identities(parent_path: str):
    assert Path(parent_path).exists()
    new_path = Path(parent_path, "tls_identities")
    new_path.mkdir(exist_ok=True)

    def secure_opener(path, flags):
        return os.open(path, flags, 0o600)

    s_open_wb = partial(open, mode='wb', opener=secure_opener)

    def dump(name, constructor):
        identity = constructor(name)

        identity.private_key_path = str(Path(new_path,f'{name}-private-key.pem'))
        identity.public_key_path = str(Path(new_path, f'{name}-public-key.pem'))
        identity.cert_path = str(Path(new_path, f'{name}-cert.pem'))

        with contextlib.ExitStack() as exit_stack:
            private_key_file = (
                exit_stack.enter_context(s_open_wb(identity.private_key_path))
            )
            public_key_file = (
                exit_stack.enter_context(s_open_wb(identity.public_key_path))
            )
            cert_file = (
                exit_stack.enter_context(s_open_wb(identity.cert_path))
            )

            private_key_file.write(identity.private_key_bytes)
            public_key_file.write(identity.public_key_bytes)
            cert_file.write(identity.cert_bytes)

        return identity

    ca = dump('xun-private-ca', create_certificate_authority)
    client = dump('xun-dask-client', partial(create_client, ca=ca))
    scheduler = dump('xun-dask-scheduler', partial(create_client, ca=ca))
    worker = dump('xun-dask-worker', partial(create_client, ca=ca))
    update_config({
            'distributed': {
                'comm': {
                    'require-encryption': True,
                    'tls': {
                        'ca-file': ca.cert_path,
                        'client': {
                            'key': client.private_key_path,
                            'cert': client.cert_path,
                        },
                        'scheduler': {
                            'key': scheduler.private_key_path,
                            'cert': scheduler.cert_path,
                        },
                        'worker': {
                            'key': worker.private_key_path,
                            'cert': worker.cert_path,
                        },
                    }
                },
            },
        }, Path(new_path, "output-tls-config"))

def update_config(config, path):
    path.parent.mkdir(exist_ok=True)

    if path.exists():
        old_config = load_yaml(path) or {}
        old_config.update(config)
        config = old_config

    dump_yaml(config, path)

class Identity:
    def __init__(self, key: rsa.RSAPrivateKey, cert: x509.Certificate):
        self.key = key
        self.cert = cert

        self.private_key = self.key
        self.public_key = self.private_key.public_key()

    @property
    def private_key_bytes(self):
        return self.private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption(),
        )

    @property
    def public_key_bytes(self):
        return self.public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo,
        )

    @property
    def cert_bytes(self):
        return self.cert.public_bytes(serialization.Encoding.PEM)

def create_certificate_authority(common_name='xun-private-ca') -> x509.Certificate:
    ca_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

    xun_name = x509.Name(
        [x509.NameAttribute(NameOID.COMMON_NAME, common_name)]
    )

    now = datetime.utcnow()

    ca_cert = (x509.CertificateBuilder()
        .subject_name(xun_name)
        .issuer_name(xun_name)

        .public_key(ca_key.public_key())
        .serial_number(x509.random_serial_number())

        .add_extension(
            x509.SubjectAlternativeName([x509.DNSName(common_name)]),
            critical=False)
        .add_extension(
            x509.BasicConstraints(ca=True, path_length=0),
            critical=True)
        .add_extension(
            x509.KeyUsage(digital_signature=False,
                          content_commitment=False,
                          key_encipherment=False,
                          data_encipherment=False,
                          key_agreement=False,
                          key_cert_sign=True,
                          crl_sign=True,
                          encipher_only=False,
                          decipher_only=False),
            critical=True)

        .not_valid_before(now)
        .not_valid_after(now + timedelta(days=60))

        .sign(ca_key, hashes.SHA256())
    )

    return Identity(key=ca_key, cert=ca_cert)

def create_client(common_name, ca) -> Identity:
    key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

    name = x509.Name(
        [x509.NameAttribute(NameOID.COMMON_NAME, common_name)]
    )
    issuer_name = x509.Name(
        ca.cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
    )

    now = datetime.utcnow()

    cert = (x509.CertificateBuilder()
        .subject_name(name)
        .issuer_name(issuer_name)

        .public_key(key.public_key())
        .serial_number(x509.random_serial_number())

        .add_extension(
            x509.BasicConstraints(ca=False, path_length=None),
            critical=True)

        .not_valid_before(now)
        .not_valid_after(now + timedelta(days=60))

        .sign(ca.key, hashes.SHA256())
    )

    return Identity(key=key, cert=cert)

if __name__ == "__main__":
    path = Path(__file__).parent
    create_tls_identities(str(path))

Then, we have to update the default config for dask. This took a long while to figure out... (might be related to this: https://github.com/dask/distributed/issues/2815) The file was in venv/lib/python3.xx/site-packages/distributed/distributed.yaml. We add the result from the script under distributed.comm.tls, so it looks like this

---
tls:
      ciphers: null     # Allowed ciphers, specified as an OpenSSL cipher string.
      min-version: 1.2  # The minimum TLS version supported.
      max-version: null # The maximum TLS version supported.
      ca-file: <DIR>/dask/tls_identities/xun-private-ca-cert.pem
      scheduler:
        cert:  <DIR>/dask/tls_identities/xun-dask-scheduler-cert.pem
        key:  <DIR>/dask/tls_identities/xun-dask-scheduler-private-key.pem
      worker:
        cert:  <DIR>/dask/tls_identities/xun-dask-worker-cert.pem
        key:  <DIR>/dask/tls_identities/xun-dask-worker-private-key.pem
      client:
        cert:  <DIR>/dask/tls_identities/xun-dask-client-cert.pem
        key:  <DIR>/dask/tls_identities/xun-dask-client-private-key.pem
---

Running the code this way worked:

from dask.distributed import Client, Security

from pathlib import Path
import yaml
import asyncio
import time
from pathlib import Path

async def main():
    current_dir = Path(__file__).parent.absolute().__str__()
    with open(Path(current_dir, "tls_identities", "output-tls-config"), mode="rt", encoding="utf-8") as f:
        yaml_dict: dict = yaml.load(f,yaml.CSafeLoader)

    tls_section = yaml_dict["distributed"]["comm"]["tls"]

    common_security = Security(
        require_encryption=True,
        tls_ca_file=tls_section["ca-file"],
        tls_scheduler_cert=tls_section["scheduler"]["cert"],
        tls_scheduler_key=tls_section["scheduler"]["key"],
        tls_worker_cert=tls_section["worker"]["cert"],
        tls_worker_key=tls_section["worker"]["key"],
        tls_client_cert=tls_section["client"]["cert"],
        tls_client_key=tls_section["client"]["key"]
    )
    client = Client(security=common_security)

    def say_hi(name: str) -> str:
        time.sleep(len(name))
        return f"Hello {name}"

    futures = client.map(say_hi, ["John", "Peter", "Leif", "Geir", "Jonathan", "Cesc Fabregas", "Ondrazhek - The Cobra"])
    print(client.gather(futures))
    client.close()

if __name__ == "__main__":
    asyncio.run(main())
jonathan-eq commented 9 months ago

This solution worked for localcluster, but having a remote cluster should not lead to more problems in theory..

oyvindeide commented 1 month ago

Since it was moved to the Done column I am closing this issue as completed.