Module brevettiai.interfaces.sagemaker
Expand source code
import json
import logging
import os
import sys
from brevettiai.io.credentials import Credentials, LoginError
log = logging.getLogger(__name__)
SAGEMAKER_HYPERPARAMETER_PATH = "/opt/ml/input/config/hyperparameters.json"
def load_hyperparameters_cmd_args(hyperparameter_path=SAGEMAKER_HYPERPARAMETER_PATH):
    try:
        with open(hyperparameter_path, "r") as hyper:
            hyper_parameters = json.load(hyper)
        log.info("Loaded hyper parameters " + json.dumps(hyper_parameters))
        sys.argv += [kki for kk, vv in hyper_parameters.items() for kki in ["--" + kk, vv]]
        log.info("Added hyper parameters to sys.argv")
    except IOError:
        log.info("No hyper parameters found!")
class SagemakerCredentials(Credentials):
    def get_credentials(self, resource_id, resource_type="dataset", mode="r"):
        if not "TRAINING_JOB_ARN" in os.environ:
            raise LoginError(f"Error logging in with Sagemaker credentials, Not on Sagemaker")
        try:
            return fetch_aws_credentials()
        except Exception:
            raise LoginError(f"Error logging in with Sagemaker credentials")
    def set_credentials(self, type, user, secret, **kwargs):
        pass
def fetch_aws_credentials():
    import boto3
    _, partition, service, region, *_ = os.environ.get("TRAINING_JOB_ARN", "arn:aws:sagemaker:eu-west-1:xx").split(":")
    sess = boto3.session.Session()
    cred = sess.get_credentials()
    s3 = sess.client("s3", region_name=region)
    os.environ["S3_USE_HTTPS"] = "1"
    os.environ["S3_VERIFY_SSL"] = "1"
    return dict(
        access_key=cred.access_key,
        secret_key=cred.secret_key,
        region=region,
        session_token=cred.token,
        endpoint=s3.meta.endpoint_url.split("://", 1)[-1]
    )
Functions
def fetch_aws_credentials()- 
Expand source code
def fetch_aws_credentials(): import boto3 _, partition, service, region, *_ = os.environ.get("TRAINING_JOB_ARN", "arn:aws:sagemaker:eu-west-1:xx").split(":") sess = boto3.session.Session() cred = sess.get_credentials() s3 = sess.client("s3", region_name=region) os.environ["S3_USE_HTTPS"] = "1" os.environ["S3_VERIFY_SSL"] = "1" return dict( access_key=cred.access_key, secret_key=cred.secret_key, region=region, session_token=cred.token, endpoint=s3.meta.endpoint_url.split("://", 1)[-1] ) def load_hyperparameters_cmd_args(hyperparameter_path='/opt/ml/input/config/hyperparameters.json')- 
Expand source code
def load_hyperparameters_cmd_args(hyperparameter_path=SAGEMAKER_HYPERPARAMETER_PATH): try: with open(hyperparameter_path, "r") as hyper: hyper_parameters = json.load(hyper) log.info("Loaded hyper parameters " + json.dumps(hyper_parameters)) sys.argv += [kki for kk, vv in hyper_parameters.items() for kki in ["--" + kk, vv]] log.info("Added hyper parameters to sys.argv") except IOError: log.info("No hyper parameters found!") 
Classes
class SagemakerCredentials- 
Abstract class for credential managers
Expand source code
class SagemakerCredentials(Credentials): def get_credentials(self, resource_id, resource_type="dataset", mode="r"): if not "TRAINING_JOB_ARN" in os.environ: raise LoginError(f"Error logging in with Sagemaker credentials, Not on Sagemaker") try: return fetch_aws_credentials() except Exception: raise LoginError(f"Error logging in with Sagemaker credentials") def set_credentials(self, type, user, secret, **kwargs): passAncestors
- Credentials
 - abc.ABC
 
Methods
def get_credentials(self, resource_id, resource_type='dataset', mode='r')- 
Expand source code
def get_credentials(self, resource_id, resource_type="dataset", mode="r"): if not "TRAINING_JOB_ARN" in os.environ: raise LoginError(f"Error logging in with Sagemaker credentials, Not on Sagemaker") try: return fetch_aws_credentials() except Exception: raise LoginError(f"Error logging in with Sagemaker credentials") def set_credentials(self, type, user, secret, **kwargs)- 
Expand source code
def set_credentials(self, type, user, secret, **kwargs): pass