Decision Tree Writer
This package allows you to train a binary classification decision tree on a list of labeled dictionaries or class instances, and then writes a new .py file with the code for the new decision tree model.
Installation
Simply run py -m pip install decision-tree-writer
from the command line (Windows) or python3 -m pip install decision-tree-writer
(Unix/macOS) and you're ready to go!
Usage
1) Train the model
Use the DecisionTreeWriter class to train a model on a data set and write the code to a new file in a specified fie folder (default folder is the same as your code):
from decision_tree_writer import DecisionTreeWriter
# Here we're using some of the famous iris data set for an example.
# You could alternatively make an Iris class with the same
# attributes as the keys of each of these dictionaries.
iris_data = [
{ "species": "setosa", "sepal_length": 5.2, "sepal_width": 3.5,
"petal_length": 1.5, "petal_width": 0.2},
{ "species": "setosa", "sepal_length": 5.2, "sepal_width": 4.1,
"petal_length": 1.5, "petal_width": 0.1},
{ "species": "setosa", "sepal_length": 5.4, "sepal_width": 3.7,
"petal_length": 1.5, "petal_width": 0.2},
{ "species": "versicolor", "sepal_length": 6.2, "sepal_width": 2.2,
"petal_length": 4.5, "petal_width": 1.5},
{ "species": "versicolor", "sepal_length": 5.7, "sepal_width": 2.9,
"petal_length": 4.2, "petal_width": 1.3},
{ "species": "versicolor", "sepal_length": 5.6, "sepal_width": 2.9,
"petal_length": 3.6, "petal_width": 1.3},
{ "species": "virginica", "sepal_length": 7.2, "sepal_width": 3.2,
"petal_length": 6.0, "petal_width": 1.8},
{ "species": "virginica", "sepal_length": 6.1, "sepal_width": 2.6,
"petal_length": 5.6, "petal_width": 1.4},
{ "species": "virginica", "sepal_length": 6.8, "sepal_width": 3.0,
"petal_length": 5.5, "petal_width": 2.1}
]
# Create the writer.
# You must specify which attribute or key is the label of the data items.
# You can also specify the max branching depth of the tree (default [and max] is 998)
# or how many data items there must be to make a new branch (default is 1).
writer = DecisionTreeWriter(label_name="species")
# Trains a new model and saves it to a new .py file
writer.create_tree(iris_data, True, "Iris Classifier")
2) Using the new decision tree
In the specified file folder a new python file with one function will appear. It will have the name you gave your decision tree model plus a uuid to ensure it has a unique name. The generated code will look something like this:
from decision_tree_writer.BaseDecisionTree import *
# class-like syntax because it acts like it's instantiating a class.
def IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d() -> 'BaseDecisionTree':
"""
IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d
has been trained to identify the species of a given object.
"""
tree = BaseDecisionTree(None, dict,
'IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d')
tree.root = Branch(lambda x: x['sepal_length'] <= 5.5)
tree.root.l = Leaf('setosa')
tree.root.r = Branch(lambda x: x['petal_length'] <= 5.0)
tree.root.r.l = Leaf('versicolor')
tree.root.r.r = Leaf('virginica')
return tree
Important note: if you train your model with class instance data you will have to import that class in the new file. That might look like:
from decision_tree_writer.BaseDecisionTree import *
from wherever import Iris
def IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d() -> 'BaseDecisionTree':
tree = BaseDecisionTree(None, Iris,
'IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d')
Now just use the factory function to create an instance of the model. The model has two important methods, classify_one
, which takes a data item of the same type as you trained the model with and returns what it thinks is the correct label for it, and classify_many
, which does the same as the first but with a list of data and returns a list of labels.
Example:
tree = IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d()
print(tree.classify_one(
{ "sepal_length": 5.4, "sepal_width": 3.2,
"petal_length": 1.6, "petal_width": 0.3})) # output: 'setosa'
Bugs or questions
If you find any problems with this package of have any questions, please create an issue on this package's GitHub repo