Java 開發人員的機器學習,第 1 部分
在機器學習中,成本函數
成本函數顯示模型與訓練資料的擬合程度。為了確定上面所示的目標函數的成本,有必要計算每個範例 house 的平方誤差
例如,面積為1330 的房屋的實際價格 = 6,500,000 €。訓練後的目標函數預測的房價之間的差異為€7,032,478:差異(或誤差)為€532,478。您也可以在上圖中看到這種差異。每個價格區域訓練對的差異(或誤差)顯示為垂直紅色虛線。計算出經過訓練的目標函數的成本後,您需要對範例中每個房屋的平方誤差求和並計算主值。價格值越小
因此,theta 向量的參數將隨著演算法的每次迭代而改進。學習係數α指定每次迭代的計算次數。可以一直進行這些計算,直到找到「好的」theta 值。例如,下面的線性迴歸函數具有三個 theta 參數:
在每次迭代中,將為每個 theta 參數計算一個新值:、和。每次迭代之後,可以使用新的 theta 向量{θ 0 , θ 1 , θ 2 }創建一個新的、更合適的實作。清單-4顯示了梯度衰減演算法的 Java 程式碼。回歸函數的 Theta 將使用訓練資料、標記資料、學習率進行訓練。結果將是使用 theta 參數改進的目標函數。該方法將一次又一次地調用,傳遞新的目標函數和先前計算中的新 theta 參數。這些呼叫將重複,直到配置的目標函數達到最小穩定狀態:
在這種情況下,即使經過 500-600 次迭代後成本不再顯著下降,目標函數仍然不是最優的。這表明存在差異。在機器學習中,術語「不一致」用於表示學習演算法沒有找到資料中的潛在趨勢。根據現實生活經驗,較大房產的每平方公尺價格可能會下降。由此我們可以得出結論,用於目標函數學習過程的模型與資料擬合得不夠好。這種差異通常是由於模型過度簡化所造成的。這發生在我們的例子中,目標函數太簡單了,它使用單一參數進行分析 - 房子的面積。但這些資訊不足以準確預測房屋的價格。
使用多個特徵需要特徵縮放,特徵縮放用於標準化不同特徵的範圍。因此, size 2屬性的取值範圍明顯大於size屬性的值範圍。如果沒有特徵縮放,大小2將不適當地影響成本函數。size 2屬性所引入的誤差將明顯大於size屬性所引入的誤差。下面給出一個簡單的特徵縮放演算法:
然而,過度擬合模型在訓練資料上表現良好,但在真實的未知資料上表現不佳。有多種方法可以避免過度擬合。
目標函數估計
讓我們回想一下,目標函數hθ
,也稱為預測函數,是準備或訓練過程的結果。從數學上講,挑戰是找到一個以變數作為輸入х
並傳回預測值的函數у
。
(J(θ))
用於計算給定目標函數的誤差值或「成本」。
(i)
。誤差是範例中房屋的計算值у
與實際值之間的差距。 y
i
(J(θ))
,我們的目標函數的預測就越準確。清單3顯示了成本函數的簡單 Java 實現,它將目標函數、訓練資料列表以及與其關聯的標籤作為輸入。預測值將在循環中計算,誤差將透過減去實際價格值(從標籤中取得)來計算。隨後,將誤差的平方相加並計算誤差值。成本將作為類型值傳回double
:
清單-3
public static double cost(Function<ltDouble[], Double> targetFunction,
List<ltDouble[]> dataset,
List<ltDouble> labels) {
int m = dataset.size();
double sumSquaredErrors = 0;
// рассчет квадрата ошибки («разницы») для каждого тренировочного примера и //добавление его к сумме
for (int i = 0; i < m; i++) {
// получаем вектор признаков из текущего примера
Double[] featureVector = dataset.get(i);
// предсказываем meaning и вычисляем ошибку базируясь на реальном
//значении (метка)
double predicted = targetFunction.apply(featureVector);
double label = labels.get(i);
double gap = predicted - label;
sumSquaredErrors += Math.pow(gap, 2);
}
// Вычисляем и возращаем meaning ошибки (чем меньше тем лучше)
return (1.0 / (2 * m)) * sumSquaredErrors;
}
有興趣閱讀有關 Java 的內容嗎?加入Java 開發者小組! |
學習目標函數
雖然成本函數有助於評估目標函數和 theta 參數的質量,但您仍然需要找到最合適的 theta 參數。您可以為此使用梯度下降演算法。梯度下降
梯度下降最小化成本函數。這意味著它用於(J(θ))
根據訓練資料查找具有最小成本的 theta 參數。以下是計算新的、更合適的 theta 值的簡化演算法:
θ0
θ1
θ2
LinearRegressionFunction
(α)
train()
清單 4
public static LinearRegressionFunction train(LinearRegressionFunction targetFunction,
List<ltDouble[]> dataset,
List<ltDouble> labels,
double alpha) {
int m = dataset.size();
double[] thetaVector = targetFunction.getThetas();
double[] newThetaVector = new double[thetaVector.length];
// вычисление нового значения тета для каждого element тета массива
for (int j = 0; j < thetaVector.length; j++) {
// сумируем разницу ошибки * признак
double sumErrors = 0;
for (int i = 0; i < m; i++) {
Double[] featureVector = dataset.get(i);
double error = targetFunction.apply(featureVector) - labels.get(i);
sumErrors += error * featureVector[j];
}
//вычисляем новые значения тета
double gradient = (1.0 / m) * sumErrors;
newThetaVector[j] = thetaVector[j] - alpha * gradient;
}
return new LinearRegressionFunction(newThetaVector);
}
為了確保成本不斷降低,您可以J(θ)
在每個訓練步驟之後運行成本函數。每次迭代之後,成本應該會降低。如果這種情況沒有發生,則表示學習係數的值太大,演算法根本就錯過了最小值。在這種情況下,梯度衰減演算法會失敗。下圖顯示了使用新計算的 theta 參數(從起始 theta 向量 開始)的目標函數{1.0, 1.0}
。左列顯示了 50 次迭代後的預測函數圖;200 次重複後的中間列;以及 1000 次重複後的右列。從這些我們可以看到,每次迭代後價格都會下降,並且新的目標函數擬合得越來越好。重複 500-600 次後,theta 參數不再顯著變化,價格達到穩定的平台。此後,透過這種方式就無法提高目標函數的精確度。
添加功能並縮放它們
如果您發現您的目標函數與您要解決的問題不對應,則需要進行調整。修正不一致的常見方法是向特徵向量添加附加特徵。在房屋價格的範例中,您可以新增房間數量或房屋年齡等特徵。也就是說,{size}
您可以使用具有多個值的向量,而不是使用具有一個特徵值的向量來描述一棟房子,例如,{size, number-of-rooms, age}.
在某些情況下,可用訓練資料中的特徵數量不夠。那麼值得嘗試使用使用現有多項式特徵計算的多項式特徵。例如,您有機會擴展用於確定房屋價格的目標函數,使其包含平方公尺 (x2) 的計算特徵:
FeaturesScaling
該演算法在下面範例程式碼的 類別中實作。本課程FeaturesScaling
介紹了一種用於建立根據訓練資料調整的縮放函數的商業方法。在內部,訓練資料實例用於計算平均值、最小值和最大值。產生的函數採用特徵向量並產生具有縮放特徵的新特徵向量。特徵縮放對於學習過程和預測過程都是必要的,如下所示:
// создание массива данных
List<ltDouble[]> dataset = new ArrayList<>();
dataset.add(new Double[] { 1.0, 90.0, 8100.0 }); // feature vector of house#1
dataset.add(new Double[] { 1.0, 101.0, 10201.0 }); // feature vector of house#2
dataset.add(new Double[] { 1.0, 103.0, 10609.0 }); // ...
//...
// создание меток
List<ltDouble> labels = new ArrayList<>();
labels.add(249.0); // price label of house#1
labels.add(338.0); // price label of house#2
labels.add(304.0); // ...
//...
// создание расширенного списка признаков
Function<ltDouble[], Double[]> scalingFunc = FeaturesScaling.createFunction(dataset);
List<ltDouble[]> scaledDataset = dataset.stream().map(scalingFunc).collect(Collectors.toList());
// создаем функцию которая инициализирует теты и осуществляет обучение //используя коэффициент обучения 0.1
LinearRegressionFunction targetFunction = new LinearRegressionFunction(new double[] { 1.0, 1.0, 1.0 });
for (int i = 0; i < 10000; i++) {
targetFunction = Learner.train(targetFunction, scaledDataset, labels, 0.1);
}
// делаем предсказание стоимости дома с площадью 600 m2
Double[] scaledFeatureVector = scalingFunc.apply(new Double[] { 1.0, 600.0, 360000.0 });
double predictedPrice = targetFunction.apply(scaledFeatureVector);
隨著越來越多的特徵被添加,目標函數的擬合度也會增加,但要小心。如果你走得太遠並添加太多特徵,你最終可能會學習到一個過度擬合的目標函數。
過度匹配和交叉驗證
當目標函數或模型與訓練資料擬合得太好以至於捕獲了訓練資料中的雜訊或隨機變化時,就會發生過度擬合。下面最右圖顯示了過度擬合的範例:- 使用更大的資料集進行訓練。
- 使用較少的功能,如上圖所示。
- 使用考慮正規化的改進機器學習演算法。
- 訓練資料;
- 驗證數據;
- 測試數據。
機器學習工具與Weka框架
大多數框架和函式庫都提供了大量的機器學習演算法。此外,它們還提供了一個方便的高級介面來訓練、測試和處理資料模型。Weka 是最受歡迎的 JVM 框架之一。Weka 是一個實用的 Java 函式庫,包含用於驗證模型的圖形測試。下面的範例使用 Weka 庫建立包含特徵和標籤的訓練資料集。方法setClassIndex()
- 用於標記。在 Weka 中,標籤被定義為一個類別:
// определяем атрибуты для признаков и меток
ArrayList<ltAttribute> attributes = new ArrayList<>();
Attribute sizeAttribute = new Attribute("sizeFeature");
attributes.add(sizeAttribute);
Attribute squaredSizeAttribute = new Attribute("squaredSizeFeature");
attributes.add(squaredSizeAttribute);
Attribute priceAttribute = new Attribute("priceLabel");
attributes.add(priceAttribute);
// создаем и заполняем список признаков 5000 примеров
Instances trainingDataset = new Instances("trainData", attributes, 5000);
trainingDataset.setClassIndex(trainingSet.numAttributes() - 1);
Instance instance = new DenseInstance(3);
instance.setValue(sizeAttribute, 90.0);
instance.setValue(squaredSizeAttribute, Math.pow(90.0, 2));
instance.setValue(priceAttribute, 249.0);
trainingDataset.add(instance);
Instance instance = new DenseInstance(3);
instance.setValue(sizeAttribute, 101.0);
...
資料集和樣本物件可以保存並從檔案載入。Weka 使用ARFF(屬性關係檔格式),Weka 圖形基準支援此格式。此資料集用於訓練 Weka 中稱為分類器的目標函數。首先,您必須定義目標函數。下面的程式碼LinearRegression
將建立分類器的實例。此分類器將使用 進行訓練buildClassifier()
。此方法buildClassifier()
根據訓練資料選擇 theta 參數,以搜尋最佳目標模型。使用Weka,您不必擔心設定學習率或迭代次數。Weka 也獨立執行特徵縮放。
Classifier targetFunction = new LinearRegression();
targetFunction.buildClassifier(trainingDataset);
一旦完成這些設置,目標函數就可以用來預測房屋的價格,如下所示:
Instances unlabeledInstances = new Instances("predictionset", attributes, 1);
unlabeledInstances.setClassIndex(trainingSet.numAttributes() - 1);
Instance unlabeled = new DenseInstance(3);
unlabeled.setValue(sizeAttribute, 1330.0);
unlabeled.setValue(squaredSizeAttribute, Math.pow(1330.0, 2));
unlabeledInstances.add(unlabeled);
double prediction = targetFunction.classifyInstance(unlabeledInstances.get(0));
Weka 提供了一個類別Evaluation
來測試經過訓練的分類器或模型。在下面的程式碼中,使用選定的驗證資料數組來避免錯誤結果。測量結果(錯誤成本)將顯示在控制台上。通常,評估結果用於比較使用不同機器學習演算法或其變體訓練的模型:
Evaluation evaluation = new Evaluation(trainingDataset);
evaluation.evaluateModel(targetFunction, validationDataset);
System.out.println(evaluation.toSummaryString("Results", false));
上面的範例使用線性迴歸,它根據輸入值預測數值,例如房屋價格。線性迴歸支持連續數值的預測。要預測二進位值(“是”和“否”),您需要使用其他機器學習演算法。例如,決策樹、神經網路或邏輯迴歸。
// использование логистической регрессии
Classifier targetFunction = new Logistic();
targetFunction.buildClassifier(trainingSet);
例如,您可以使用其中一種演算法來預測電子郵件是否為垃圾郵件,或預測天氣,或預測房屋是否會暢銷。如果你想教你的演算法預測天氣或房子的銷售速度,你需要一個不同的資料集,例如topseller:
// использование атрибута маркера topseller instead of атрибута маркера цена
ArrayList<string> classVal = new ArrayList<>();
classVal.add("true");
classVal.add("false");
Attribute topsellerAttribute = new Attribute("topsellerLabel", classVal);
attributes.add(topsellerAttribute);
該資料集將用於訓練新的分類器topseller
。一旦經過訓練,預測呼叫應該會傳回一個可用於取得預測值的標記類別索引。
int idx = (int) targetFunction.classifyInstance(unlabeledInstances.get(0));
String prediction = classVal.get(idx);
GO TO FULL VERSION