import random
import numpy as np
from catboost import CatBoostRegressor
from lightgbm import LGBMRegressor
from sklearn.datasets import load_diabetes
from sklearn.ensemble import ExtraTreesRegressor, AdaBoostRegressor, GradientBoostingRegressor, RandomForestRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor
from evolutionary_forest.forest import EvolutionaryForestRegressor
from evolutionary_forest.utils import get_feature_importance, plot_feature_importance, feature_append
random.seed(0)
np.random.seed(0)
# Load Dataset
X, y = load_diabetes(return_X_y=True)
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
# Train Random Forest
r = RandomForestRegressor()
r.fit(x_train, y_train)
print('随机森林R2分数', r2_score(y_test, r.predict(x_test)))
# Train Evolutionary Forest
r = EvolutionaryForestRegressor(max_height=3, normalize=True, select='AutomaticLexicase',
gene_num=10, boost_size=100, n_gen=20, n_pop=200, cross_pb=1,
base_learner='Random-DT', verbose=True, n_process=64)
r.fit(x_train, y_train)
print('演化森林R2分数', r2_score(y_test, r.predict(x_test)))
PicklingError Traceback (most recent call last)
/tmp/ipykernel_24448/3027070399.py in <module>
29 gene_num=10, boost_size=100, n_gen=20, n_pop=200, cross_pb=1,
30 base_learner='Random-DT', verbose=True, n_process=64)
---> 31 r.fit(x_train, y_train)
32 print('演化森林R2分数', r2_score(y_test, r.predict(x_test)))
~/anaconda3/lib/python3.9/site-packages/evolutionary_forest/forest.py in fit(self, X, y, test_X)
2137 else:
2138 # Not using gradient boosting mode
-> 2139 pop, log = self.eaSimple(self.pop, self.toolbox, self.cross_pb, self.mutation_pb, self.n_gen,
2140 stats=mstats, halloffame=self.hof, verbose=self.verbose)
2141 self.pop = pop
~/anaconda3/lib/python3.9/site-packages/evolutionary_forest/forest.py in eaSimple(self, population, toolbox, cxpb, mutpb, ngen, stats, halloffame, verbose)
2548 invalid_ind = self.multiobjective_evaluation(toolbox, population)
2549 else:
-> 2550 invalid_ind = self.population_evaluation(toolbox, population)
2551 if self.environmental_selection == 'NSGA2-Mixup':
2552 self.mixup_evaluation(self.toolbox, population)
~/anaconda3/lib/python3.9/site-packages/evolutionary_forest/forest.py in population_evaluation(self, toolbox, population)
4385 # distribute tasks
4386 if self.n_process > 1:
-> 4387 data = [next(f) for f in fitnesses]
4388 results = list(self.pool.map(calculate_score, data))
4389 else:
~/anaconda3/lib/python3.9/site-packages/evolutionary_forest/forest.py in <listcomp>(.0)
4385 # distribute tasks
4386 if self.n_process > 1:
-> 4387 data = [next(f) for f in fitnesses]
4388 results = list(self.pool.map(calculate_score, data))
4389 else:
~/anaconda3/lib/python3.9/site-packages/evolutionary_forest/forest.py in fitness_evaluation(self, individual)
714 information: EvaluationResults
715 if self.n_process > 1:
--> 716 y_pred, estimators, information = yield pipe, dill.dumps(genes, protocol=-1)
717 else:
718 y_pred, estimators, information = yield pipe, genes
~/anaconda3/lib/python3.9/site-packages/dill/_dill.py in dumps(obj, protocol, byref, fmode, recurse, **kwds)
302 """
303 file = StringIO()
--> 304 dump(obj, file, protocol, byref, fmode, recurse, **kwds)#, strictio)
305 return file.getvalue()
306
~/anaconda3/lib/python3.9/site-packages/dill/_dill.py in dump(obj, file, protocol, byref, fmode, recurse, **kwds)
274 _kwds = kwds.copy()
275 _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
--> 276 Pickler(file, protocol, **_kwds).dump(obj)
277 return
278
~/anaconda3/lib/python3.9/site-packages/dill/_dill.py in dump(self, obj)
496 raise PicklingError(msg)
497 else:
--> 498 StockPickler.dump(self, obj)
499 stack.clear() # clear record of 'recursion-sensitive' pickled objects
500 return
~/anaconda3/lib/python3.9/pickle.py in dump(self, obj)
485 if self.proto >= 4:
486 self.framer.start_framing()
--> 487 self.save(obj)
488 self.write(STOP)
489 self.framer.end_framing()
~/anaconda3/lib/python3.9/pickle.py in save(self, obj, save_persistent_id)
558 f = self.dispatch.get(t)
559 if f is not None:
--> 560 f(self, obj) # Call unbound method with explicit self
561 return
562
~/anaconda3/lib/python3.9/pickle.py in save_list(self, obj)
929
930 self.memoize(obj)
--> 931 self._batch_appends(obj)
932
933 dispatch[list] = save_list
~/anaconda3/lib/python3.9/pickle.py in _batch_appends(self, items)
953 write(MARK)
954 for x in tmp:
--> 955 save(x)
956 write(APPENDS)
957 elif n:
~/anaconda3/lib/python3.9/pickle.py in save(self, obj, save_persistent_id)
601
602 # Save the reduce() output and finally memoize the object
--> 603 self.save_reduce(obj=obj, *rv)
604
605 def persistent_id(self, obj):
~/anaconda3/lib/python3.9/pickle.py in save_reduce(self, func, args, state, listitems, dictitems, state_setter, obj)
708
709 if listitems is not None:
--> 710 self._batch_appends(listitems)
711
712 if dictitems is not None:
~/anaconda3/lib/python3.9/pickle.py in _batch_appends(self, items)
953 write(MARK)
954 for x in tmp:
--> 955 save(x)
956 write(APPENDS)
957 elif n:
~/anaconda3/lib/python3.9/pickle.py in save(self, obj, save_persistent_id)
601
602 # Save the reduce() output and finally memoize the object
--> 603 self.save_reduce(obj=obj, *rv)
604
605 def persistent_id(self, obj):
~/anaconda3/lib/python3.9/pickle.py in save_reduce(self, func, args, state, listitems, dictitems, state_setter, obj)
685 "args[0] from __newobj__ args has the wrong class")
686 args = args[1:]
--> 687 save(cls)
688 save(args)
689 write(NEWOBJ)
~/anaconda3/lib/python3.9/pickle.py in save(self, obj, save_persistent_id)
558 f = self.dispatch.get(t)
559 if f is not None:
--> 560 f(self, obj) # Call unbound method with explicit self
561 return
562
~/anaconda3/lib/python3.9/site-packages/dill/_dill.py in save_type(pickler, obj)
1437 #print ("%s\n%s" % (obj.__bases__, obj.__dict__))
1438 name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
-> 1439 StockPickler.save_global(pickler, obj, name=name)
1440 log.info("# T4")
1441 return
~/anaconda3/lib/python3.9/pickle.py in save_global(self, obj, name)
1068 obj2, parent = _getattribute(module, name)
1069 except (ImportError, KeyError, AttributeError):
-> 1070 raise PicklingError(
1071 "Can't pickle %r: it's not found as %s.%s" %
1072 (obj, module_name, name)) from None
PicklingError: Can't pickle <class 'deap.gp.rand101'>: it's not found as deap.gp.rand101