Comments (5)
Thanks for your kindness. It's very helpful.
from swa_gaussian.
@izmailovpavel may have the notebook still, but try plotting on a log scale for x. Also double check that your signs are correct as they potentially could be flipped.
from swa_gaussian.
Thank you! I have tried log scale but it seems a little bit strange, maybe I should define a proper transformation for the xsticks.
And I will be extremely appreciated if gentleman @izmailovpavel could provide some clues for reproducing this beautiful figure.
Thanks for your kindness.
from swa_gaussian.
Hey @Codefmeister, something seems strange in how your xticks are arranged. Here's our code for making the plots
styles = {name: (label, color) for (name, label, _, color) in new_methods().name_marker_pairs}
methods = {'SWAG-Cov', 'SWA-temp', 'SWA-Drop', 'SGD', 'SWAG-Diag', 'Laplace-SGD', 'SGLD'}
from matplotlib.ticker import FormatStrFormatter
class CustomScale(mscale.ScaleBase):
name = 'custom'
eps = 0.002
def __init__(self, axis, **kwargs):
mscale.ScaleBase.__init__(self)
self.thresh = None #thresh
def get_transform(self):
return self.CustomTransform(self.thresh)
def set_default_locators_and_formatters(self, axis):
pass
class CustomTransform(mtransforms.Transform):
input_dims = 1
output_dims = 1
is_separable = True
def __init__(self, thresh):
mtransforms.Transform.__init__(self)
self.thresh = thresh
def transform_non_affine(self, a):
return -np.log(1 + CustomScale.eps - a)
def inverted(self):
return CustomScale.InvertedCustomTransform(self.thresh)
class InvertedCustomTransform(mtransforms.Transform):
input_dims = 1
output_dims = 1
is_separable = True
def __init__(self, thresh):
mtransforms.Transform.__init__(self)
self.thresh = thresh
def transform_non_affine(self, a):
return 1 + CustomScale.eps - np.exp(-a)
def inverted(self):
return CustomScale.CustomTransform(self.thresh)
mscale.register_scale(CustomScale)
fig, axes = plt.subplots(figsize=(37, 8), nrows=1, ncols=4)
plt.subplots_adjust(wspace=0.3, bottom=0.25)
def calibration_plot(results, ds, model):
for method, curve in sorted(results.items()):
#print(method, 'YN'[int(curve is None)])
if method not in methods:
continue
label, color = styles[method]
if curve is not None:
plt.plot(curve['confidence'], curve['confidence'] - curve['accuracy'], linewidth=4, marker='o', markersize=8,
color=color, label='%s' % (label), zorder=3)
plt.plot(np.linspace(0.1, 1.0, 100), np.zeros(100), 'k--', dashes=(5, 5), linewidth=3, zorder=2)
plt.gca().set_xscale('custom')
ticks = 1.0 - np.logspace(np.log(0.8), np.log(0.002), 6, base=np.e)
plt.xticks(ticks, fontsize=22)
plt.yticks(fontsize=22)
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%.3f'))
plt.margins(x=0.03)
plt.ylabel('Confidence - Accuracy', fontsize=28)
plt.xlabel('Confidence (max prob)', fontsize=28)
plt.title('%s %s' % (model, ds), fontsize=28, y=1.02)
plt.grid()
plt.sca(axes[0])
calibration_plot(load_dict('./data/calibrations/c100_wrn_new.pkl'), 'CIFAR-100', 'WideResNet28x10')
plt.sca(axes[1])
calibration_plot(load_dict('./data/calibrations/stl_wrn.pkl'), 'CIFAR-10 $\\rightarrow$ STL-10', 'WideResNet28x10')
plt.sca(axes[2])
calibration_plot(load_dict('./data/calibrations/imagenet_densenet161.pkl'), 'ImageNet', 'DenseNet-161')
plt.sca(axes[3])
calibration_plot(load_dict('./data/calibrations/imagenet_resnet152.pkl'), 'ImageNet', 'ResNet-152')
#plt.sca(axes[1])
handles, labels = axes[0].get_legend_handles_labels()
leg = plt.figlegend(handles, labels, fontsize=28, loc='lower center', bbox_to_anchor=(0.43, 0.0), ncol=6)
for legobj in leg.legendHandles:
legobj.set_linewidth(6.0)
legobj._legmarker.set_markersize(12.0)
plt.savefig('./pics/calibration_curves.pdf', format='pdf', bbox_inches='tight')
plt.show()
It was originally written by @timgaripov.
from swa_gaussian.
For another paper, I used this code to plot the calibration curves, which is a lot simpler:
plt.figure(figsize=(3, 3))
def plot_calibration(arr):
plt.plot(arr["confidence"], arr["accuracy"] - arr["confidence"],
"-o", color=arr["color"], mec="k", ms=7, lw=3)
# plot_calibration({**matt_arr["deep_ensemble_calibration"], "color": de_color})
plot_calibration({**new_calibration_arr["deep_ensemble"].item(), "color": de_color})
plot_calibration({**new_calibration_arr["sgld"].item(), "color": sgld_color})
plot_calibration({**new_calibration_arr["sgld_mom_clr_prec"].item(), "color": sgld_hot_color})
plot_calibration({**matt_arr["hmc_calibration"], "color": "orange"})
# plot_calibration({**matt_arr["sgld_calibration"], "color": sgld_color})
# plot_calibration({**matt_arr["sgld_hot_calibration"], "color": sgld_hot_color})
plt.hlines(0., 0., 1., color="k", linestyle="dashed")
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel("Confidence", fontsize=16)
plt.ylabel("Accuracy - Confidence", fontsize=16)
plt.grid()
plt.xlim(0.35, 1.05)
plt.savefig("calibration_curve.pdf", bbox_inches="tight")
from swa_gaussian.
Related Issues (20)
- Replicating results from paper with dropout HOT 4
- Running on CPU HOT 2
- Replicating results of transfer learning and out-of-domain image detection HOT 3
- Could you share the pretrained model for imagenet? HOT 4
- Cannot find key 'n_models' HOT 1
- Question about KFACLaplace for BatchNorm
- Error with CUDA10 HOT 5
- Questions about the implementation of calculation of Low-Rank Covariance Matrix HOT 2
- Loading SWAG Checkpoint and Continue SWAG Training HOT 7
- Non-Reproducible / Weird Uncertainty Results HOT 1
- Results CSV
- RMSE UCI Regression Results Paper
- Reproducing UCI Regression Experiments
- Sampling using SWAG HOT 2
- reliability diagrams HOT 7
- Cannot understand result HOT 1
- Why BN Update is not used for other methods like SGD HOT 5
- Reproducibility of Uncertainty Experiment HOT 2
- 'CIFAR10' object has no attribute 'targets' HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from swa_gaussian.