In this example, Python is used. When you create the functions, select Python 3.9 for Runtime Environments. The following code provides an example. You can modify the code based on your business requirements.
# -*- coding: utf-8 -*-
import json
import logging
import os
try:
import pymysql
except:
os.system('pip install pymysql -t ./')
import pymysql
from aliyunsdkcore.acs_exception.exceptions import ServerException
from aliyunsdkcore.auth.credentials import StsTokenCredential
from aliyunsdkcore.client import AcsClient
from aliyunsdkkms.request.v20160120.GetRandomPasswordRequest import GetRandomPasswordRequest
from aliyunsdkkms.request.v20160120.GetSecretValueRequest import GetSecretValueRequest
from aliyunsdkkms.request.v20160120.PutSecretValueRequest import PutSecretValueRequest
from aliyunsdkkms.request.v20160120.UpdateSecretVersionStageRequest import UpdateSecretVersionStageRequest
from aliyunsdkrds.request.v20140815.DescribeDBInstancesRequest import DescribeDBInstancesRequest
logger = logging.getLogger()
logger.setLevel(logging.INFO)
def handler(event, context):
evt = json.loads(event)
secret_name = evt['SecretName']
region_id = evt['RegionId']
step = evt['Step']
version_id = evt.get('VersionId')
if not version_id:
version_id = context.requestId
credentials = StsTokenCredential(context.credentials.accessKeyId, context.credentials.accessKeySecret,
context.credentials.securityToken)
client = AcsClient(region_id=region_id, credential=credentials)
endpoint = "kms-vpc." + region_id + ".aliyuncs.com"
client.add_endpoint(region_id, 'kms', endpoint)
resp = get_secret_value(client, secret_name)
if "Generic" != resp['SecretType']:
logger.error("Secret %s is not enabled for rotation" % secret_name)
raise ValueError("Secret %s is not enabled for rotation" % secret_name)
if step == "new":
new_phase(client, secret_name, version_id)
elif step == "set":
set_phase(client, secret_name, version_id)
elif step == "test":
test_phase(client, secret_name, version_id)
elif step == "end":
end_phase(client, secret_name, version_id)
else:
logger.error("handler: Invalid step parameter %s for secret %s" % (step, secret_name))
raise ValueError("Invalid step parameter %s for secret %s" % (step, secret_name))
return {"VersionId": version_id}
def new_phase(client, secret_name, version_id):
current_dict = get_secret_dict(client, secret_name, "ACSCurrent")
try:
get_secret_dict(client, secret_name, "ACSPending", version_id)
logger.info("new: Successfully retrieved secret for %s." % secret_name)
except ServerException as e:
if e.error_code != 'Forbidden.ResourceNotFound':
raise
current_dict['AccountName'] = get_alt_account_name(current_dict['AccountName'])
exclude_characters = os.environ['EXCLUDE_CHARACTERS'] if 'EXCLUDE_CHARACTERS' in os.environ else '/@"\'\\'
passwd = get_random_password(client, exclude_characters)
current_dict['AccountPassword'] = passwd['RandomPassword']
put_secret_value(client, secret_name, version_id, json.dumps(current_dict),
json.dumps(['ACSPending']))
logger.info(
"new: Successfully put secret for secret_name %s and version %s." % (secret_name, version_id))
def set_phase(client, secret_name, version_id):
current_dict = get_secret_dict(client, secret_name, "ACSCurrent")
pending_dict = get_secret_dict(client, secret_name, "ACSPending", version_id)
conn = get_connection(pending_dict)
if conn:
conn.close()
logger.info(
"set: ACSPending secret is already set as password in MySQL DB for secret secret_name %s." % secret_name)
return
if get_alt_account_name(current_dict['AccountName']) != pending_dict['AccountName']:
logger.error("set: Attempting to modify user %s other than current user or rotation %s" % (
pending_dict['AccountName'], current_dict['AccountName']))
raise ValueError("Attempting to modify user %s other than current user or rotation %s" % (
pending_dict['AccountName'], current_dict['AccountName']))
if current_dict['Endpoint'] != pending_dict['Endpoint']:
logger.error("set: Attempting to modify user for Endpoint %s other than current Endpoint %s" % (
pending_dict['Endpoint'], current_dict['Endpoint']))
raise ValueError("Attempting to modify user for Endpoint %s other than current Endpoint %s" % (
pending_dict['Endpoint'], current_dict['Endpoint']))
conn = get_connection(current_dict)
if not conn:
logger.error("set: Unable to access the given database using current credentials for secret %s" % secret_name)
raise ValueError("Unable to access the given database using current credentials for secret %s" % secret_name)
conn.close()
master_secret = current_dict['MasterSecret']
master_dict = get_secret_dict(client, master_secret, "ACSCurrent")
if current_dict['Endpoint'] != master_dict['Endpoint'] and not is_rds_replica_database(current_dict, master_dict):
logger.error("set: Current database Endpoint %s is not the same Endpoint as/rds replica of master %s" % (
current_dict['Endpoint'], master_dict['Endpoint']))
raise ValueError("Current database Endpoint %s is not the same Endpoint as/rds replica of master %s" % (
current_dict['Endpoint'], master_dict['Endpoint']))
conn = get_connection(master_dict)
if not conn:
logger.error(
"set: Unable to access the given database using credentials in master secret secret %s" % master_secret)
raise ValueError("Unable to access the given database using credentials in master secret secret %s" % master_secret)
try:
with conn.cursor() as cur:
cur.execute("SELECT User FROM mysql.user WHERE User = %s", pending_dict['AccountName'])
if cur.rowcount == 0:
cur.execute("CREATE USER %s IDENTIFIED BY %s",
(pending_dict['AccountName'], pending_dict['AccountPassword']))
cur.execute("SHOW GRANTS FOR %s", current_dict['AccountName'])
for row in cur.fetchall():
if 'XA_RECOVER_ADMIN' in row[0]:
continue
grant = row[0].split(' TO ')
new_grant_escaped = grant[0].replace('%', '%%') # % is a special cha30racter in Python format strings.
cur.execute(new_grant_escaped + " TO %s ", (pending_dict['AccountName'],))
cur.execute("SELECT VERSION()")
ver = cur.fetchone()[0]
escaped_encryption_statement = get_escaped_encryption_statement(ver)
cur.execute("SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s",
current_dict['AccountName'])
tls_options = cur.fetchone()
ssl_type = tls_options[0]
if not ssl_type:
cur.execute(escaped_encryption_statement + " NONE", pending_dict['AccountName'])
elif ssl_type == "ANY":
cur.execute(escaped_encryption_statement + " SSL", pending_dict['AccountName'])
elif ssl_type == "X509":
cur.execute(escaped_encryption_statement + " X509", pending_dict['AccountName'])
else:
cur.execute(escaped_encryption_statement + " CIPHER %s AND ISSUER %s AND SUBJECT %s",
(pending_dict['AccountName'], tls_options[1], tls_options[2], tls_options[3]))
password_option = get_password_option(ver)
cur.execute("SET PASSWORD FOR %s = " + password_option,
(pending_dict['AccountName'], pending_dict['AccountPassword']))
conn.commit()
logger.info("set: Successfully changed password for %s in MySQL DB for secret secret_name %s." % (
pending_dict['AccountName'], secret_name))
finally:
conn.close()
def test_phase(client, secret_name, version_id):
conn = get_connection(get_secret_dict(client, secret_name, "ACSPending", version_id))
if conn:
try:
with conn.cursor() as cur:
cur.execute("SELECT NOW()")
conn.commit()
finally:
conn.close()
logger.info("test: Successfully accessed into MySQL DB with ACSPending secret in %s." % secret_name)
return
else:
logger.error(
"test: Unable to access the given database with pending secret of secret secret_name %s" % secret_name)
raise ValueError("Unable to access the given database with pending secret of secret secret_name %s" % secret_name)
def end_phase(client, secret_name, version_id):
update_secret_version_stage(client, secret_name, 'ACSCurrent', move_to_version=version_id)
update_secret_version_stage(client, secret_name, 'ACSPending', remove_from_version=version_id)
logger.info(
"end: Successfully update ACSCurrent stage to version %s for secret %s." % (version_id, secret_name))
def get_connection(secret_dict):
port = int(secret_dict['Port']) if 'Port' in secret_dict else 3306
dbname = secret_dict['DBName'] if 'DBName' in secret_dict else None
use_ssl, fall_back = get_ssl_config(secret_dict)
conn = connect_and_authenticate(secret_dict, port, dbname, use_ssl)
if conn or not fall_back:
return conn
else:
return connect_and_authenticate(secret_dict, port, dbname, False)
def get_ssl_config(secret_dict):
if 'SSL' not in secret_dict:
return True, True
if isinstance(secret_dict['SSL'], bool):
return secret_dict['SSL'], False
if isinstance(secret_dict['SSL'], str):
ssl = secret_dict['SSL'].lower()
if ssl == "true":
return True, False
elif ssl == "false":
return False, False
else:
return True, True
return True, True
def connect_and_authenticate(secret_dict, port, dbname, use_ssl):
ssl = {'ca': '/opt/python/certs/cert.pem'} if use_ssl else None
try:
conn = pymysql.connect(host=secret_dict['Endpoint'], user=secret_dict['AccountName'],
password=secret_dict['AccountPassword'],
port=port, database=dbname, connect_timeout=5, ssl=ssl)
logger.info("Successfully established %s connection as user '%s' with Endpoint: '%s'" % (
"SSL/TLS" if use_ssl else "non SSL/TLS", secret_dict['AccountName'], secret_dict['Endpoint']))
return conn
except pymysql.OperationalError as e:
if 'certificate verify failed: IP address mismatch' in e.args[1]:
logger.error(
"Hostname verification failed when estlablishing SSL/TLS Handshake with Endpoint: %s" % secret_dict[
'Endpoint'])
return None
def get_secret_dict(client, secret_name, stage, version_id=None):
required_fields = ['Endpoint', 'AccountName', 'AccountPassword']
if version_id:
secret = get_secret_value(client, secret_name, version_id, stage)
else:
secret = get_secret_value(client, secret_name, stage=stage)
plaintext = secret['SecretData']
secret_dict = json.loads(plaintext)
for field in required_fields:
if field not in secret_dict:
raise KeyError("%s key is missing from secret JSON" % field)
return secret_dict
def get_alt_account_name(current_account_name):
rotation_suffix = "_rt"
if current_account_name.endswith(rotation_suffix):
return current_account_name[:(len(rotation_suffix) * -1)]
else:
new_account_name = current_account_name + rotation_suffix
if len(new_account_name) > 16:
raise ValueError(
"Unable to rotate user, account_name length with _rotation appended would exceed 16 characters")
return new_account_name
def get_password_option(version):
if version.startswith("8"):
return "%s"
else:
return "PASSWORD(%s)"
def get_escaped_encryption_statement(version):
if version.startswith("5.6"):
return "GRANT USAGE ON *.* TO %s@'%%' REQUIRE"
else:
return "ALTER USER %s@'%%' REQUIRE"
def is_rds_replica_database(client, replica_dict, master_dict):
replica_instance_id = replica_dict['Endpoint'].split(".")[0].replace('io', '')
master_instance_id = master_dict['Endpoint'].split(".")[0].replace('io', '')
try:
describe_response = describe_db_instances(client, replica_instance_id)
except Exception as err:
logger.warning("Encountered error while verifying rds replica status: %s" % err)
return False
items = describe_response['Items']
instances = items.get("DBInstance")
if not instances:
logger.info("Cannot verify replica status - no RDS instance found with identifier: %s" % replica_instance_id)
return False
current_instance = instances[0]
return master_instance_id == current_instance.get('DBInstanceId')
def get_secret_value(client, secret_name, version_id=None, stage=None):
request = GetSecretValueRequest()
request.set_accept_format('json')
request.set_SecretName(secret_name)
if version_id:
request.set_VersionId(version_id)
if stage:
request.set_VersionStage(stage)
response = client.do_action_with_exception(request)
return json.loads(response)
def put_secret_value(client, secret_name, version_id, secret_data, version_stages=None):
request = PutSecretValueRequest()
request.set_accept_format('json')
request.set_SecretName(secret_name)
request.set_VersionId(version_id)
if version_stages:
request.set_VersionStages(version_stages)
request.set_SecretData(secret_data)
response = client.do_action_with_exception(request)
return json.loads(response)
def get_random_password(client, exclude_characters=None):
request = GetRandomPasswordRequest()
request.set_accept_format('json')
if exclude_characters:
request.set_ExcludeCharacters(exclude_characters)
response = client.do_action_with_exception(request)
return json.loads(response)
def update_secret_version_stage(client, secret_name, version_stage, remove_from_version=None, move_to_version=None):
request = UpdateSecretVersionStageRequest()
request.set_accept_format('json')
request.set_VersionStage(version_stage)
request.set_SecretName(secret_name)
if remove_from_version:
request.set_RemoveFromVersion(remove_from_version)
if move_to_version:
request.set_MoveToVersion(move_to_version)
response = client.do_action_with_exception(request)
return json.loads(response)
def describe_db_instances(client, db_instance_id):
request = DescribeDBInstancesRequest()
request.set_accept_format('json')
request.set_DBInstanceId(db_instance_id)
response = client.do_action_with_exception(request)
return json.loads(response)