if num_splits is None:
num_splits = 10
logger.warning("num_splits not defined, using default value of \
10 splits instead ")
if isinstance(data.dataset, BaseLazyDataset):
logger.warning("A lazy dataset is given for stratified kfold. \
Iterating over the dataset to extract labels for \
stratification may be a massive overhead")
split_idxs = list(range(len(data.dataset)))
fold = KFold(n_splits=num_splits, shuffle=shuffle,
random_state=random_seed)
for idx, (_train_idxs, test_idxs) in enumerate(fold.split(split_idxs,
split_labels)):
# extract data from single manager
_train_data = data.get_subset(_train_idxs)
_split_idxs = list(range(len(_train_data.dataset)))
val_fold = ShuffleSplit(n_splits=1,
test_size=split_val,
random_state=random_seed)
for train_idxs, val_idx in val_fold.split(_split_idxs):
train_data = _train_data.get_subset(train_idxs)
val_data = _train_data.get_subset(val_idx)
test_data = data.get_subset(test_idxs)
# update manager behavior for train and test case
train_data.update_state_from_dict(train_kwargs)
val_data.update_state_from_dict(test_kwargs)
test_data.update_state_from_dict(test_kwargs)
model = self.run(train_data, val_data,
num_epochs=num_epochs,
fold=idx,
**kwargs)
_outputs, _metrics_val = self.test(
self.params, model, test_data)
outputs[str(idx)] = _outputs
metrics_val[str(idx)] = _metrics_val
return outputs, metrics_val