Module brevettiai.interfaces.sagemaker
Expand source code
import json
import logging
import os
import sys
from brevettiai.io.credentials import Credentials, LoginError
log = logging.getLogger(__name__)
SAGEMAKER_HYPERPARAMETER_PATH = "/opt/ml/input/config/hyperparameters.json"
def load_hyperparameters_cmd_args(hyperparameter_path=SAGEMAKER_HYPERPARAMETER_PATH):
try:
with open(hyperparameter_path, "r") as hyper:
hyper_parameters = json.load(hyper)
log.info("Loaded hyper parameters " + json.dumps(hyper_parameters))
sys.argv += [kki for kk, vv in hyper_parameters.items() for kki in ["--" + kk, vv]]
log.info("Added hyper parameters to sys.argv")
except IOError:
log.info("No hyper parameters found!")
class SagemakerCredentials(Credentials):
def get_credentials(self, resource_id, resource_type="dataset", mode="r"):
if not "TRAINING_JOB_ARN" in os.environ:
raise LoginError(f"Error logging in with Sagemaker credentials, Not on Sagemaker")
try:
return fetch_aws_credentials()
except Exception:
raise LoginError(f"Error logging in with Sagemaker credentials")
def set_credentials(self, type, user, secret, **kwargs):
pass
def fetch_aws_credentials():
import boto3
_, partition, service, region, *_ = os.environ.get("TRAINING_JOB_ARN", "arn:aws:sagemaker:eu-west-1:xx").split(":")
sess = boto3.session.Session()
cred = sess.get_credentials()
s3 = sess.client("s3", region_name=region)
os.environ["S3_USE_HTTPS"] = "1"
os.environ["S3_VERIFY_SSL"] = "1"
return dict(
access_key=cred.access_key,
secret_key=cred.secret_key,
region=region,
session_token=cred.token,
endpoint=s3.meta.endpoint_url.split("://", 1)[-1]
)
Functions
def fetch_aws_credentials()
-
Expand source code
def fetch_aws_credentials(): import boto3 _, partition, service, region, *_ = os.environ.get("TRAINING_JOB_ARN", "arn:aws:sagemaker:eu-west-1:xx").split(":") sess = boto3.session.Session() cred = sess.get_credentials() s3 = sess.client("s3", region_name=region) os.environ["S3_USE_HTTPS"] = "1" os.environ["S3_VERIFY_SSL"] = "1" return dict( access_key=cred.access_key, secret_key=cred.secret_key, region=region, session_token=cred.token, endpoint=s3.meta.endpoint_url.split("://", 1)[-1] )
def load_hyperparameters_cmd_args(hyperparameter_path='/opt/ml/input/config/hyperparameters.json')
-
Expand source code
def load_hyperparameters_cmd_args(hyperparameter_path=SAGEMAKER_HYPERPARAMETER_PATH): try: with open(hyperparameter_path, "r") as hyper: hyper_parameters = json.load(hyper) log.info("Loaded hyper parameters " + json.dumps(hyper_parameters)) sys.argv += [kki for kk, vv in hyper_parameters.items() for kki in ["--" + kk, vv]] log.info("Added hyper parameters to sys.argv") except IOError: log.info("No hyper parameters found!")
Classes
class SagemakerCredentials
-
Abstract class for credential managers
Expand source code
class SagemakerCredentials(Credentials): def get_credentials(self, resource_id, resource_type="dataset", mode="r"): if not "TRAINING_JOB_ARN" in os.environ: raise LoginError(f"Error logging in with Sagemaker credentials, Not on Sagemaker") try: return fetch_aws_credentials() except Exception: raise LoginError(f"Error logging in with Sagemaker credentials") def set_credentials(self, type, user, secret, **kwargs): pass
Ancestors
- Credentials
- abc.ABC
Methods
def get_credentials(self, resource_id, resource_type='dataset', mode='r')
-
Expand source code
def get_credentials(self, resource_id, resource_type="dataset", mode="r"): if not "TRAINING_JOB_ARN" in os.environ: raise LoginError(f"Error logging in with Sagemaker credentials, Not on Sagemaker") try: return fetch_aws_credentials() except Exception: raise LoginError(f"Error logging in with Sagemaker credentials")
def set_credentials(self, type, user, secret, **kwargs)
-
Expand source code
def set_credentials(self, type, user, secret, **kwargs): pass