Coverage for src/m5py/export.py: 3%
136 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-22 17:10 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-22 17:10 +0000
1from numbers import Integral
3import numpy as np
4from io import StringIO
6from sklearn.tree import _tree
7from sklearn.tree._criterion import FriedmanMSE
9from m5py.main import is_leaf, ConstantLeafModel, M5Base, check_is_fitted
12def export_text_m5(decision_tree, out_file=None, max_depth=None,
13 feature_names=None, class_names=None, label='all',
14 target_name=None,
15 # filled=False, leaves_parallel=False,
16 impurity=True,
17 node_ids=False, proportion=False,
18 # rounded=False, rotate=False,
19 special_characters=False, precision=3, **kwargs):
20 """Export a decision tree in TXT format.
22 Note: this should be merged with ._export.export_text
24 Inspired by WEKA and by
25 >>> from sklearn.tree import export_graphviz
27 This function generates a human-readable, text representation of the
28 decision tree, which is then written into `out_file`.
30 The sample counts that are shown are weighted with any sample_weights that
31 might be present.
33 Read more in the :ref:`User Guide <tree>`.
35 Parameters
36 ----------
37 decision_tree : decision tree classifier
38 The decision tree to be exported to text.
40 out_file : file object or string, optional (default='tree.dot')
41 Handle or name of the output file. If ``None``, the result is
42 returned as a string.
44 max_depth : int, optional (default=None)
45 The maximum depth of the representation. If None, the tree is fully
46 generated.
48 feature_names : list of strings, optional (default=None)
49 Names of each of the features.
51 class_names : list of strings, bool or None, optional (default=None)
52 Names of each of the target classes in ascending numerical order.
53 Only relevant for classification and not supported for multi-output.
54 If ``True``, shows a symbolic representation of the class name.
56 label : {'all', 'root', 'none'}, optional (default='all')
57 Whether to show informative labels for impurity, etc.
58 Options include 'all' to show at every node, 'root' to show only at
59 the top root node, or 'none' to not show at any node.
61 target_name : optional string with the target name. If not provided, the
62 target will not be displayed in the equations
64 impurity : bool, optional (default=True)
65 When set to ``True``, show the impurity at each node.
67 node_ids : bool, optional (default=False)
68 When set to ``True``, show the ID number on each node.
70 proportion : bool, optional (default=False)
71 When set to ``True``, change the display of 'values' and/or 'samples'
72 to be proportions and percentages respectively.
74 special_characters : bool, optional (default=False)
75 When set to ``False``, ignore special characters for PostScript
76 compatibility.
78 precision : int, optional (default=3)
79 Number of digits of precision for floating point in the values of
80 impurity, threshold and value attributes of each node.
82 kwargs : other keyword arguments for the linear model printer
84 Returns
85 -------
86 dot_data : string
87 String representation of the input tree in GraphViz dot format.
88 Only returned if ``out_file`` is None.
90 Examples
91 --------
92 >>> from sklearn.datasets import load_iris
93 >>> from sklearn import tree
95 >>> clf = tree.DecisionTreeClassifier()
96 >>> iris = load_iris()
98 >>> clf = clf.fit(iris.data, iris.target)
99 >>> tree_to_text(clf, out_file='tree.txt') # doctest: +SKIP
101 """
103 models = []
105 def add_model(node_model):
106 models.append(node_model)
107 return len(models)
109 def node_to_str(tree, node_id, criterion, node_models=None):
110 """ Generates the node content string """
112 # Should labels be shown?
113 labels = (label == 'root' and node_id == 0) or label == 'all'
115 # PostScript compatibility for special characters
116 if special_characters:
117 characters = ['#', '<SUB>', '</SUB>', '≤', '<br/>', '>']
118 node_string = '<'
119 else:
120 characters = ['#', '[', ']', '<=', '\\n', '']
121 node_string = ''
123 # -- If this node is not a leaf, Write the split decision criteria (x <= y)
124 leaf = is_leaf(node_id, tree)
125 if not leaf:
126 if feature_names is not None:
127 feature = feature_names[tree.feature[node_id]]
128 else:
129 feature = "X%s%s%s" % (characters[1],
130 tree.feature[node_id], # feature id for the split
131 characters[2])
132 node_string += '%s %s %s' % (feature,
133 characters[3], # <=
134 round(tree.threshold[node_id], # threshold for the split
135 precision))
136 else:
137 node_string += 'LEAF'
139 # Node details - start bracket [
140 node_string += ' %s' % characters[1]
142 # -- Write impurity
143 if impurity:
144 if isinstance(criterion, FriedmanMSE):
145 criterion = "friedman_mse"
146 elif not isinstance(criterion, str):
147 criterion = "impurity"
148 if labels:
149 node_string += '%s=' % criterion
150 node_string += str(round(tree.impurity[node_id], precision)) + ', '
152 # -- Write node sample count
153 if labels:
154 node_string += 'samples='
155 if proportion:
156 percent = (100. * tree.n_node_samples[node_id] /
157 float(tree.n_node_samples[0]))
158 node_string += str(round(percent, 1)) + '%'
159 else:
160 node_string += str(tree.n_node_samples[node_id])
162 # Node details - end bracket ]
163 node_string += '%s' % characters[2]
165 # -- Write node class distribution / regression value
166 if tree.n_outputs == 1:
167 value = tree.value[node_id][0, :]
168 else:
169 value = tree.value[node_id]
171 if proportion and tree.n_classes[0] != 1:
172 # For classification this will show the proportion of samples
173 value = value / tree.weighted_n_node_samples[node_id]
174 if tree.n_classes[0] == 1:
175 # Regression
176 value_text = np.around(value, precision)
177 elif proportion:
178 # Classification
179 value_text = np.around(value, precision)
180 elif np.all(np.equal(np.mod(value, 1), 0)):
181 # Classification without floating-point weights
182 value_text = value.astype(int)
183 else:
184 # Classification with floating-point weights
185 value_text = np.around(value, precision)
187 # Strip whitespace
188 value_text = str(value_text.astype('S32')).replace("b'", "'")
189 value_text = value_text.replace("' '", ", ").replace("'", "")
190 if tree.n_classes[0] == 1 and tree.n_outputs == 1:
191 value_text = value_text.replace("[", "").replace("]", "")
192 value_text = value_text.replace("\n ", characters[4])
194 if node_models is None:
195 node_string += ' : '
196 if labels:
197 node_string += 'value='
198 else:
199 nodemodel = node_models[node_id]
200 model_err_val = np.around(nodemodel.error, precision)
201 if leaf:
202 if isinstance(nodemodel, ConstantLeafModel):
203 # the model does not contain the value. rely on the value_text computed from tree
204 value_text = " : %s (err=%s, params=%s)" % (value_text, model_err_val, nodemodel.n_params)
205 else:
206 # put the model in the stack, we'll write it later
207 model_id = add_model(nodemodel)
208 value_text = " : LM%s (err=%s, params=%s)" % (model_id, model_err_val, nodemodel.n_params)
209 else:
210 # replace the value text with error at this node and number of parameters
211 value_text = " (err=%s, params=%s)" % (model_err_val, nodemodel.n_params)
213 node_string += value_text
215 # Write node majority class
216 if (class_names is not None and
217 tree.n_classes[0] != 1 and
218 tree.n_outputs == 1):
219 # Only done for single-output classification trees
220 node_string += ', '
221 if labels:
222 node_string += 'class='
223 if class_names is not True:
224 class_name = class_names[np.argmax(value)]
225 else:
226 class_name = "y%s%s%s" % (characters[1],
227 np.argmax(value),
228 characters[2])
229 node_string += class_name
231 return node_string + characters[5]
233 def recurse(tree, node_id, criterion, parent=None, depth=0, node_models=None):
234 if node_id == _tree.TREE_LEAF:
235 raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF)
237 # Add node with description
238 if max_depth is None or depth <= max_depth:
239 indent_str = ("| " * depth)
240 if node_ids:
241 out_file.write('%d| %s%s\n' % (node_id, indent_str, node_to_str(tree, node_id, criterion,
242 node_models=node_models)))
243 else:
244 out_file.write('%s%s\n' % (indent_str, node_to_str(tree, node_id, criterion, node_models=node_models)))
246 # Recurse on Children if needed
247 left_child = tree.children_left[node_id]
248 right_child = tree.children_right[node_id]
250 # if not is_leaf(node_id, tree)
251 if left_child != _tree.TREE_LEAF:
252 # that means that node_id is not a leaf (see is_leaf() below.): recurse on children
253 recurse(tree, left_child, criterion=criterion, parent=node_id, depth=depth + 1,
254 node_models=node_models)
255 recurse(tree, right_child, criterion=criterion, parent=node_id, depth=depth + 1,
256 node_models=node_models)
258 else:
259 ranks['leaves'].append(str(node_id))
260 out_file.write('%d| (...)\n')
262 def write_models(models):
263 for i, model in enumerate(models):
264 out_file.write("LM%s: %s\n" % (i + 1, model.to_text(feature_names=feature_names, precision=precision,
265 target_name=target_name, **kwargs)))
267 # Main
268 check_is_fitted(decision_tree, 'tree_')
269 own_file = False
270 return_string = False
271 try:
272 if isinstance(out_file, str):
273 out_file = open(out_file, "w", encoding="utf-8")
274 own_file = True
276 if out_file is None:
277 return_string = True
278 out_file = StringIO()
280 if isinstance(precision, Integral):
281 if precision < 0:
282 raise ValueError("'precision' should be greater or equal to 0."
283 " Got {} instead.".format(precision))
284 else:
285 raise ValueError("'precision' should be an integer. Got {}"
286 " instead.".format(type(precision)))
288 # Check length of feature_names before getting into the tree node
289 # Raise error if length of feature_names does not match
290 # n_features_ in the decision_tree
291 if feature_names is not None:
292 if len(feature_names) != decision_tree.n_features_in_:
293 raise ValueError("Length of feature_names, %d "
294 "does not match number of features, %d"
295 % (len(feature_names),
296 decision_tree.n_features_in_))
298 # The depth of each node for plotting with 'leaf' option TODO probably remove
299 ranks = {'leaves': []}
301 # Tree title
302 if isinstance(decision_tree, M5Base):
303 if hasattr(decision_tree, 'installed_smoothing_constant'):
304 details = "pre-smoothed with constant %s" % decision_tree.installed_smoothing_constant
305 else:
306 if decision_tree.use_smoothing == 'installed':
307 details = "under construction - not pre-smoothed yet"
308 else:
309 details = "unsmoothed - but this can be done at prediction time"
311 # add more info or M5P
312 out_file.write('%s (%s):\n' % (type(decision_tree).__name__, details))
313 else:
314 # generic title
315 out_file.write('%s :\n' % type(decision_tree).__name__)
317 # some space for readability
318 out_file.write('\n')
320 # Now recurse the tree and add node & edge attributes
321 if isinstance(decision_tree, _tree.Tree):
322 recurse(decision_tree, 0, criterion="impurity")
323 elif isinstance(decision_tree, M5Base) and hasattr(decision_tree, 'node_models'):
324 recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion, node_models=decision_tree.node_models)
326 # extra step: write all models
327 out_file.write("\n")
328 write_models(models)
329 else:
330 recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)
332 # Return the text if needed
333 if return_string:
334 return out_file.getvalue()
336 finally:
337 if own_file:
338 out_file.close()