Alink漫談(一) : 從KMeans算法實現不同看Alink設計思想
Alink漫談(一) : 從KMeans算法實現不同看Alink設計思想
0x00 摘要
Alink 是阿里巴巴基于實時計算引擎 Flink 研發的新一代機器學習算法平臺,是業界首個同時支持批式算法、流式算法的機器學習平臺。本文將帶領大家從多重角度出發來分析推測Alink的設計思路。
因為Alink的公開資料太少,所以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會隨時更新。
0x01 Flink 是什么
Apache Flink是由Apache軟件基金會開發的開源流處理框架,它通過實現了 Google Dataflow 流式計算模型實現了高吞吐、低延遲、高性能兼具實時流式計算框架。
其核心是用Java和Scala編寫的分布式流數據流引擎。Flink以數據并行和流水線方式執行任意流數據程序,Flink的流水線運行時系統可以執行批處理和流處理程序。此外,Flink的運行時本身也支持迭代算法的執行。
0x02 Alink 是什么
Alink 是阿里巴巴計算平臺事業部PAI團隊從2017年開始基于實時計算引擎 Flink 研發的新一代機器學習算法平臺,提供豐富的算法組件庫和便捷的操作框架,開發者可以一鍵搭建覆蓋數據處理、特征工程、模型訓練、模型預測的算法模型開發全流程。項目之所以定為Alink,是取自相關名稱(Alibaba, Algorithm, AI, Flink, Blink)的公共部分。
借助Flink在批流一體化方面的優勢,Alink能夠為批流任務提供一致性的操作。在2017年初,阿里團隊通過調研團隊看到了Flink在批流一體化方面的優勢及底層引擎的優秀性能,于是基于Flink重新設計研發了機器學習算法庫,即Alink平臺。該平臺于2018年在阿里集團內部上線,隨后不斷改進完善,在阿里內部錯綜復雜的業務場景中鍛煉成長。
0x03 Alink設計思路
因為目前關于Alink設計的公開資料比較少,我們手頭只有其源碼,看起來只能從代碼反推。但是世界上的事物都不是孤立的,我們還有其他角度來幫助我們判斷推理。所以下面就讓我們來進行推斷。
1. 白手起家
FlinkML 是 Flink 社區現存的一套機器學習算法庫,這一套算法庫已經存在很久而且更新比較緩慢。
Alink團隊起初面臨的抉擇是:是否要基于 Flink ML 進行開發,或者對 Flink ML進行更新。
經過研究,Alink團隊發現,Flink ML 其僅支持10余種算法,支持的數據結構也不夠通用,在算法性能方面做的優化也比較少,而且其代碼也很久沒有更新。所以,他們放棄了基于舊版FlinkML進行改進、升級的想法,決定基于Flink重新設計研發機器學習算法庫。
所以我們要分析的就是如何從無到有設計出一個新的機器學習平臺/框架。
2. 替代品如何造成威脅
因為Alink是市場的新進入者,所以Alink的最大問題就是如何替代市場上的現有產品。
邁克爾·波特用 “替代品威脅” 來解釋用戶的整個替代邏輯,當新產品能牢牢掌握住這一點,就有可能在市場上獲得非常好的表現,打敗競爭對手。
假如現在想從0到1構建一個機器學習庫或者機器學習框架,那么我們需要從商業意識和商業邏輯出發,來思考這個產品的價值所在,就能對這個產品做個比較精確的定義,從而能夠確定產品路線。
產品需要解決應用環境下的綜合性問題,產品的價值體現,可以分拆了三個維度。
- 用戶的角度:價值體現在用戶使用,獲取產品的意愿。這個就是換用成本的問題,一旦換用成本過高,這個產品就很難成功。
- 競爭對手的角度: 產品的競爭力,最終都體現為用戶為了獲取該產品愿意支付的最高成本上限,當一個替代品進入市場,必須有能給用戶足夠的洞理驅使用戶換用替代品。
- 企業的角度:站在企業的角度,實際就是成本結構和收益的規模性問題 。
下面就讓我們逐一分析。
3. 用戶角度看設計
這個就是換用成本的問題,一旦換用成本過高,這個產品就很難成功。
Alink大略有兩種用戶:算法工程師,應用工程師。
Alink算法工程師特指實現機器學習算法的工程師。Alink應用工程師就是應用Alink AI算法做業務的工程師。這兩類用戶的換用成本都是Alink需要考慮的。
新產品對于用戶來說,有兩個大的問題:產品底層邏輯和開發工具。一個優秀的新產品絕對不能在這兩個問題上增加用戶的換用成本。
底層邏輯Flink
Flink這個平臺博大精深,無論是熟悉其API還是深入理解系統架構都不是容易的事情。如果Alink用戶還需要熟悉Flink,那勢必造成ALink用戶的換用成本,所以這點應該盡量避免。
-
對于算法工程師,他們應該主要把思路集中在算法上,而盡量不用關心Flink內部的細節,如果一定要熟悉Flink,那么越少越好;
-
對于應用工程師,他們主要的需求就是API接口越簡單越好,他們最理想的狀態應該是:完全感覺不到Flink的存在。
綜上所述,Alink的原則之一應該是 :算法的歸算法,Flink的歸Flink,盡量屏蔽AI算法和Flink之間的聯系。
開發工具
開發工具就是究竟用什么語言開發。Flink的開發語言主要是JAVA,SCALA,Python。而機器學習世界中主要還是Python。
-
首先要排除SCALA。因為Scala 是一門很難掌握的語言,它的規則是基于數學類型理論的,學習曲線相當陡峭。一個能夠領會規則和語言特性的優秀程序員,使用 Scala 會比使用 Java 更高效,但是一個普通程序員的生產力,從功能實現上來看,效率則會相反。
讓我們看看基于Flink的原生KMeans SCALA代碼,很多人看了之后恐怕都會懵圈。
val finalCentroids = centroids.iterate(params.getInt("iterations", 10)) { currentCentroids => val newCentroids = points .map(new SelectNearestCenter).withBroadcastSet(currentCentroids, "centroids") .map { x => (x._1, x._2, 1L) }.withForwardedFields("_1; _2") .groupBy(0) .reduce { (p1, p2) => (p1._1, p1._2.add(p2._2), p1._3 + p2._3) }.withForwardedFields("_1") .map { x => new Centroid(x._1, x._2.div(x._3)) }.withForwardedFields("_1->id") newCentroids } -
其次是選擇JAVA還是Python開發具體算法。Alink內部肯定進行了很多權宜和抉擇。因為這個不單單是哪個語言本身更合適,也涉及到Alink團隊內部有哪些資源,比如是JAVA工程師更多還是Python更多。最終Alink選擇了JAVA來開發算法。
-
最后是API。這個就沒有什么疑問了,Alink提供了Python和JAVA兩種語言的API,直接可參見GitHub的介紹。
在 PyAlink 中,算法組件提供的接口基本與 Java API 一致,即通過默認構造方法創建一個算法組件,然后通過
setXXX設置參數,通過link/linkTo/linkFrom與其他組件相連。 這里利用 Jupyter 的自動補全機制可以提供書寫便利。
另外,如果采用JAVA或者Python,肯定有大量現有代碼可以修改復用。如果采用SCALA,就難以復用之前的積累。
綜上所述,Alink的原則之一應該是 :采用最簡單,最常見的開發語言和設計思維。
4. 競爭對手角度看設計
Alink的競爭對手大略可以認為是Spark ML, Flink ML, Scikit-learn。
他們是市場上的現有力量,擁有大量的用戶。用戶已經熟悉了這些競爭對手的設計思路,開發策略,基本概念和API。除非Alink能夠提供一種神奇簡便的API,否則Alink應該在設計上最大程度借鑒這些競爭對手。
比如機器學習開發中有如下常見概念:Transformer,Estimator,PipeLine,Parameter。這些概念 Alink 應該盡量提供。
綜上所述,**Alink的原則之一應該是 :盡量借鑒市面上通用的設計思路和開發模式,讓開發者無縫切換 **。
從 Alink的目錄結構中 ,我們可以看出,Alink確實提供了這些常見概念。
比如 Pipeline,Trainer,Model,Estimator。我們會在后續文章中再詳細介紹這些概念。
./java/com/alibaba/alink:
common operator params pipeline
./java/com/alibaba/alink/params:
associationrule evaluation nlp regression statistics
classification feature onlinelearning shared tuning
clustering io outlier similarity udf
dataproc mapper recommendation sql validators
./java/com/alibaba/alink/pipeline:
EstimatorBase.java ModelBase.java Trainer.java feature
LocalPredictable.java ModelExporterUtils.java TransformerBase.java nlp
LocalPredictor.java Pipeline.java classification recommendation
MapModel.java PipelineModel.java clustering regression
MapTransformer.java PipelineStageBase.java dataproc tuning
5. 企業角度看設計
這是成本結構和收益的規模性問題。從而決定了Alink在開發時候,必須盡量提高開發工程師的效率,提高生產力。前面提到的棄用SCALA,部分也出于這個考慮。
挑戰集中在:
- 如何在對開發者最大程度屏蔽Flink的情況下,依然利用好Flink的各種能力。
- 如何構建一套相應打法和戰術體系,即middleware或者adapter,讓用戶基于此可以快速開發算法
舉個例子:
-
肯定有個別開發者,其對Flink特別熟悉,他們可以運用各種Flink API和函數編程思維開發出高效率的算法。這種開發者,我們可以稱為是武松武都頭。他們類似特種兵,能上戰場沖鋒陷陣,也能吊打白額大蟲。
-
但是絕大多數開發者對Flink不熟悉,他們更熟悉AI算法和命令式編程思路。這種開發者我們可以認為他們屬于八十萬禁軍或者是玄甲軍,北府兵,魏武卒,背嵬軍。這種才是實際開發中的主力部隊和常規套路。
我們需要針對八十萬禁軍,讓林沖林教頭設計出一套適合正規作戰的槍棒打法。或者針對背嵬軍,讓岳飛岳元帥設計一套馬軍沖陣機制。
因此,**Alink的原則之一應該是 :構建一套戰術打法(middleware或者adapter),即屏蔽了Flink,又可以利用好Flink,還可以讓用戶基于此可以快速開發算法 **。
我們想想看大概有哪些基礎工作需要做:
- 如何初始化
- 如果通信
- 如何分割代碼,如何廣播代碼
- 如果分割數據,如何廣播數據
- 如何迭代算法
- ......
讓我們看看Alink做了哪些努力,這點從其目錄結構可以看出有queue,operator,mapper等等構建架構所必須的數據結構:
./java/com/alibaba/alink/common:
MLEnvironment.java linalg MLEnvironmentFactory.java mapper
VectorTypes.java model comqueue utils io
./java/com/alibaba/alink/operator:
AlgoOperator.java common batch stream
其中最重要的概念是BaseComQueue,這是把通信或者計算抽象成ComQueueItem,然后把ComQueueItem串聯起來形成隊列。這樣就形成了面向迭代計算場景的一套迭代通信計算框架。其他數據結構大多是圍繞著BaseComQueue來具體運作。
/**
* Base class for the com(Computation && Communicate) queue.
*/
public class BaseComQueue<Q extends BaseComQueue<Q>> implements Serializable {
/**
* All computation or communication functions.
*/
private final List<ComQueueItem> queue = new ArrayList<>();
/**
* sessionId for shared objects within this BaseComQueue.
*/
private final int sessionId = SessionSharedObjs.getNewSessionId();
/**
* The function executed to decide whether to break the loop.
*/
private CompareCriterionFunction compareCriterion;
/**
* The function executed when closing the iteration
*/
private CompleteResultFunction completeResult;
/**
* Max iteration count.
*/
private int maxIter = Integer.MAX_VALUE;
private transient ExecutionEnvironment executionEnvironment;
}
MLEnvironment 是另外一個重要的類。其封裝了Flink開發所必須要的運行上下文。用戶可以通過這個類來獲取各種實際運行環境,可以建立table,可以運行SQL語句。
/**
* The MLEnvironment stores the necessary context in Flink.
* Each MLEnvironment will be associated with a unique ID.
* The operations associated with the same MLEnvironment ID
* will share the same Flink job context.
*/
public class MLEnvironment {
private ExecutionEnvironment env;
private StreamExecutionEnvironment streamEnv;
private BatchTableEnvironment batchTableEnv;
private StreamTableEnvironment streamTableEnv;
}
6. 設計原則總結
下面我們可以總結下Alink部分設計原則
-
算法的歸算法,Flink的歸Flink,盡量屏蔽AI算法和Flink之間的聯系。
-
采用最簡單,最常見的開發語言。
-
盡量借鑒市面上通用的設計思路和開發模式,讓開發者無縫切換。
-
構建一套戰術打法(middleware或者adapter),即屏蔽了Flink,又可以利用好Flink,還可以讓用戶基于此可以快速開發算法。
0x04 KMeans算法實現看設計
Flink和Alink源碼中,都提供了KMeans算法例子,所以我們就從KMeans入手看看Flink原生算法和Alink算法實現的區別。為了統一標準,我們都選用JAVA版本的算法實現。
1. KMeans算法
KMeans算法的思想比較簡單,假設我們要把數據分成K個類,大概可以分為以下幾個步驟:
- 隨機選取k個點,作為聚類中心;
- 計算每個點分別到k個聚類中心的聚類,然后將該點分到最近的聚類中心,這樣就行成了k個簇;
- 再重新計算每個簇的質心(均值);
- 重復以上2~4步,直到質心的位置不再發生變化或者達到設定的迭代次數。
2. Flink KMeans例子
K-Means 是迭代的聚類算法,初始設置K個聚類中心
- 在每一次迭代過程中,算法計算每個數據點到每個聚類中心的歐式距離
- 每個點被分配到它最近的聚類中心
- 隨后每個聚類中心被移動到所有被分配的點
- 移動的聚類中心被分配到下一次迭代
- 算法在固定次數的迭代之后終止(在本實現中,參數設置)
- 或者聚類中心在迭代中不在移動
- 本項目是工作在二維平面的數據點上
- 它計算分配給集群中心的數據點
- 每個數據點都使用其所屬的最終集群(中心)的id進行注釋。
下面給出部分代碼,具體算法解釋可以在注釋中看到。
這里主要采用了Flink的批量迭代。其調用 DataSet 的 iterate(int) 方法創建一個 BulkIteration,迭代以此為起點,返回一個 IterativeDataSet,可以用常規運算符進行轉換。迭代調用的參數 int 指定最大迭代次數。
IterativeDataSet 調用 closeWith(DataSet) 方法來指定哪個轉換應該反饋到下一個迭代,可以選擇使用 closeWith(DataSet,DataSet) 指定終止條件。如果該 DataSet 為空,則它將評估第二個 DataSet 并終止迭代。如果沒有指定終止條件,則迭代在給定的最大次數迭代后終止。
public class KMeans {
public static void main(String[] args) throws Exception {
// Checking input parameters
final ParameterTool params = ParameterTool.fromArgs(args);
// set up execution environment
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.getConfig().setGlobalJobParameters(params); // make parameters available in the web interface
// get input data:
// read the points and centroids from the provided paths or fall back to default data
DataSet<Point> points = getPointDataSet(params, env);
DataSet<Centroid> centroids = getCentroidDataSet(params, env);
// set number of bulk iterations for KMeans algorithm
IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));
DataSet<Centroid> newCentroids = points
// compute closest centroid for each point
.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
// count and sum point coordinates for each centroid
.map(new CountAppender())
.groupBy(0).reduce(new CentroidAccumulator())
// compute new centroids from point counts and coordinate sums
.map(new CentroidAverager());
// feed new centroids back into next iteration
DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
DataSet<Tuple2<Integer, Point>> clusteredPoints = points
// assign points to final clusters
.map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");
// emit result
if (params.has("output")) {
clusteredPoints.writeAsCsv(params.get("output"), "\n", " ");
// since file sinks are lazy, we trigger the execution explicitly
env.execute("KMeans Example");
} else {
System.out.println("Printing result to stdout. Use --output to specify output path.");
clusteredPoints.print();
}
}
3. Alink KMeans示例
Alink中,Kmeans是分布在若干文件中,這里我們提取部分代碼來對照。
KMeansTrainBatchOp
這里是算法主程序,這里倒是看起來十分清爽干凈,但實際上是沒有這么簡單,Alink在其背后做了大量的基礎工作。
可以看出,算法實現的主要工作是:
- 構建了一個IterativeComQueue(BaseComQueue的缺省實現)。
- 初始化數據,這里有兩種辦法:initWithPartitionedData將DataSet分片緩存至內存。initWithBroadcastData將DataSet整體緩存至每個worker的內存。
- 將計算分割為若干ComputeFunction,比如KMeansPreallocateCentroid / KMeansAssignCluster / KMeansUpdateCentroids ...,串聯在IterativeComQueue。
- 運用AllReduce通信模型完成了數據同步。
public final class KMeansTrainBatchOp extends BatchOperator <KMeansTrainBatchOp>
implements KMeansTrainParams <KMeansTrainBatchOp> {
static DataSet <Row> iterateICQ(...省略...) {
return new IterativeComQueue()
.initWithPartitionedData(TRAIN_DATA, data)
.initWithBroadcastData(INIT_CENTROID, initCentroid)
.initWithBroadcastData(KMEANS_STATISTICS, statistics)
.add(new KMeansPreallocateCentroid())
.add(new KMeansAssignCluster(distance))
.add(new AllReduce(CENTROID_ALL_REDUCE))
.add(new KMeansUpdateCentroids(distance))
.setCompareCriterionOfNode0(new KMeansIterTermination(distance, tol))
.closeWith(new KMeansOutputModel(distanceType, vectorColName, latitudeColName, longitudeColName))
.setMaxIter(maxIter)
.exec();
}
}
KMeansPreallocateCentroid
預先分配聚類中心
public class KMeansPreallocateCentroid extends ComputeFunction {
public void calc(ComContext context) {
if (context.getStepNo() == 1) {
List<FastDistanceMatrixData> initCentroids = (List)context.getObj("initCentroid");
List<Integer> list = (List)context.getObj("statistics");
Integer vectorSize = (Integer)list.get(0);
context.putObj("vectorSize", vectorSize);
FastDistanceMatrixData centroid = (FastDistanceMatrixData)initCentroids.get(0);
Preconditions.checkArgument(centroid.getVectors().numRows() == vectorSize, "Init centroid error, size not equal!");
context.putObj("centroid1", Tuple2.of(context.getStepNo() - 1, centroid));
context.putObj("centroid2", Tuple2.of(context.getStepNo() - 1, new FastDistanceMatrixData(centroid)));
context.putObj("k", centroid.getVectors().numCols());
}
}
}
KMeansAssignCluster
為每個點(point)計算最近的聚類中心,為每個聚類中心的點坐標的計數和求和
/**
* Find the closest cluster for every point and calculate the sums of the points belonging to the same cluster.
*/
public class KMeansAssignCluster extends ComputeFunction {
@Override
public void calc(ComContext context) {
Integer vectorSize = context.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
Integer k = context.getObj(KMeansTrainBatchOp.K);
// get iterative coefficient from static memory.
Tuple2<Integer, FastDistanceMatrixData> stepNumCentroids;
if (context.getStepNo() % 2 == 0) {
stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID1);
} else {
stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID2);
}
if (null == distanceMatrix) {
distanceMatrix = new DenseMatrix(k, 1);
}
double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
if (sumMatrixData == null) {
sumMatrixData = new double[k * (vectorSize + 1)];
context.putObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE, sumMatrixData);
}
Iterable<FastDistanceVectorData> trainData = context.getObj(KMeansTrainBatchOp.TRAIN_DATA);
if (trainData == null) {
return;
}
Arrays.fill(sumMatrixData, 0.0);
for (FastDistanceVectorData sample : trainData) {
KMeansUtil.updateSumMatrix(sample, 1, stepNumCentroids.f1, vectorSize, sumMatrixData, k, fastDistance,
distanceMatrix);
}
}
}
KMeansUpdateCentroids
基于點計數和坐標,計算新的聚類中心。
/**
* Update the centroids based on the sum of points and point number belonging to the same cluster.
*/
public class KMeansUpdateCentroids extends ComputeFunction {
@Override
public void calc(ComContext context) {
Integer vectorSize = context.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
Integer k = context.getObj(KMeansTrainBatchOp.K);
double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
Tuple2<Integer, FastDistanceMatrixData> stepNumCentroids;
if (context.getStepNo() % 2 == 0) {
stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID2);
} else {
stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID1);
}
stepNumCentroids.f0 = context.getStepNo();
context.putObj(KMeansTrainBatchOp.K,
updateCentroids(stepNumCentroids.f1, k, vectorSize, sumMatrixData, distance));
}
}
4. 區別
代碼量
通過下面的分析可以看出,從實際業務代碼量角度說,其實差別不大。
- Flink的代碼量少;
- Alink的代碼量雖然大,但其本質就是把Flink版本的一些用戶定義類分離到自己不同類中,并且有很多讀取環境變量的代碼;
所以Alink代碼只能說比Flink原生實現略大。
耦合度
這里指的是與Flink的耦合度。能看出來Flink的KMeans算法需要大量的Flink類。而Alink被最大限度屏蔽了。
- Flink 算法需要引入的flink類如下
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
- Alink 算法需要引入的flink類如下,可以看出來ALink使用的都是基本設施,不涉及算子和復雜API,這樣就減少了用戶的負擔。
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;
編程模式
這是一個主要的區別。
- Flink 使用的是函數式編程。這種范式相對新穎,很多工程師不熟悉。
- Alink 依然使用了命令式編程。這樣的好處在于,大量現有算法代碼可以復用,也更符合絕大多數工程師的習慣。
- Flink 通過Flink的各種算子完成了操作,比如IterativeDataSet實現了迭代。但這種實現對于不熟悉Flink的工程師是個折磨。
- Alink 基于自己的框架,把計算代碼總結成了若干ComputeFunction,然后通過IterativeComQueue完成了具體算法的迭代。這樣用戶其實對Flink是不需要過多深入理解。
在下一期文章中,將從源碼角度分析驗證本文的設計思路。
0x05 參考
Spark ML簡介之Pipeline,DataFrame,Estimator,Transformer
斬獲GitHub 2000+ Star,阿里云開源的 Alink 機器學習平臺如何跑贏雙11數據“博弈”?|AI 技術生態論
浙公網安備 33010602011771號