Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diabetes_regression/evaluate/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"--model_name",
type=str,
help="Name of the Model",
default="sklearn_regression_model.pkl",
default="diabetes_model.pkl",
)

parser.add_argument(
Expand Down
4 changes: 4 additions & 0 deletions diabetes_regression/parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
"evaluation":
{

},
"registration":
{
"tags": ["mse"]
},
"scoring":
{
Expand Down
40 changes: 30 additions & 10 deletions diabetes_regression/register/register_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ARISING IN ANY WAY OUT OF THE USE OF THE SOFTWARE CODE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""
import json
import os
import sys
import argparse
Expand Down Expand Up @@ -69,8 +70,9 @@ def main():
"--model_name",
type=str,
help="Name of the Model",
default="sklearn_regression_model.pkl",
default="diabetes_model.pkl",
)

parser.add_argument(
"--step_input",
type=str,
Expand All @@ -85,40 +87,58 @@ def main():
model_name = args.model_name
model_path = args.step_input

print("Getting registration parameters")

# Load the registration parameters from the parameters file
with open("parameters.json") as f:
pars = json.load(f)
try:
register_args = pars["registration"]
except KeyError:
print("Could not load registration values from file")
register_args = {"tags": []}

model_tags = {}
for tag in register_args["tags"]:
try:
mtag = run.parent.get_metrics()[tag]
model_tags[tag] = mtag
except KeyError:
print(f"Could not find {tag} metric on parent run.")

# load the model
print("Loading model from " + model_path)
model_file = os.path.join(model_path, model_name)
model = joblib.load(model_file)
model_mse = run.parent.get_metrics()["mse"]
parent_tags = run.parent.get_tags()
try:
build_id = parent_tags["BuildId"]
except KeyError:
build_id = None
print("BuildId tag not found on parent run.")
print("Tags present: {parent_tags}")
print(f"Tags present: {parent_tags}")
try:
build_uri = parent_tags["BuildUri"]
except KeyError:
build_uri = None
print("BuildUri tag not found on parent run.")
print("Tags present: {parent_tags}")
print(f"Tags present: {parent_tags}")

if (model is not None):
dataset_id = parent_tags["dataset_id"]
if (build_id is None):
register_aml_model(
model_file,
model_name,
model_mse,
model_tags,
exp,
run_id,
dataset_id)
elif (build_uri is None):
register_aml_model(
model_file,
model_name,
model_mse,
model_tags,
exp,
run_id,
dataset_id,
Expand All @@ -127,7 +147,7 @@ def main():
register_aml_model(
model_file,
model_name,
model_mse,
model_tags,
exp,
run_id,
dataset_id,
Expand All @@ -152,7 +172,7 @@ def model_already_registered(model_name, exp, run_id):
def register_aml_model(
model_path,
model_name,
model_mse,
model_tags,
exp,
run_id,
dataset_id,
Expand All @@ -162,8 +182,8 @@ def register_aml_model(
try:
tagsValue = {"area": "diabetes_regression",
"run_id": run_id,
"experiment_name": exp.name,
"mse": model_mse}
"experiment_name": exp.name}
tagsValue.update(model_tags)
if (build_id != 'none'):
model_already_registered(model_name, exp, run_id)
tagsValue["BuildId"] = build_id
Expand Down
2 changes: 1 addition & 1 deletion diabetes_regression/training/train_aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main():
"--model_name",
type=str,
help="Name of the Model",
default="sklearn_regression_model.pkl",
default="diabetes_model.pkl",
)

parser.add_argument(
Expand Down