Git Product home page Git Product logo

Comments (3)

chenshanghao avatar chenshanghao commented on May 24, 2024

Hi chickenbestlover,

Can you help me explain why one is return ''return total_loss / (nbatch)' and the other is 'return total_loss / (nbatch+1)'

Thanks,
Chauncey

from rnn-time-series-anomaly-detection.

chickenbestlover avatar chickenbestlover commented on May 24, 2024
def evaluate_1step_pred(args, model, test_dataset):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0
    with torch.no_grad():
        hidden = model.init_hidden(args.eval_batch_size)
        for nbatch, i in enumerate(range(0, test_dataset.size(0) - 1, args.bptt)):

            inputSeq, targetSeq = get_batch(args,test_dataset, i)
            outSeq, hidden = model.forward(inputSeq, hidden)

            loss = criterion(outSeq.view(args.batch_size,-1), targetSeq.view(args.batch_size,-1))
            hidden = model.repackage_hidden(hidden)
            total_loss+= loss.item()

    return total_loss / nbatch     **# you mean here? his is my mistake. it should be loss / ( nbatch+1)**

By the way, the function evaluate_1step_pred has been depreciated and no longer be used to evaluate accuracy.

from rnn-time-series-anomaly-detection.

chenshanghao avatar chenshanghao commented on May 24, 2024

Thank you.
Also, it seems you augment the training data in any case.

def preprocessing(self, path, train=True):
    """ Read, Standardize, Augment """

    with open(str(path), 'rb') as f:
        data = torch.FloatTensor(pickle.load(f))
        label = data[:,-1]
        data = data[:,:-1]
    if train:
        self.mean = data.mean(dim=0)
        self.std= data.std(dim=0)
        self.length = len(data)
        data,label = self.augmentation(data,label)
    else:
        if self.augment_test_data:
            data, label = self.augmentation(data, label)

    data = standardization(data,self.mean,self.std)

    return data,label

from rnn-time-series-anomaly-detection.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.