深圳幻海软件技术有限公司 欢迎您!

ML.NET Cookbook:(7)如何训练回归模型?

2023-03-22

通常,为了在ML.NET中训练任何模型,您将经历三个步骤:弄清楚训练数据如何以IDataView形式进入ML.NET。将“学习管道”构建为一系列基本的“运算符”(估计器)。在管道上调用Fit以获得经过训练的模型。示例文件[1]:feature_0;feature_1;feature_2;featur

通常,为了在ML.NET中训练任何模型,您将经历三个步骤:

  1. 弄清楚训练数据如何以IDataView形式进入ML.NET。

  2. 将“学习管道”构建为一系列基本的“运算符”(估计器)。

  3. 在管道上调用Fit以获得经过训练的模型。

示例文件[1]:

  1. feature_0;feature_1;feature_2;feature_3;feature_4;feature_5;feature_6;feature_7;feature_8;feature_9;feature_10;target
  2. -2.75;0.77;-0.61;0.14;1.39;0.38;-0.53;-0.50;-2.13;-0.39;0.46;140.66
  3. -0.61;-0.37;-0.12;0.55;-1.00;0.84;-0.02;1.30;-0.24;-0.50;-2.12;148.12
  4. -0.85;-0.91;1.81;0.02;-0.78;-1.41;-1.09;-0.65;0.90;-0.37;-0.22;402.20

在上面的文件中,最后一列(第12列)是我们预测的标签,前面所有的都是特征。

  1. // 第一步:将数据加载为IDataView。
  2. // 首先,我们定义加载器:指定数据列以及它们在文本文件中的位置。
  3. // 将数据加载到数据视图中。但是请记住,加载器是延迟执行的,所以实际加载将在访问数据时发生。
  4. var trainData = mlContext.Data.LoadFromTextFile<RegressionData>(dataPath,
  5.     // 默认分隔符是tab,但数据集使用分号。
  6.     separatorChar: ';'
  7. );
  8. // 有时,当数据要在某个地方多次使用时,在首次访问后将数据缓存在内存中可以节省一些加载时间。缓存机制也是延迟执行的;它只在使用后才缓存东西。用户可以用“cachedTrainData”替换“trainData”的所有后续用法。
  9. // 我们仍然使用“trainData”,因为提供相同缓存功能的缓存步骤将插入到所考虑的“管道”中。
  10. var cachedTrainData = mlContext.Data.Cache(trainData);
  11. // 第二步:定义学习管道。
  12. // 我们用加载器的输出“启动”管道。
  13. var pipeline =
  14.     // 首先“规范化”数据(对于所有样本,重新缩放到-1和1之间)
  15.     mlContext.Transforms.NormalizeMinMax("FeatureVector")
  16.     // 我们增加了一个在内存中缓存数据的步骤,使得下游的迭代训练算法能够有效地对数据进行多次扫描。否则,下面的训练器将多次从磁盘加载数据。缓存机制使用按需策略。
  17.     // 在任何下游步骤中访问的数据都将在首次使用后被缓存。通常,您只需要在可训练步骤之前添加一个缓存步骤,因为如果数据只扫描一次,则缓存没有帮助。如果用户没有足够的内存来存储整个数据集,则可以删除此步骤。请注意,在上游Transforms.Normalize步骤中,我们只扫描数据一次,因此添加缓存步骤是没有帮助的。
  18.     .AppendCacheCheckpoint(mlContext)
  19.     // 添加SDCA回归训练器。
  20.     .Append(mlContext.Regression.Trainers.Sdca(labelColumnName: "Target", featureColumnName: "FeatureVector"));
  21. //第三步: 在管道上调用`Fit`
  22. var model = pipeline.Fit(trainData);

参考资料

[1]

示例文件: https://github.com/dotnet/machinelearning/blob/main/test/data/generated_regression_dataset.csv

文章知识点与官方知识档案匹配,可进一步学习相关知识
OpenCV技能树首页概览14595 人正在系统学习中