SageMakerノートブックインスタンスを定期/自動停止したい

AWS

この記事を読むのに必要な時間は 約12分 です。

Pythonで機械学習するのに、AWSが提供しているAmazon SageMakerによくお世話になっています。
ノートブックインスタンスを使うと、Pythonの実行環境としてJupyter(Lab)がサクッと起動して使えるようになるのですが、停止を忘れると ずーっと 課金されてしまうことにも…
(チャリンチャリン コワイコワイ)

ということで(どういうことで?)
起動したままのノートブックインスタンスを見つけ出して、未操作が続いている場合は自動停止させる Lambda関数 を構築しました。

(準備)メール通知用のSNSトピックを作ります

自動停止の実行結果をメール通知で受け取りたいので、Amazon SNSのトピックを作成します。
<こちらのリンクからトピックの作成>を始めます。

しんじ
しんじ

作成時の設定は、最低限でOKなので、
 タイプ:スタンダード
 名前:sagemaker-autostop-notebook-topic
と入力し、一番下までスクロールして「トピックの作成」をクリックします。

Screenshot
しんじ
しんじ

無事にトピックが作成できました。
詳細の ARN の文字列を後で使うので、メモしておきましょう。

続いて、トピックが作成されたら、サブスクリプションの作成を行います
右下にある「サブスクリプションの作成」をクリック

Screenshot
しんじ
しんじ

プロトコルを「Eメール」に設定して、
エンドポイントには、受信用のメールアドレスを入力して、
サブスクリプションの作成をクリック

しんじ
しんじ

すると、エンドポイントに指定したアドレスにこんなメールが届きます。

メール本文に Confirm subscription というリンクがありますので、これをクリックします

しんじ
しんじ

リンクをクリックした先で、こんな表示が出ればOKです

以上で、Amazon SNSでの設定は完了です。

Lambda関数を作ります

しんじ
しんじ

それでは、事前準備が完了したので、Lambda関数を作っていきましょう!

まずは、下の図を参考にトリガーを追加します。
(EventBridgeに新規ルールを作成します)

次に、設定→環境変数を、次の表のとおりに設定します。
ちなみに、後でも説明しますが、ノートブックインスタンスのタグに「auto_stop: False」と設定すると、この自動停止操作の対象外になります。

キー
STOP_HOUR4
TAG_KEYauto_stop
TOPIC_ARNarn:aws:sns:ap-northeast-1:<ここにアカウントID>:sagemaker-autostop-notebook-topic
※最後の項目はAmazon SNSで設定したトピック名

コードソースに以下のコードをコピペします。
コピペしたら、忘れずに「Deploy」ボタンをクリックしましょう。

import boto3
import json
import os
import logging
from datetime import datetime, timezone, timedelta


# 環境変数を読み込む
TAG_KEY = os.environ['TAG_KEY']    # 'auto_stop'
STOP_HOUR = int(os.environ['STOP_HOUR'])    # 6
TOPIC_ARN = os.environ['TOPIC_ARN']


def lambda_handler(event, context):
    # Loggerの設定
    logger = logging.getLogger()
    logger.setLevel("INFO")
    
    # 環境変数を出力しておく
    logger.info('## ENVIRONMENT VARIABLES ##')
    logger.info(os.environ['TAG_KEY'])
    logger.info(os.environ['STOP_HOUR'])

    # 連続稼働の上限(秒)
    limit_time = 60 * 60 * STOP_HOUR
    # limit_time = 60    # テスト用1分で停止対象…!?
    jst = timezone(timedelta(hours=9), 'JST')

    # Client
    client = boto3.client('sagemaker')    # SageMaker
    cw_logs_client = boto3.client('logs')    # CloudWatch
    sns_client = boto3.client('sns')    # 

    # --------------------------------------------------------------------
    # 1. 全てのノートブックインスタンスのリストを取得する
    # 初期リクエスト
    response = client.list_notebook_instances()
    notebook_instances = response.get('NotebookInstances', [])
    # NextTokenが存在する限り、追加のノートブックインスタンスを取得
    while 'NextToken' in response:
        response = client.list_notebook_instances(NextToken=response['NextToken'])
        notebook_instances.extend(response.get('NotebookInstances', []))
    logger.info(f"ノートブックインスタンスの総数: {len(notebook_instances)}")

    # --------------------------------------------------------------------
    # 2. 全てのノートブックインスタンスのアイドル時間をチェックする
    in_operation_longtime = []    # アイドル時間の上限を超えたインスタンスのリスト
    notebook_counter = 0
    for idx, notebook_instance in enumerate(notebook_instances):
        instance_name = notebook_instance['NotebookInstanceName']
        last_modified_time = notebook_instance['LastModifiedTime']
        if not notebook_instance['NotebookInstanceStatus'] == 'InService':
            continue
        else:
            notebook_counter += 1
        try:
            logs = cw_logs_client.get_log_events(
                logGroupName="/aws/sagemaker/NotebookInstances",
                logStreamName=f"{instance_name}/jupyter.log",
                limit=10,
                startFromHead=False,
                startTime=int(last_modified_time.timestamp()) * 1000,
                endTime=int(datetime.now().timestamp()) * 1000,
            )['events']
            logger.info(f"[debug]インスタンス名:{instance_name} logs:{logs}")
        except:
            logger.error(f"ログ取得ERROR インスタンス名:{instance_name}")
            continue
        if len(logs) == 0:
            timestamp = last_modified_time.timestamp()
        else:
            target_logs = logs[-1]
            timestamp = int(str(target_logs['timestamp'])[:10])
        diff = datetime.now() - datetime.fromtimestamp(timestamp)
        logger.info(f"[debug]経過時間: {diff}")

        tags = client.list_tags(ResourceArn=notebook_instance['NotebookInstanceArn'])['Tags']
        logger.info(f"[debug]タグ情報: {tags}")
        # logger.info(f"[debug]diff:{diff.seconds} <-> limit:{limit_time}")
        if diff.seconds >= limit_time:
            # Tag情報があれば記録する
            # tags = client.list_tags(ResourceArn=notebook_instance['NotebookInstanceArn'])['Tags']
            # logger.info(f"[debug] tags: {tags}")
            if len(tags):
                # logger.info(f"[debug]タグのデータ型 [tags type] key:{type(tags[0]['Key'])} value:{type(tags[0]['Value'])}")
                if (tags[0]['Key'] == TAG_KEY) & (tags[0]['Value'] == "False"):
                    tag_value = False
                else:
                    tag_value = True
            else:
                tag_value = True
            in_operation_longtime.append({
                "idx": idx,
                "instance_name": instance_name,
                "last_modified_time": (last_modified_time.astimezone(tz=jst)).strftime("%Y-%m-%d %H:%I:%S"),
                "seconds": diff.seconds,
                "auto_stop_flag": tag_value
            })
            logger.info(f"{instance_name}が{STOP_HOUR}時間以上(約{int(diff.total_seconds() // (60 * 60))}時間)起動中です >>> 停止フラグ:{tag_value}")
    logger.info(f"起動中のノートブックインスタンスの数は{notebook_counter}です")
    logger.info(f"[debug]長時間起動しているノートブックインスタンス:{in_operation_longtime}")

    # --------------------------------------------------------------------
    # 3. 停止対象のノートブックインスタンスを停止する
    for target_instance in in_operation_longtime:
        idx = target_instance['idx']
        instance_name = target_instance['instance_name']
        notebook_instance = notebook_instances[idx]
        if target_instance['auto_stop_flag']:
            logger.info(f"{idx}> {instance_name}を停止します")
            response = client.stop_notebook_instance(NotebookInstanceName=instance_name)
            logger.info(response)
        else:
            logger.info(f"{idx}> {instance_name}はタグ判定で停止外です")

    # --------------------------------------------------------------------
    # 4. 実行結果をメール通知する
    instances_str = "\n".join([str(instance) for instance in in_operation_longtime])    # 本文用に文字を成形
    SUBJECT = "ノートブックインスタンス自動停止処理通知"
    MESSAGE = f"AWS SageMakerノートブックインスタンスについて、長時間アイドル状態のインスタンスを停止させる処理を実行しました。\n\n"\
                f"<実行結果>\n起動中のノートブックインスタンスの数は{notebook_counter}です\n"\
                f"停止対象のノートブックインスタンス:\n{instances_str}\n※`auto_stop_flag == False`のインスタンスは除きます"
    # SNSトピックにメッセージを公開
    response = sns_client.publish(
        TopicArn=TOPIC_ARN,
        Message=MESSAGE,
        Subject=SUBJECT
    )
    logger.info(f"通知メールの送信結果: {response}")

    return {
        'statusCode': 200,
        'body': json.dumps('done')
    }
しんじ
しんじ

これで全ての準備が完了です!
毎日10時・22時になったら、停止し忘れたノートブックインスタンスを自動で停止してくれます。
ここまでの構築、お疲れさまでした!!

(おまけ)自動停止したくないノートブックインスタンスの設定

最後におまけですが、どうしても自動停止したくないノートブックインスタンスもあるかもしれません。
そういうときは、ノートブックインスタンスの設定画面で、下までスクロールすると「タグ」の設定項目があります。
「編集」を押して設定画面に入り、
 キー: auto_stop
 値: False
と設定すると、自動停止の対象外にできますので、適宜設定してみてください。

しんじ
しんじ

今回は以上になります。
ここまで読んでいただき、本当にありがとうございました!

コメント

タイトルとURLをコピーしました