Comments (3)
Hi chickenbestlover,
Can you help me explain why one is return ''return total_loss / (nbatch)' and the other is 'return total_loss / (nbatch+1)'
Thanks,
Chauncey
from rnn-time-series-anomaly-detection.
def evaluate_1step_pred(args, model, test_dataset):
# Turn on evaluation mode which disables dropout.
model.eval()
total_loss = 0
with torch.no_grad():
hidden = model.init_hidden(args.eval_batch_size)
for nbatch, i in enumerate(range(0, test_dataset.size(0) - 1, args.bptt)):
inputSeq, targetSeq = get_batch(args,test_dataset, i)
outSeq, hidden = model.forward(inputSeq, hidden)
loss = criterion(outSeq.view(args.batch_size,-1), targetSeq.view(args.batch_size,-1))
hidden = model.repackage_hidden(hidden)
total_loss+= loss.item()
return total_loss / nbatch **# you mean here? his is my mistake. it should be loss / ( nbatch+1)**
By the way, the function evaluate_1step_pred
has been depreciated and no longer be used to evaluate accuracy.
from rnn-time-series-anomaly-detection.
Thank you.
Also, it seems you augment the training data in any case.
def preprocessing(self, path, train=True):
""" Read, Standardize, Augment """
with open(str(path), 'rb') as f:
data = torch.FloatTensor(pickle.load(f))
label = data[:,-1]
data = data[:,:-1]
if train:
self.mean = data.mean(dim=0)
self.std= data.std(dim=0)
self.length = len(data)
data,label = self.augmentation(data,label)
else:
if self.augment_test_data:
data, label = self.augmentation(data, label)
data = standardization(data,self.mean,self.std)
return data,label
from rnn-time-series-anomaly-detection.
Related Issues (20)
- loss function
- [BUG]: Model saving logic is wrong HOT 1
- [BUG]: Computing normal stats is wrong
- UnicodeDecodeError on Windows10 HOT 3
- RuntimeError: view size is not compatible with input tensor's size and stride HOT 2
- Training on Custom Dataset
- Operation does not have an identity.
- What if you have taxi data from multiple states?
- repackage_hidden의 역할이 뭔가요? HOT 1
- The program is stuck when running SRU, there is no prompt message
- error changes HOT 1
- Runtime error running example HOT 3
- resume problem
- RuntimeError when using tie_weights=True
- Labels of the datasets HOT 1
- paper
- 调通代码
- what data augmentation method you use?
- A question about the size of the rnn input emb in function forward HOT 1
- `chfdb_chf14_45590.pkl` file isn't found
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from rnn-time-series-anomaly-detection.