Coverage for src/m5py/export.py: 3%

136 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-22 17:10 +0000

1from numbers import Integral 

2 

3import numpy as np 

4from io import StringIO 

5 

6from sklearn.tree import _tree 

7from sklearn.tree._criterion import FriedmanMSE 

8 

9from m5py.main import is_leaf, ConstantLeafModel, M5Base, check_is_fitted 

10 

11 

12def export_text_m5(decision_tree, out_file=None, max_depth=None, 

13 feature_names=None, class_names=None, label='all', 

14 target_name=None, 

15 # filled=False, leaves_parallel=False, 

16 impurity=True, 

17 node_ids=False, proportion=False, 

18 # rounded=False, rotate=False, 

19 special_characters=False, precision=3, **kwargs): 

20 """Export a decision tree in TXT format. 

21 

22 Note: this should be merged with ._export.export_text 

23 

24 Inspired by WEKA and by 

25 >>> from sklearn.tree import export_graphviz 

26 

27 This function generates a human-readable, text representation of the 

28 decision tree, which is then written into `out_file`. 

29 

30 The sample counts that are shown are weighted with any sample_weights that 

31 might be present. 

32 

33 Read more in the :ref:`User Guide <tree>`. 

34 

35 Parameters 

36 ---------- 

37 decision_tree : decision tree classifier 

38 The decision tree to be exported to text. 

39 

40 out_file : file object or string, optional (default='tree.dot') 

41 Handle or name of the output file. If ``None``, the result is 

42 returned as a string. 

43 

44 max_depth : int, optional (default=None) 

45 The maximum depth of the representation. If None, the tree is fully 

46 generated. 

47 

48 feature_names : list of strings, optional (default=None) 

49 Names of each of the features. 

50 

51 class_names : list of strings, bool or None, optional (default=None) 

52 Names of each of the target classes in ascending numerical order. 

53 Only relevant for classification and not supported for multi-output. 

54 If ``True``, shows a symbolic representation of the class name. 

55 

56 label : {'all', 'root', 'none'}, optional (default='all') 

57 Whether to show informative labels for impurity, etc. 

58 Options include 'all' to show at every node, 'root' to show only at 

59 the top root node, or 'none' to not show at any node. 

60 

61 target_name : optional string with the target name. If not provided, the 

62 target will not be displayed in the equations 

63 

64 impurity : bool, optional (default=True) 

65 When set to ``True``, show the impurity at each node. 

66 

67 node_ids : bool, optional (default=False) 

68 When set to ``True``, show the ID number on each node. 

69 

70 proportion : bool, optional (default=False) 

71 When set to ``True``, change the display of 'values' and/or 'samples' 

72 to be proportions and percentages respectively. 

73 

74 special_characters : bool, optional (default=False) 

75 When set to ``False``, ignore special characters for PostScript 

76 compatibility. 

77 

78 precision : int, optional (default=3) 

79 Number of digits of precision for floating point in the values of 

80 impurity, threshold and value attributes of each node. 

81 

82 kwargs : other keyword arguments for the linear model printer 

83 

84 Returns 

85 ------- 

86 dot_data : string 

87 String representation of the input tree in GraphViz dot format. 

88 Only returned if ``out_file`` is None. 

89 

90 Examples 

91 -------- 

92 >>> from sklearn.datasets import load_iris 

93 >>> from sklearn import tree 

94 

95 >>> clf = tree.DecisionTreeClassifier() 

96 >>> iris = load_iris() 

97 

98 >>> clf = clf.fit(iris.data, iris.target) 

99 >>> tree_to_text(clf, out_file='tree.txt') # doctest: +SKIP 

100 

101 """ 

102 

103 models = [] 

104 

105 def add_model(node_model): 

106 models.append(node_model) 

107 return len(models) 

108 

109 def node_to_str(tree, node_id, criterion, node_models=None): 

110 """ Generates the node content string """ 

111 

112 # Should labels be shown? 

113 labels = (label == 'root' and node_id == 0) or label == 'all' 

114 

115 # PostScript compatibility for special characters 

116 if special_characters: 

117 characters = ['&#35;', '<SUB>', '</SUB>', '&le;', '<br/>', '>'] 

118 node_string = '<' 

119 else: 

120 characters = ['#', '[', ']', '<=', '\\n', ''] 

121 node_string = '' 

122 

123 # -- If this node is not a leaf, Write the split decision criteria (x <= y) 

124 leaf = is_leaf(node_id, tree) 

125 if not leaf: 

126 if feature_names is not None: 

127 feature = feature_names[tree.feature[node_id]] 

128 else: 

129 feature = "X%s%s%s" % (characters[1], 

130 tree.feature[node_id], # feature id for the split 

131 characters[2]) 

132 node_string += '%s %s %s' % (feature, 

133 characters[3], # <= 

134 round(tree.threshold[node_id], # threshold for the split 

135 precision)) 

136 else: 

137 node_string += 'LEAF' 

138 

139 # Node details - start bracket [ 

140 node_string += ' %s' % characters[1] 

141 

142 # -- Write impurity 

143 if impurity: 

144 if isinstance(criterion, FriedmanMSE): 

145 criterion = "friedman_mse" 

146 elif not isinstance(criterion, str): 

147 criterion = "impurity" 

148 if labels: 

149 node_string += '%s=' % criterion 

150 node_string += str(round(tree.impurity[node_id], precision)) + ', ' 

151 

152 # -- Write node sample count 

153 if labels: 

154 node_string += 'samples=' 

155 if proportion: 

156 percent = (100. * tree.n_node_samples[node_id] / 

157 float(tree.n_node_samples[0])) 

158 node_string += str(round(percent, 1)) + '%' 

159 else: 

160 node_string += str(tree.n_node_samples[node_id]) 

161 

162 # Node details - end bracket ] 

163 node_string += '%s' % characters[2] 

164 

165 # -- Write node class distribution / regression value 

166 if tree.n_outputs == 1: 

167 value = tree.value[node_id][0, :] 

168 else: 

169 value = tree.value[node_id] 

170 

171 if proportion and tree.n_classes[0] != 1: 

172 # For classification this will show the proportion of samples 

173 value = value / tree.weighted_n_node_samples[node_id] 

174 if tree.n_classes[0] == 1: 

175 # Regression 

176 value_text = np.around(value, precision) 

177 elif proportion: 

178 # Classification 

179 value_text = np.around(value, precision) 

180 elif np.all(np.equal(np.mod(value, 1), 0)): 

181 # Classification without floating-point weights 

182 value_text = value.astype(int) 

183 else: 

184 # Classification with floating-point weights 

185 value_text = np.around(value, precision) 

186 

187 # Strip whitespace 

188 value_text = str(value_text.astype('S32')).replace("b'", "'") 

189 value_text = value_text.replace("' '", ", ").replace("'", "") 

190 if tree.n_classes[0] == 1 and tree.n_outputs == 1: 

191 value_text = value_text.replace("[", "").replace("]", "") 

192 value_text = value_text.replace("\n ", characters[4]) 

193 

194 if node_models is None: 

195 node_string += ' : ' 

196 if labels: 

197 node_string += 'value=' 

198 else: 

199 nodemodel = node_models[node_id] 

200 model_err_val = np.around(nodemodel.error, precision) 

201 if leaf: 

202 if isinstance(nodemodel, ConstantLeafModel): 

203 # the model does not contain the value. rely on the value_text computed from tree 

204 value_text = " : %s (err=%s, params=%s)" % (value_text, model_err_val, nodemodel.n_params) 

205 else: 

206 # put the model in the stack, we'll write it later 

207 model_id = add_model(nodemodel) 

208 value_text = " : LM%s (err=%s, params=%s)" % (model_id, model_err_val, nodemodel.n_params) 

209 else: 

210 # replace the value text with error at this node and number of parameters 

211 value_text = " (err=%s, params=%s)" % (model_err_val, nodemodel.n_params) 

212 

213 node_string += value_text 

214 

215 # Write node majority class 

216 if (class_names is not None and 

217 tree.n_classes[0] != 1 and 

218 tree.n_outputs == 1): 

219 # Only done for single-output classification trees 

220 node_string += ', ' 

221 if labels: 

222 node_string += 'class=' 

223 if class_names is not True: 

224 class_name = class_names[np.argmax(value)] 

225 else: 

226 class_name = "y%s%s%s" % (characters[1], 

227 np.argmax(value), 

228 characters[2]) 

229 node_string += class_name 

