Regression tree#
- class EarlyStopping.RegressionTree(design: array, response: array, min_samples_split, true_signal=None, true_noise_vector=None)#
A class for constructing a regression tree and computing theoretical quantities.
Parameters
design:
array
. The design matrix for the regression tree.response:
array
. The response variable values.min_samples_split:
int
. The number of samples that the terminal node can have at maximum.true_signal:
array or None
. Used only in simulation contexts for computing theoretical quantities.true_noise_vector:
array or None
. Used only in simulation contexts for theoretical quantities.Attributes
sample_size:
int
. The number of samples in the data.dimension:
int
. The number of variables in the design matrix.regression_tree:
Node
. The root node of the regression tree.residuals:
array
. Stores the mean squared error residuals at each level of the tree.bias2:
array
. Stores the squared bias values at each level of the tree.variance:
array
. Stores the variance values at each level of the tree.risk:
array
. Risk based on bias and variance.Methods
iterate(
max_depth=None
)Grows a regression tree up to the specified depth.
predict(
design, depth
)Predicts target values using the regression tree at a given depth.
get_discrepancy_stop(
crit, max_depth
)Finds the first generation satisfying the discrepancy principle.
get_balanced_oracle(
max_depth
)Computes the balanced oracle generation.
- RegressionTree.iterate(max_depth: int = None)#
Grows the regression tree up to the specified depth.
Parameters
max_depth:
int or None
. Maximum depth to which the tree should grow. If None, the tree grows fully.
- RegressionTree.get_balanced_oracle(max_depth=None)#
Computes the balanced oracle iteration based on the bias and variance.
Parameters
max_depth:
None
. Maximum depth for the tree if it has not been grown yet.
- RegressionTree.get_discrepancy_stop(critical_value, max_depth=None)#
Finds the first generation where the discrepancy principle is met.
Parameters
critical_value:
float
. Threshold for discrepancy-based stopping.max_depth:
None
. Maximum depth for the tree if it has not been grown yet.
- RegressionTree.predict(design: DataFrame | ndarray, depth: int)#
Predicts target values for the given design data using the decision tree at the specified depth.
Parameters
design:
array
. Input design matrix for predictions.depth:
int
. Depth level of the tree to use for predictions. If 0, returns the unconditional mean.