View source on GitHub
|
Approximates the AUC (Area under the curve) of the ROC or PR curves.
Inherits From: Metric
tf.keras.metrics.AUC(
num_thresholds=200,
curve='ROC',
summation_method='interpolation',
name=None,
dtype=None,
thresholds=None,
multi_label=False,
num_labels=None,
label_weights=None,
from_logits=False
)
Used in the notebooks
| Used in the tutorials |
|---|
The AUC (Area under the curve) of the ROC (Receiver operating characteristic; default) or PR (Precision Recall) curves are quality measures of binary classifiers. Unlike the accuracy, and like cross-entropy losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
This class approximates AUCs using a Riemann sum. During the metric accumulation phrase, predictions are accumulated within predefined buckets by value. The AUC is then computed by interpolating per-bucket averages. These buckets define the evaluated operational points.
This metric creates four local variables, true_positives,
true_negatives, false_positives and false_negatives that are used to
compute the AUC. To discretize the AUC curve, a linearly spaced set of
thresholds is used to compute pairs of recall and precision values. The area
under the ROC-curve is therefore computed using the height of the recall
values by the false positive rate, while the area under the PR-curve is the
computed using the height of the precision values by the recall.
This value is ultimately returned as auc, an idempotent operation that
computes the area under a discretized curve of precision versus recall
values (computed using the aforementioned variables). The num_thresholds
variable controls the degree of discretization with larger numbers of
thresholds more closely approximating the true AUC. The quality of the
approximation may vary dramatically depending on num_thresholds. The
thresholds parameter can be used to manually specify thresholds which
split the predictions more evenly.
For a best approximation of the real AUC, predictions should be
distributed approximately uniformly in the range [0, 1] (if
from_logits=False). The quality of the AUC approximation may be poor if
this is not the case. Setting summation_method to 'minoring' or 'majoring'
can help quantify the error in the approximation by providing lower or upper
bound estimate of the AUC.
If sample_weight is None, weights default to 1.
Use sample_weight of 0 to mask values.
Args | |
|---|---|
num_thresholds
|
(Optional) The number of thresholds to
use when discretizing the roc curve. Values must be > 1.
Defaults to 200.
|
curve
|
(Optional) Specifies the name of the curve to be computed,
'ROC' (default) or 'PR' for the Precision-Recall-curve.
|
summation_method
|
(Optional) Specifies the Riemann summation method used.
'interpolation' (default) applies mid-point summation scheme for
ROC. For PR-AUC, interpolates (true/false) positives but not
the ratio that is precision (see Davis & Goadrich 2006 for
details); 'minoring' applies left summation for increasing
intervals and right summation for decreasing intervals; 'majoring'
does the opposite.
|
name
|
(Optional) string name of the metric instance. |
dtype
|
(Optional) data type of the metric result. |
thresholds
|
(Optional) A list of floating point values to use as the
thresholds for discretizing the curve. If set, the num_thresholds
parameter is ignored. Values should be in [0, 1]. Endpoint
thresholds equal to {-epsilon, 1+epsilon} for a small positive
epsilon value will be automatically included with these to correctly
handle predictions equal to exactly 0 or 1.
|
multi_label
|
boolean indicating whether multilabel data should be
treated as such, wherein AUC is computed separately for each label
and then averaged across labels, or (when False) if the data
should be flattened into a single label before AUC computation. In
the latter case, when multilabel data is passed to AUC, each
label-prediction pair is treated as an individual data point. Should
be set to False for multi-class data.
|
num_labels
|
(Optional) The number of labels, used when multi_label is
True. If num_labels is not specified, then state variables get
created on the first call to update_state.
|
label_weights
|
(Optional) list, array, or tensor of non-negative weights
used to compute AUCs for multilabel data. When multi_label is
True, the weights are applied to the individual label AUCs when they
are averaged to produce the multi-label AUC. When it's False, they
are used to weight the individual label predictions in computing the
confusion matrix on the flattened data. Note that this is unlike
class_weights in that class_weights weights the example
depending on the value of its label, whereas label_weights depends
only on the index of that label before flattening; therefore
label_weights should not be used for multi-class data.
|
from_logits
|
boolean indicating whether the predictions (y_pred in
update_state) are probabilities or sigmoid logits. As a rule of thumb,
when using a keras loss, the from_logits constructor argument of the
loss should match the AUC from_logits constructor argument.
|
Example:
m = keras.metrics.AUC(num_thresholds=3)m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])# threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]# tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]# auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0)))# = 0.75m.result()0.75
m.reset_state()m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],sample_weight=[1, 0, 0, 1])m.result()1.0
Usage with compile() API:
# Reports the AUC of a model outputting a probability.
model.compile(optimizer='sgd',
loss=keras.losses.BinaryCrossentropy(),
metrics=[keras.metrics.AUC()])
# Reports the AUC of a model outputting a logit.
model.compile(optimizer='sgd',
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.AUC(from_logits=True)])
Attributes | |
|---|---|
dtype
|
|
thresholds
|
The thresholds used for evaluating AUC. |
variables
|
|
Methods
add_variable
add_variable(
shape, initializer, dtype=None, aggregation='sum', name=None
)
add_weight
add_weight(
shape=(), initializer=None, dtype=None, name=None
)
from_config
@classmethodfrom_config( config )
get_config
get_config()
Return the serializable config of the metric.
interpolate_pr_auc
interpolate_pr_auc()
Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
https://www.biostat.wisc.edu/~page/rocpr.pdf
Note here we derive & use a closed formula not present in the paper as follows:
Precision = TP / (TP + FP) = TP / P
Modeling all of TP (true positive), FP (false positive) and their sum P = TP + FP (predicted positive) as varying linearly within each interval [A, B] between successive thresholds, we get
Precision slope = dTP / dP
= (TP_B - TP_A) / (P_B - P_A)
= (TP - TP_A) / (P - P_A)
Precision = (TP_A + slope * (P - P_A)) / P
The area within the interval is (slope / total_pos_weight) times
int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
where dTP == TP_B - TP_A.
Note that when P_A == 0 the above calculation simplifies into
int_A^B{Precision.dTP} = int_A^B{slope * dTP}
= slope * (TP_B - TP_A)
which is really equivalent to imputing constant precision throughout the first bucket having >0 true positives.
| Returns | |
|---|---|
pr_auc
|
an approximation of the area under the P-R curve. |
reset_state
reset_state()
Reset all of the metric state variables.
This function is called between epochs/steps, when a metric is evaluated during training.
result
result()
Compute the current metric value.
| Returns | |
|---|---|
| A scalar tensor, or a dictionary of scalar tensors. |
stateless_reset_state
stateless_reset_state()
stateless_result
stateless_result(
metric_variables
)
stateless_update_state
stateless_update_state(
metric_variables, *args, **kwargs
)
update_state
update_state(
y_true, y_pred, sample_weight=None
)
Accumulates confusion matrix statistics.
| Args | |
|---|---|
y_true
|
The ground truth values. |
y_pred
|
The predicted values. |
sample_weight
|
Optional weighting of each example. Can
be a tensor whose rank is either 0, or the same rank as
y_true, and must be broadcastable to y_true. Defaults to
1.
|
__call__
__call__(
*args, **kwargs
)
Call self as a function.
View source on GitHub