generated from databricks-industry-solutions/industry-solutions-blueprints
-
Notifications
You must be signed in to change notification settings - Fork 7
/
06_DNS_Analytics_ScoreDomain.py
66 lines (42 loc) · 2.17 KB
/
06_DNS_Analytics_ScoreDomain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Databricks notebook source
# MAGIC %md
# MAGIC You may find this series of notebooks at https://github.com/databricks-industry-solutions/dns-analytics. For more information about this solution accelerator, visit https://www.databricks.com/solutions/accelerators/threat-detection.
# COMMAND ----------
# MAGIC %md Read Parameterized inputs
# COMMAND ----------
dbutils.widgets.removeAll()
dbutils.widgets.text("DomainName","","01. Domain to be scored")
# COMMAND ----------
domain=dbutils.widgets.get("DomainName")
# COMMAND ----------
# MAGIC %md Download Databricks Trained DGA Detection model file for scoring
# COMMAND ----------
# MAGIC %sh
# MAGIC if [ ! -d /tmp/dga_model ]; then
# MAGIC mkdir -p /tmp/dga_model
# MAGIC curl -o /tmp/dga_model/python_model.pkl https://raw.githubusercontent.com/zaferbil/dns-notebook-datasets/master/model/python_model.pkl
# MAGIC curl -o /tmp/dga_model/MLmodel https://raw.githubusercontent.com/zaferbil/dns-notebook-datasets/master/model/MLmodel
# MAGIC curl -o /tmp/dga_model/conda.yaml https://raw.githubusercontent.com/zaferbil/dns-notebook-datasets/master/model/conda.yaml
# MAGIC fi
# COMMAND ----------
# MAGIC %md Load the model using mlflow
# COMMAND ----------
# Load the DGA model.
# this is an optimization to not to reload model on evey invocation!
import json
ctx = json.loads(dbutils.notebook.entry_point.getDbutils().notebook().getContext().toJson())
if spark.conf.get(f"dga_model_is_loaded_{ctx['extraContext']['notebook_path']}", "false") == "false":
import mlflow
import mlflow.pyfunc
# you can change to your own path copied from the output of 4th notebook
model_path = 'dbfs:/FileStore/tables/dga_model'
dbutils.fs.cp("file:/tmp/dga_model/", model_path, True)
print(f"loading model from {model_path}")
loaded_model = mlflow.pyfunc.load_model(model_path)
spark.conf.set(f"dga_model_is_loaded_{ctx['extraContext']['notebook_path']}", "true")
# COMMAND ----------
# MAGIC %md Score the domain name with the function.
# COMMAND ----------
print(f'Score for Domain {domain} is : {loaded_model.predict(domain)}')
# COMMAND ----------
#print(f'Test Execution for google.com: {loaded_model.predict("google.com")}')