对于任意一个训练后的模型:
estimator = SimpleFeedForwardEstimator(xxxx)
predictor = estimator.train(training_data) # 开始训练
可以保存、加载的是predictor
,使用下面的方法:
保存模型
from pathlib import Path
save_path = "my_model"
predictor.serialize(Path("models/{}/".format(save_path)))
加载模型
from pathlib import Path
from gluonts.model.predictor import Predictor
load_path = "my_model"
predictor = Predictor.deserialize(Path("models/{}/".format(load_path)))