Home > life is fun > Spark on Yarn:Softmax Regression算法的实现

Spark on Yarn:Softmax Regression算法的实现


Taobao

Spark on Yarn:Softmax Regression算法的实现

我们团队一直在对Spark在大规模数据挖掘、机器学习上的应用进行实践和探索,本系列文章是我们在使用Spark进行分布式开发的总结,有不足的地方,欢迎大家指正和交流,希望更多的人加入我们。

本文是通过在Spark on Yarn上实现Softmax Regression算法的训练部分,了解和测试Spark on Yarn的实际生产性能

1. Softmax Regression

Softmax Regression 也叫 multinomial logistic regression,可以处理多分类问题,可以看作 logistic regression 在多分类上的推广。关于 Softmax Regression 细节的描述和 loss function 的推导可以参考这里 或者 Andrew Ng的讲义 cs229 note

2. Softmax Regression在Spark上的实现

Spark源码的 example 包有 logistic regression 的实现,跟它一样,本文描述的 softmax regression 实现也采用批量梯度下降算法进行优化求解(实际应用中,批量梯度下降的求解速度较慢,更常见的是lbfgs)。我们实验中采用的一份实际数据集包含130万条数据,有100个类别,feature 是接近100万维的稀疏向量。

在本文描述的实现中,权重矩阵用一个稠密矩阵类进行封装,数据样本则通过一个基于 HashMap 的稀疏向量 SparseVector 类来表示,该类包含了点乘、相加等向量操作。这两个类基本上囊括了 softmax regression 的所有计算方面的基本操作。

下面介绍一下我们的实现版本的主要步骤:

主要步骤说明:

  • 用 textFile 读取数据文件并转换成 DataPoint 结构的RDD,调用 cache(),使得 dataPoint 保存在内存中。因为在多次迭代中都需要用到相同的训练数据,所以把它用 RDD 的方式存放在内存中,可以加快迭代速度,这也正是 Spark 的优势所在。cache() 是表示只存放在内存中,是 persist() 的一种默认行为,即 persist(StorageLevel.MEMORY_ONLY),在内存空间允许的情况下,这是最高效的一种缓存方式。persist() 还支持多种 其他方式
  • 对于权重矩阵,由于其将近400M,我们采用 Spark 的 broadcast 方式把它传递给 worker,并且由于权重矩阵在每次迭代后都会更新,因此更新后都会重新 broadcast 一份新的数据。权重矩阵的计算是我们在实现过程中遇到的比较棘手的问题之一。
  • 对于梯度的计算,我们用 flatMap 的方式进行计算,即输入一个 dataPoint 数据和权重矩阵,输出m 个 k 维的列向量(用(列id,向量)的 pair 进行输出),其中 m 是指该 dataPoint 的 x 向量中不为 0 的列数;然后再用reduceByKey,把相同列 id 的向量进行加和;最后 collect 到 master 中逐列更新权重矩阵,迭代周期结束。

3. 性能

我们在公司云梯1的yarn集群上进行了Spark的性能测试。测试数据集有100类,包含130万条数据,每条数据是近100万维的稀疏向量。在这个测试集上,我们测试了我们开发的softmax regression的运行时间,结果见图1和图2。

softmax_1 

 

图1

图1显示的是softmax regression在yarn上分配30个worker,每个worker 12G内存时,进行50次迭代的执行时间,图中的横坐标代表第 i 次迭代,纵坐标代表执行时间,单位是秒。从图中可以看到,每次迭代的执行时间存在一定的波动,但相对还是比较稳定的,基本维持在400秒上下。

softmax_2 

 

图2

图2显示的是softmax regression在yarn上分配10、20、30个worker,每个worker 12G内存时,进行10次迭代各自的执行时间。同样的,图中的横坐标代表第 i 次迭代,纵坐标代表执行时间,单位是秒。从图中可以看到,随着分配的worker数量的增加,执行时间相应减少,但并没有呈现线性的关系。

4. 总结

总的来说,使用Spark进行分布式开发,代码还是比较简短的(我们在前面的代码中没有贴出所有实现,但整体来讲,由于不需要介入数据的传输逻辑控制、容错等,还是相对简单一些)。当然,由于Spark的计算、存储都在内存中,因此当数据量很大时,想达到较好的性能,对开发人员还是有一定的要求,调优、调试都需要花费不少投入才行。而从性能上来说,目前的测试结果表明Spark目前的版本还是比较稳定的。至于计算时间方面,则取决于算法的并行方案、机器数量、机器内存数等多方面的条件,本文的实验说明随着机器数量、内存的增加,计算时间有明显的减少。而算法的并行方案,目前还没有深入的研究,还请有经验的同学指点一二。

Advertisements
Categories: life is fun
  1. No comments yet.
  1. No trackbacks yet.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: