Regression tree#

class EarlyStopping.RegressionTree(design: array, response: array, min_samples_split, true_signal=None, true_noise_vector=None)#

A class for constructing a regression tree and computing theoretical quantities.

Parameters

design: array. The design matrix for the regression tree.

response: array. The response variable values.

min_samples_split: int. The number of samples that the terminal node can have at maximum.

true_signal: array or None. Used only in simulation contexts for computing theoretical quantities.

true_noise_vector: array or None. Used only in simulation contexts for theoretical quantities.

Attributes

sample_size: int. The number of samples in the data.

dimension: int. The number of variables in the design matrix.

regression_tree: Node. The root node of the regression tree.

residuals: array. Stores the mean squared error residuals at each level of the tree.

bias2: array. Stores the squared bias values at each level of the tree.

variance: array. Stores the variance values at each level of the tree.

risk: array. Risk based on bias and variance.

Methods

iterate( max_depth=None )

Grows a regression tree up to the specified depth.

predict( design, depth )

Predicts target values using the regression tree at a given depth.

get_discrepancy_stop( crit, max_depth)

Finds the first generation satisfying the discrepancy principle.

get_balanced_oracle( max_depth )

Computes the balanced oracle generation.

RegressionTree.iterate(max_depth: int = None)#

Grows the regression tree up to the specified depth.

Parameters

max_depth: int or None. Maximum depth to which the tree should grow. If None, the tree grows fully.

RegressionTree.get_balanced_oracle(max_depth=None)#

Computes the balanced oracle iteration based on the bias and variance.

Parameters

max_depth: None. Maximum depth for the tree if it has not been grown yet.

RegressionTree.get_discrepancy_stop(critical_value, max_depth=None)#

Finds the first generation where the discrepancy principle is met.

Parameters

critical_value: float. Threshold for discrepancy-based stopping.

max_depth: None. Maximum depth for the tree if it has not been grown yet.

RegressionTree.predict(design: DataFrame | ndarray, depth: int)#

Predicts target values for the given design data using the decision tree at the specified depth.

Parameters

design: array. Input design matrix for predictions.

depth: int. Depth level of the tree to use for predictions. If 0, returns the unconditional mean.