Git Product home page Git Product logo

Comments (5)

Codefmeister avatar Codefmeister commented on July 21, 2024 1

Thanks for your kindness. It's very helpful.

from swa_gaussian.

wjmaddox avatar wjmaddox commented on July 21, 2024

@izmailovpavel may have the notebook still, but try plotting on a log scale for x. Also double check that your signs are correct as they potentially could be flipped.

from swa_gaussian.

Codefmeister avatar Codefmeister commented on July 21, 2024

Thank you! I have tried log scale but it seems a little bit strange, maybe I should define a proper transformation for the xsticks.
And I will be extremely appreciated if gentleman @izmailovpavel could provide some clues for reproducing this beautiful figure.
Thanks for your kindness.

from swa_gaussian.

izmailovpavel avatar izmailovpavel commented on July 21, 2024

Hey @Codefmeister, something seems strange in how your xticks are arranged. Here's our code for making the plots

styles = {name: (label, color) for (name, label, _, color) in new_methods().name_marker_pairs}

methods = {'SWAG-Cov', 'SWA-temp', 'SWA-Drop', 'SGD', 'SWAG-Diag', 'Laplace-SGD', 'SGLD'}

from matplotlib.ticker import FormatStrFormatter

class CustomScale(mscale.ScaleBase):
    name = 'custom'
    eps = 0.002

    def __init__(self, axis, **kwargs):
        mscale.ScaleBase.__init__(self)
        self.thresh = None #thresh

    def get_transform(self):
        return self.CustomTransform(self.thresh)

    def set_default_locators_and_formatters(self, axis):
        pass

    class CustomTransform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True        
        

        def __init__(self, thresh):
            mtransforms.Transform.__init__(self)
            self.thresh = thresh
            

        def transform_non_affine(self, a):
            return -np.log(1 + CustomScale.eps - a)

        def inverted(self):
            return CustomScale.InvertedCustomTransform(self.thresh)
    
    class InvertedCustomTransform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, thresh):
            mtransforms.Transform.__init__(self)
            self.thresh = thresh

        def transform_non_affine(self, a):
            return 1 + CustomScale.eps - np.exp(-a)

        def inverted(self):
            return CustomScale.CustomTransform(self.thresh)
mscale.register_scale(CustomScale)

fig, axes = plt.subplots(figsize=(37, 8), nrows=1, ncols=4)
plt.subplots_adjust(wspace=0.3, bottom=0.25)

def calibration_plot(results, ds, model):    
    for method, curve in sorted(results.items()):
        #print(method, 'YN'[int(curve is None)])
        if method not in methods:
            continue
        label, color = styles[method]
        if curve is not None:        
            plt.plot(curve['confidence'], curve['confidence'] - curve['accuracy'], linewidth=4, marker='o', markersize=8, 
                    color=color, label='%s' % (label), zorder=3)         
    plt.plot(np.linspace(0.1, 1.0, 100), np.zeros(100), 'k--', dashes=(5, 5), linewidth=3, zorder=2)


    plt.gca().set_xscale('custom')

    
    ticks = 1.0 - np.logspace(np.log(0.8), np.log(0.002), 6, base=np.e)
    plt.xticks(ticks, fontsize=22)    
    plt.yticks(fontsize=22)
    plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    plt.margins(x=0.03)
    plt.ylabel('Confidence - Accuracy', fontsize=28)
    plt.xlabel('Confidence (max prob)', fontsize=28)
    plt.title('%s %s' % (model, ds), fontsize=28, y=1.02)
    plt.grid()
    
plt.sca(axes[0])
calibration_plot(load_dict('./data/calibrations/c100_wrn_new.pkl'), 'CIFAR-100', 'WideResNet28x10')
plt.sca(axes[1])
calibration_plot(load_dict('./data/calibrations/stl_wrn.pkl'), 'CIFAR-10 $\\rightarrow$ STL-10', 'WideResNet28x10')
plt.sca(axes[2])
calibration_plot(load_dict('./data/calibrations/imagenet_densenet161.pkl'), 'ImageNet', 'DenseNet-161')
plt.sca(axes[3])
calibration_plot(load_dict('./data/calibrations/imagenet_resnet152.pkl'), 'ImageNet', 'ResNet-152')


#plt.sca(axes[1])


handles, labels = axes[0].get_legend_handles_labels()
leg = plt.figlegend(handles, labels, fontsize=28, loc='lower center', bbox_to_anchor=(0.43, 0.0), ncol=6)
for legobj in leg.legendHandles:
    legobj.set_linewidth(6.0)
    legobj._legmarker.set_markersize(12.0)

plt.savefig('./pics/calibration_curves.pdf', format='pdf', bbox_inches='tight')
plt.show()

It was originally written by @timgaripov.

from swa_gaussian.

izmailovpavel avatar izmailovpavel commented on July 21, 2024

For another paper, I used this code to plot the calibration curves, which is a lot simpler:

plt.figure(figsize=(3, 3))
def plot_calibration(arr):
    plt.plot(arr["confidence"], arr["accuracy"] - arr["confidence"], 
             "-o", color=arr["color"], mec="k", ms=7, lw=3)

# plot_calibration({**matt_arr["deep_ensemble_calibration"], "color": de_color})
plot_calibration({**new_calibration_arr["deep_ensemble"].item(), "color": de_color})
plot_calibration({**new_calibration_arr["sgld"].item(), "color": sgld_color})
plot_calibration({**new_calibration_arr["sgld_mom_clr_prec"].item(), "color": sgld_hot_color})
plot_calibration({**matt_arr["hmc_calibration"], "color": "orange"})
# plot_calibration({**matt_arr["sgld_calibration"], "color": sgld_color})
# plot_calibration({**matt_arr["sgld_hot_calibration"], "color": sgld_hot_color})
plt.hlines(0., 0., 1., color="k", linestyle="dashed")
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel("Confidence", fontsize=16)
plt.ylabel("Accuracy - Confidence", fontsize=16)
plt.grid()
plt.xlim(0.35, 1.05)
plt.savefig("calibration_curve.pdf", bbox_inches="tight")

from swa_gaussian.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.