Kernel: Python [conda env:py37]
In [1]:
Model Evaluation and Improvement
In [2]:
Test set score: 0.88
Cross-Validation
In [3]:
Cross-Validation in scikit-learn
In [4]:
Cross-validation scores: [0.961 0.922 0.958]
In [5]:
Cross-validation scores: [1. 0.967 0.933 0.9 1. ]
In [6]:
Average cross-validation score: 0.96
In [7]:
{'fit_time': array([0.001, 0.001, 0.001, 0.001, 0.001]),
'score_time': array([0., 0., 0., 0., 0.]),
'test_score': array([1. , 0.967, 0.933, 0.9 , 1. ]),
'train_score': array([0.95 , 0.967, 0.967, 0.975, 0.958])}
In [8]:
fit_time | score_time | test_score | train_score | |
---|---|---|---|---|
0 | 9.79e-04 | 2.69e-04 | 1.00 | 0.95 |
1 | 7.50e-04 | 2.07e-04 | 0.97 | 0.97 |
2 | 8.71e-04 | 2.95e-04 | 0.93 | 0.97 |
3 | 8.22e-04 | 2.60e-04 | 0.90 | 0.97 |
4 | 1.44e-03 | 4.58e-04 | 1.00 | 0.96 |
Mean times and scores:
fit_time 9.73e-04
score_time 2.98e-04
test_score 9.60e-01
train_score 9.63e-01
dtype: float64
Benefits of Cross-Validation
Stratified K-Fold cross-validation and other strategies
In [9]:
Iris labels:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
In [10]:
More control over cross-validation
In [11]:
In [12]:
Cross-validation scores:
[1. 0.933 0.433 0.967 0.433]
In [13]:
Cross-validation scores:
[0. 0. 0.]
In [14]:
Cross-validation scores:
[0.9 0.96 0.96]
Leave-one-out cross-validation
In [15]:
Number of cv iterations: 150
Mean accuracy: 0.95
Shuffle-split cross-validation
In [16]:
In [17]:
Cross-validation scores:
[0.973 0.933 0.933 0.933 0.92 0.853 0.947 0.813 0.947 0.96 ]
Cross-validation with groups
In [18]:
In [19]:
Cross-validation scores:
[0.75 0.8 0.667]
Grid Search
Simple Grid Search
In [20]:
Size of training set: 112 size of test set: 38
Best score: 0.97
Best parameters: {'C': 100, 'gamma': 0.001}
The danger of overfitting the parameters and the validation set
In [21]:
In [22]:
Size of training set: 84 size of validation set: 28 size of test set: 38
Best score on validation set: 0.96
Best parameters: {'C': 10, 'gamma': 0.001}
Test set score with best parameters: 0.92
Grid Search with Cross-Validation
In [23]:
SVC(C=100, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape='ovr', degree=3, gamma=0.01, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
In [24]:
In [25]:
In [26]:
Parameter grid:
{'C': [0.001, 0.01, 0.1, 1, 10, 100], 'gamma': [0.001, 0.01, 0.1, 1, 10, 100]}
In [27]:
In [28]:
In [29]:
GridSearchCV(cv=5, error_score='raise-deprecating',
estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape='ovr', degree=3, gamma='auto_deprecated',
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False),
iid='warn', n_jobs=None,
param_grid={'C': [0.001, 0.01, 0.1, 1, 10, 100], 'gamma': [0.001, 0.01, 0.1, 1, 10, 100]},
pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
scoring=None, verbose=0)
In [30]:
Test set score: 0.97
In [31]:
Best parameters: {'C': 100, 'gamma': 0.01}
Best cross-validation score: 0.97
In [32]:
Best estimator:
SVC(C=100, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape='ovr', degree=3, gamma=0.01, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
Analyzing the result of cross-validation
In [33]:
mean_fit_time | std_fit_time | mean_score_time | std_score_time | ... | split3_train_score | split4_train_score | mean_train_score | std_train_score | |
---|---|---|---|---|---|---|---|---|---|
0 | 1.40e-03 | 6.85e-04 | 4.92e-04 | 1.69e-04 | ... | 0.37 | 0.36 | 0.37 | 2.85e-03 |
1 | 1.24e-03 | 5.33e-04 | 4.27e-04 | 1.46e-04 | ... | 0.37 | 0.36 | 0.37 | 2.85e-03 |
2 | 1.10e-03 | 3.77e-04 | 6.01e-04 | 4.14e-04 | ... | 0.37 | 0.36 | 0.37 | 2.85e-03 |
3 | 8.72e-04 | 3.81e-05 | 3.39e-04 | 3.83e-05 | ... | 0.37 | 0.36 | 0.37 | 2.85e-03 |
4 | 1.40e-03 | 4.82e-04 | 6.03e-04 | 1.58e-04 | ... | 0.37 | 0.36 | 0.37 | 2.85e-03 |
5 rows × 22 columns
In [34]:
<matplotlib.collections.PolyCollection at 0x7fcb17fd24a8>
In [35]:
<matplotlib.colorbar.Colorbar at 0x7fcb184cf940>
In [36]:
List of grids:
[{'kernel': ['rbf'], 'C': [0.001, 0.01, 0.1, 1, 10, 100], 'gamma': [0.001, 0.01, 0.1, 1, 10, 100]}, {'kernel': ['linear'], 'C': [0.001, 0.01, 0.1, 1, 10, 100]}]
In [37]:
Best parameters: {'C': 100, 'gamma': 0.01, 'kernel': 'rbf'}
Best cross-validation score: 0.97
In [38]:
0 | 1 | 2 | 3 | ... | 38 | 39 | 40 | 41 | |
---|---|---|---|---|---|---|---|---|---|
mean_fit_time | 0.00073 | 0.00071 | 0.00085 | 0.0011 | ... | 0.00033 | 0.0003 | 0.0003 | 0.00031 |
std_fit_time | 5.5e-05 | 4.3e-05 | 0.00028 | 0.00036 | ... | 2.1e-05 | 9.6e-06 | 1.4e-05 | 2.5e-05 |
mean_score_time | 0.00028 | 0.00027 | 0.00033 | 0.00047 | ... | 0.00017 | 0.00016 | 0.00016 | 0.00016 |
std_score_time | 2e-05 | 1.2e-05 | 0.00013 | 0.00025 | ... | 4.2e-07 | 2.3e-06 | 4.1e-06 | 9.9e-07 |
param_C | 0.001 | 0.001 | 0.001 | 0.001 | ... | 0.1 | 1 | 10 | 100 |
param_gamma | 0.001 | 0.01 | 0.1 | 1 | ... | NaN | NaN | NaN | NaN |
param_kernel | rbf | rbf | rbf | rbf | ... | linear | linear | linear | linear |
params | {'C': 0.001, 'gamma': 0.001, 'kernel': 'rbf'} | {'C': 0.001, 'gamma': 0.01, 'kernel': 'rbf'} | {'C': 0.001, 'gamma': 0.1, 'kernel': 'rbf'} | {'C': 0.001, 'gamma': 1, 'kernel': 'rbf'} | ... | {'C': 0.1, 'kernel': 'linear'} | {'C': 1, 'kernel': 'linear'} | {'C': 10, 'kernel': 'linear'} | {'C': 100, 'kernel': 'linear'} |
split0_test_score | 0.38 | 0.38 | 0.38 | 0.38 | ... | 0.96 | 1 | 0.96 | 0.96 |
split1_test_score | 0.35 | 0.35 | 0.35 | 0.35 | ... | 0.91 | 0.96 | 1 | 1 |
split2_test_score | 0.36 | 0.36 | 0.36 | 0.36 | ... | 1 | 1 | 1 | 1 |
split3_test_score | 0.36 | 0.36 | 0.36 | 0.36 | ... | 0.91 | 0.95 | 0.91 | 0.91 |
split4_test_score | 0.38 | 0.38 | 0.38 | 0.38 | ... | 0.95 | 0.95 | 0.95 | 0.95 |
mean_test_score | 0.37 | 0.37 | 0.37 | 0.37 | ... | 0.95 | 0.97 | 0.96 | 0.96 |
std_test_score | 0.011 | 0.011 | 0.011 | 0.011 | ... | 0.033 | 0.022 | 0.034 | 0.034 |
rank_test_score | 27 | 27 | 27 | 27 | ... | 11 | 1 | 3 | 3 |
split0_train_score | 0.36 | 0.36 | 0.36 | 0.36 | ... | 0.97 | 0.99 | 0.99 | 0.99 |
split1_train_score | 0.37 | 0.37 | 0.37 | 0.37 | ... | 0.98 | 0.98 | 0.99 | 0.99 |
split2_train_score | 0.37 | 0.37 | 0.37 | 0.37 | ... | 0.94 | 0.98 | 0.98 | 0.99 |
split3_train_score | 0.37 | 0.37 | 0.37 | 0.37 | ... | 0.98 | 0.99 | 0.99 | 1 |
split4_train_score | 0.36 | 0.36 | 0.36 | 0.36 | ... | 0.97 | 0.99 | 1 | 1 |
mean_train_score | 0.37 | 0.37 | 0.37 | 0.37 | ... | 0.97 | 0.98 | 0.99 | 0.99 |
std_train_score | 0.0029 | 0.0029 | 0.0029 | 0.0029 | ... | 0.012 | 0.0055 | 0.007 | 0.0055 |
23 rows × 42 columns
Using different cross-validation strategies with grid search
Nested cross-validation
In [39]:
Cross-validation scores: [0.967 1. 0.967 0.967 1. ]
Mean cross-validation score: 0.9800000000000001
In [40]:
In [41]:
Cross-validation scores: [0.967 1. 0.967 0.967 1. ]
Parallelizing cross-validation and grid search
Evaluation Metrics and Scoring
Keep the End Goal in Mind
Metrics for Binary Classification
Kinds of errors
Imbalanced datasets
In [42]:
In [43]:
Unique predicted labels: [False]
Test score: 0.90
In [44]:
Test score: 0.92
In [45]:
dummy score: 0.80
logreg score: 0.98
Confusion matrices
In [46]:
Confusion matrix:
[[401 2]
[ 8 39]]
In [47]:
In [48]:
In [49]:
Most frequent class:
[[403 0]
[ 47 0]]
Dummy model:
[[366 37]
[ 43 4]]
Decision tree:
[[390 13]
[ 24 23]]
Logistic Regression
[[401 2]
[ 8 39]]
Relation to accuracy
Precision, recall and f-score
In [50]:
f1 score most frequent: 0.00
f1 score dummy: 0.09
f1 score tree: 0.55
f1 score logistic regression: 0.89
/home/andy/checkout/scikit-learn/sklearn/metrics/classification.py:1143: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 due to no predicted samples.
'precision', 'predicted', average, warn_for)
In [51]:
precision recall f1-score support
not nine 0.90 1.00 0.94 403
nine 0.00 0.00 0.00 47
micro avg 0.90 0.90 0.90 450
macro avg 0.45 0.50 0.47 450
weighted avg 0.80 0.90 0.85 450
/home/andy/checkout/scikit-learn/sklearn/metrics/classification.py:1143: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.
'precision', 'predicted', average, warn_for)
In [52]:
precision recall f1-score support
not nine 0.89 0.91 0.90 403
nine 0.10 0.09 0.09 47
micro avg 0.82 0.82 0.82 450
macro avg 0.50 0.50 0.50 450
weighted avg 0.81 0.82 0.82 450
In [53]:
precision recall f1-score support
not nine 0.98 1.00 0.99 403
nine 0.95 0.83 0.89 47
micro avg 0.98 0.98 0.98 450
macro avg 0.97 0.91 0.94 450
weighted avg 0.98 0.98 0.98 450
Taking uncertainty into account
In [54]:
In [55]:
In [56]:
precision recall f1-score support
0 0.97 0.89 0.93 104
1 0.35 0.67 0.46 9
micro avg 0.88 0.88 0.88 113
macro avg 0.66 0.78 0.70 113
weighted avg 0.92 0.88 0.89 113
In [57]:
In [58]:
precision recall f1-score support
0 1.00 0.82 0.90 104
1 0.32 1.00 0.49 9
micro avg 0.83 0.83 0.83 113
macro avg 0.66 0.91 0.69 113
weighted avg 0.95 0.83 0.87 113
Precision-Recall curves and ROC curves
In [59]:
In [60]:
<matplotlib.legend.Legend at 0x7fcb17f90e10>
In [61]:
<matplotlib.legend.Legend at 0x7fcb18404240>
In [62]:
f1_score of random forest: 0.610
f1_score of svc: 0.656
In [63]:
Average precision of random forest: 0.660
Average precision of svc: 0.666
Receiver Operating Characteristics (ROC) and AUC
In [64]:
<matplotlib.legend.Legend at 0x7fcb180db978>
In [65]:
<matplotlib.legend.Legend at 0x7fcb180dba20>
In [66]:
AUC for Random Forest: 0.937
AUC for SVC: 0.916
In [67]:
gamma = 1.00 accuracy = 0.90 AUC = 0.50
gamma = 0.05 accuracy = 0.90 AUC = 1.00
gamma = 0.01 accuracy = 0.90 AUC = 1.00
<matplotlib.legend.Legend at 0x7fcb17f67e48>
Metrics for Multiclass Classification
In [68]:
Accuracy: 0.953
Confusion matrix:
[[37 0 0 0 0 0 0 0 0 0]
[ 0 39 0 0 0 0 2 0 2 0]
[ 0 0 41 3 0 0 0 0 0 0]
[ 0 0 1 43 0 0 0 0 0 1]
[ 0 0 0 0 38 0 0 0 0 0]
[ 0 1 0 0 0 47 0 0 0 0]
[ 0 0 0 0 0 0 52 0 0 0]
[ 0 1 0 1 1 0 0 45 0 0]
[ 0 3 1 0 0 0 0 0 43 1]
[ 0 0 0 1 0 1 0 0 1 44]]
In [69]:
In [70]:
precision recall f1-score support
0 1.00 1.00 1.00 37
1 0.89 0.91 0.90 43
2 0.95 0.93 0.94 44
3 0.90 0.96 0.92 45
4 0.97 1.00 0.99 38
5 0.98 0.98 0.98 48
6 0.96 1.00 0.98 52
7 1.00 0.94 0.97 48
8 0.93 0.90 0.91 48
9 0.96 0.94 0.95 47
micro avg 0.95 0.95 0.95 450
macro avg 0.95 0.95 0.95 450
weighted avg 0.95 0.95 0.95 450
In [71]:
Micro average f1 score: 0.953
Macro average f1 score: 0.954
Regression metrics
Using evaluation metrics in model selection
In [72]:
Default scoring: [0.9 0.9 0.9 0.9 0.9]
Explicit accuracy scoring: [0.9 0.9 0.9 0.9 0.9]
AUC scoring: [0.997 0.997 0.996 0.998 0.992]
In [73]:
fit_time | score_time | test_accuracy | train_accuracy | test_roc_auc | train_roc_auc | test_recall_macro | train_recall_macro | |
---|---|---|---|---|---|---|---|---|
0 | 0.24 | 0.15 | 0.9 | 1.0 | 1.00 | 1.0 | 0.5 | 1.0 |
1 | 0.23 | 0.16 | 0.9 | 1.0 | 1.00 | 1.0 | 0.5 | 1.0 |
2 | 0.22 | 0.15 | 0.9 | 1.0 | 1.00 | 1.0 | 0.5 | 1.0 |
3 | 0.22 | 0.16 | 0.9 | 1.0 | 1.00 | 1.0 | 0.5 | 1.0 |
4 | 0.22 | 0.15 | 0.9 | 1.0 | 0.99 | 1.0 | 0.5 | 1.0 |
In [74]:
Grid-Search with accuracy
Best parameters: {'gamma': 0.0001}
Best cross-validation score (accuracy)): 0.970
Test set AUC: 0.992
Test set accuracy: 0.973
In [75]:
Grid-Search with AUC
Best parameters: {'gamma': 0.01}
Best cross-validation score (AUC): 0.997
Test set AUC: 1.000
Test set accuracy: 1.000
In [87]:
Available scorers:
['explained_variance', 'r2', 'neg_median_absolute_error', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'accuracy', 'roc_auc', 'balanced_accuracy', 'average_precision', 'neg_log_loss', 'brier_score_loss', 'adjusted_rand_score', 'homogeneity_score', 'completeness_score', 'v_measure_score', 'mutual_info_score', 'adjusted_mutual_info_score', 'normalized_mutual_info_score', 'fowlkes_mallows_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted']
Summary and Outlook
In [ ]: