''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import os
import sys
import time
import uuid
import yaml
import hashlib
import base64
import requests
import tempfile
import importlib
import revive
import shutil
import base64
from loguru import logger
from revive.conf.config import base_url
[docs]
def customer_createTrain(machineCode: str,
trainModelSimulatorTotalCount: str,
trainModelPolicyTotalCount: str,
trainDataRowsCount: str,
yamlNodeCount: str,
yamlFileClientUrl: str,
configFileClientUrl: str,
logFileClientUrl: str,
userPrivateKey: str,
):
"""
Verify the user's training privileges.
API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#uFKnl
"""
url = base_url + "api/customer/createTrain"
machineCode = base64.b64encode(machineCode.encode("utf-8")).decode()
payload={'machineCode': machineCode,
'trainModelSimulatorTotalCount': trainModelSimulatorTotalCount,
'trainModelPolicyTotalCount': trainModelPolicyTotalCount,
'trainDataRowsCount': trainDataRowsCount,
'yamlNodeCount': yamlNodeCount,
'yamlFileClientUrl': yamlFileClientUrl,
'configFileClientUrl': configFileClientUrl,
'logFileClientUrl': logFileClientUrl,
'userPrivateKey': userPrivateKey,
'version': revive.__version__}
files=[
('yamlFile',(os.path.basename(yamlFileClientUrl),open(yamlFileClientUrl,'rb'),'application/octet-stream')),
('configFile',(os.path.basename(configFileClientUrl),open(configFileClientUrl,'rb'),'application/octet-stream'))
]
headers = {
'requestId': uuid.uuid4().hex,
}
response = requests.request("POST", url, headers=headers, data=payload, files=files, timeout=10)
res = eval(response.text)
res["data"]["license"] = base64.b64decode(res["data"]["license"])
logger.info(f"customer_createTrain,{res['code']}")
return res
[docs]
def customer_uploadTrainFile(trainId: str,
accessToken: str,
yamlFile: str = None,
configFile: str = None,
logFile: str = None,
):
"""
Upload the history train log.
API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#r5IPw
"""
url = base_url + "api/customer/uploadTrainFile"
upload_files = {name: file for name,file in zip(['yamlFile', 'configFile', 'logFile'],[yamlFile, configFile, logFile]) if file is not None}
payload_files = {name: file for name,file in zip(['yamlFileClientUrl', 'configFileClientUrl', 'logFileClientUrl'],[yamlFile, configFile, logFile]) if file is not None}
payload={'trainId': trainId,}
payload.update(payload_files)
tmp = tempfile.NamedTemporaryFile(delete=True)
files = {file:(file_name,tmp,'application/octet-stream') for file,file_name in zip(['yamlFile', 'configFile', 'logFile'],['none.yaml', 'none.json', 'none.log'])}
for name, file in upload_files.items():
files[name] = (os.path.basename(file),open(file,'rb'),'application/octet-stream')
files = [(k,v) for k,v in files.items()]
headers = {'requestId': uuid.uuid4().hex,
'Authorization': accessToken,
}
response = requests.request("POST", url, headers=headers, data=payload, files=files, timeout=10)
tmp.close()
logger.info(f"customer_uploadTrainFile,{eval(response.text)['code']}")
return eval(response.text)
[docs]
def customer_uploadTrainLog(trainId: str,
logFile: str,
trainType: str,
trainResult: str,
trainScore: str,
accessToken: str,
):
"""
Upload the log after a trail is trained.
API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#KvKWx
"""
url = base_url + "api/customer/uploadTrainLog"
payload={'trainId': trainId,
'trainType': trainType,
'trainResult' : trainResult,
'trainScore' : trainScore}
files=[('logFile',(os.path.basename(logFile),open(logFile,'rb'),'application/octet-stream')),]
headers = {'requestId': uuid.uuid4().hex,
'Authorization': accessToken,
}
response = requests.request("POST", url, headers=headers, data=payload, files=files, timeout=10)
logger.info(f"customer_uploadTrainLog,{eval(response.text)['code']}")
return eval(response.text)
[docs]
def check_license(cls):
from io import StringIO
from wurlitzer import pipes, STDOUT
out = StringIO()
package_path = os.path.join(os.path.dirname(revive.__file__), "dist")
sys.path.insert(0, package_path)
with pipes(stdout=out, stderr=STDOUT):
os.system("pyarmor hdinfo")
try:
REVIVE_LICENSE = os.getenv("REVIVE_LICENSE")
sys.PYARMOR_LICENSE = REVIVE_LICENSE
if not sys.PYARMOR_LICENSE:
sys.PYARMOR_LICENSE = "/"
try:
importlib.import_module(f'revive.algo.venv.revive_p')
logger.info(f"import revive.algo.venv.revive_p")
except:
if not isinstance(REVIVE_LICENSE, str):
logger.warning(f"Don't find valid 'REVIVE_LICENSE' in environment variables. Please check.")
time.sleep(3)
raise NotImplementedError
else:
try:
assert os.path.basename(REVIVE_LICENSE) == "license.lic"
shutil.copy(REVIVE_LICENSE,os.path.join(package_path, "pyarmor_runtime_005497", "pyarmor.rkey"))
importlib.import_module(f'revive.dist.algo.venv.revive_p')
logger.info(f"imsysport revive.dist.algo.venv.revive_p")
except Exception as e:
logger.info(f"Can't find local REVIVE_LICENSE file. {e}")
raise NotImplementedError
except:
logger.info(f"Get online REVIVE_LICENSE file.")
from revive.utils.license_utils import get_machine_info
machineCode = get_machine_info(os.path.join("/tmp/machine_info.json"), online=True)
if cls.venv_mode is None and cls.policy_mode is None:
logger.error(f"Don't train venv and policy. Please check the venv_mode and policy_mode.")
sys.exit()
trainModelSimulatorTotalCount = 0
if cls.venv_mode == "once":
trainModelSimulatorTotalCount += 1
elif cls.venv_mode == "tune":
trainModelSimulatorTotalCount += cls.config["train_venv_trials"]
else:
pass
trainModelPolicyTotalCount = 0
if cls.policy_mode == "once":
trainModelPolicyTotalCount += 1
elif cls.policy_mode == "tune":
trainModelPolicyTotalCount += cls.config["train_policy_trials"]
else:
pass
trainDataRowsCount = len(cls.dataset)
yamlNodeCount = len(cls.dataset.graph)
yamlFileClientUrl = os.path.abspath(cls.config_file)
configFileClientUrl = os.path.abspath(cls.revive_config_file_path)
logFileClientUrl = os.path.abspath(cls.revive_log_path)
config_folder = os.path.join(os.path.expanduser('~'),".revive")
with open(os.path.join(config_folder,'config.yaml'), 'r', encoding='utf-8') as f:
revive_config = yaml.load(f, Loader=yaml.FullLoader)
if "accesskey" not in revive_config.keys():
logger.error(f"Please check the ``~/.revive/config.yaml`` file for the configuration.")
sys.exit()
userPrivateKey = revive_config["accesskey"]
args = {'machineCode': machineCode,
'trainModelSimulatorTotalCount': trainModelSimulatorTotalCount,
'trainModelPolicyTotalCount': trainModelPolicyTotalCount,
'trainDataRowsCount': trainDataRowsCount,
'yamlNodeCount': yamlNodeCount,
'yamlFileClientUrl': yamlFileClientUrl,
'configFileClientUrl': configFileClientUrl,
'logFileClientUrl': logFileClientUrl,
'userPrivateKey': userPrivateKey}
args = {k:str(v) for k,v in args.items()}
for i in range(1):
result = customer_createTrain(**args)
if result["code"] == "General.Success":
trainId = result["data"]["trainId"]
license = result["data"]["license"]
accessToken = result["data"]["accessToken"]
cls.config["trainId"] = trainId
cls.config["accessToken"] = accessToken
license_path = os.path.abspath(os.path.join(cls.workspace, "license.lic"))
with open(license_path, "wb") as f:
f.writelines([license,])
sys.PYARMOR_LICENSE = license_path
shutil.copy(license_path,os.path.join(package_path, "pyarmor_runtime_005497", "pyarmor.rkey"))
time.sleep(1)
break
logger.info(f"{result['code']}")
if result["code"] == "General.Success":
try:
if result["data"]["beforeTrainId"]:
# logger.info(f"{result['msg']}")
args = {'trainId': result["data"]["beforeTrainId"],
'accessToken': result["data"]["accessToken"],}
if result["data"]["needYamlFileClientUrl"]:
if os.path.exists(result["data"]["needYamlFileClientUrl"]):
args["yamlFile"] = result["data"]["needYamlFileClientUrl"]
if result["data"]["needConfigFileClientUrl"]:
if os.path.exists(result["data"]["needConfigFileClientUrl"]):
args["configFile"] = result["data"]["needConfigFileClientUrl"]
if result["data"]["needLogFileClientUrl"]:
if os.path.exists(result["data"]["needLogFileClientUrl"]):
args["logFile"] = result["data"]["needLogFileClientUrl"]
if len(args) > 2:
customer_uploadTrainFile(**args)
except Exception as e:
logger.info(f"{e}")
else:
logger.error(f"{result['code']}. Please check the ``~/.revive/config.yaml`` file for the configuration.")
sys.exit()
importlib.import_module(f'revive.dist.algo.venv.revive_p')
logger.info(f"Import encryption venv algorithm module -> revive_p!")