230 

231 return node_string + characters[5] 

232 

233 def recurse(tree, node_id, criterion, parent=None, depth=0, node_models=None): 

234 if node_id == _tree.TREE_LEAF: 

235 raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF) 

236 

237 # Add node with description 

238 if max_depth is None or depth <= max_depth: 

239 indent_str = ("| " * depth) 

240 if node_ids: 

241 out_file.write('%d| %s%s\n' % (node_id, indent_str, node_to_str(tree, node_id, criterion, 

242 node_models=node_models))) 

243 else: 

244 out_file.write('%s%s\n' % (indent_str, node_to_str(tree, node_id, criterion, node_models=node_models))) 

245 

246 # Recurse on Children if needed 

247 left_child = tree.children_left[node_id] 

248 right_child = tree.children_right[node_id] 

249 

250 # if not is_leaf(node_id, tree) 

251 if left_child != _tree.TREE_LEAF: 

252 # that means that node_id is not a leaf (see is_leaf() below.): recurse on children 

253 recurse(tree, left_child, criterion=criterion, parent=node_id, depth=depth + 1, 

254 node_models=node_models) 

255 recurse(tree, right_child, criterion=criterion, parent=node_id, depth=depth + 1, 

256 node_models=node_models) 

257 

258 else: 

259 ranks['leaves'].append(str(node_id)) 

260 out_file.write('%d| (...)\n') 

261 

262 def write_models(models): 

263 for i, model in enumerate(models): 

264 out_file.write("LM%s: %s\n" % (i + 1, model.to_text(feature_names=feature_names, precision=precision, 

265 target_name=target_name, **kwargs))) 

266 

267 # Main 

268 check_is_fitted(decision_tree, 'tree_') 

269 own_file = False 

270 return_string = False 

271 try: 

272 if isinstance(out_file, str): 

273 out_file = open(out_file, "w", encoding="utf-8") 

274 own_file = True 

275 

276 if out_file is None: 

277 return_string = True 

278 out_file = StringIO() 

279 

280 if isinstance(precision, Integral): 

281 if precision < 0: 

282 raise ValueError("'precision' should be greater or equal to 0." 

283 " Got {} instead.".format(precision)) 

284 else: 

285 raise ValueError("'precision' should be an integer. Got {}" 

286 " instead.".format(type(precision))) 

287 

288 # Check length of feature_names before getting into the tree node 

289 # Raise error if length of feature_names does not match 

290 # n_features_ in the decision_tree 

291 if feature_names is not None: 

292 if len(feature_names) != decision_tree.n_features_in_: 

293 raise ValueError("Length of feature_names, %d " 

294 "does not match number of features, %d" 

295 % (len(feature_names), 

296 decision_tree.n_features_in_)) 

297 

298 # The depth of each node for plotting with 'leaf' option TODO probably remove 

299 ranks = {'leaves': []} 

300 

301 # Tree title 

302 if isinstance(decision_tree, M5Base): 

303 if hasattr(decision_tree, 'installed_smoothing_constant'): 

304 details = "pre-smoothed with constant %s" % decision_tree.installed_smoothing_constant 

305 else: 

306 if decision_tree.use_smoothing == 'installed': 

307 details = "under construction - not pre-smoothed yet" 

308 else: 

309 details = "unsmoothed - but this can be done at prediction time" 

310 

311 # add more info or M5P 

312 out_file.write('%s (%s):\n' % (type(decision_tree).__name__, details)) 

313 else: 

314 # generic title 

315 out_file.write('%s :\n' % type(decision_tree).__name__) 

316 

317 # some space for readability 

318 out_file.write('\n') 

319 

320 # Now recurse the tree and add node & edge attributes 

321 if isinstance(decision_tree, _tree.Tree): 

322 recurse(decision_tree, 0, criterion="impurity") 

323 elif isinstance(decision_tree, M5Base) and hasattr(decision_tree, 'node_models'): 

324 recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion, node_models=decision_tree.node_models) 

325 

326 # extra step: write all models 

327 out_file.write("\n") 

328 write_models(models) 

329 else: 

330 recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion) 

331 

332 # Return the text if needed 

333 if return_string: 

334 return out_file.getvalue() 

335 

336 finally: 

337 if own_file: 

338 out_file.close()