Uczenie maszynowe dla programistów Java, część 1
Oszacowanie funkcji celu
Przypomnijmy, że funkcja celu
hθ
, zwana także funkcją predykcyjną, jest wynikiem procesu przygotowawczego lub szkoleniowego. Matematycznie wyzwanie polega na znalezieniu funkcji, która przyjmuje zmienną jako dane wejściowe
х
i zwraca przewidywaną wartość
у
.
W uczeniu maszynowym funkcja kosztu
(J(θ))
służy do obliczenia wartości błędu lub „kosztu” danej funkcji celu.
Funkcja kosztu pokazuje, jak dobrze model pasuje do danych uczących. Aby określić koszt funkcji celu pokazanej powyżej, należy obliczyć błąd kwadratowy każdego przykładowego domu
(i)
. Błąd to odległość pomiędzy obliczoną wartością
у
a rzeczywistą wartością
y
domu z przykładu
i
.
Przykładowo realna cena domu o powierzchni
1330 = 6 500 000 € . Różnica między przewidywaną ceną domu na podstawie przeszkolonej funkcji celu wynosi
7 032 478 EUR : różnica (lub błąd) wynosi
532 478 EUR . Tę różnicę widać także na powyższym wykresie. Różnica (lub błąd) jest pokazana jako pionowe przerywane czerwone linie dla każdej pary szkoleniowej obszaru cenowego. Po obliczeniu kosztu wyszkolonej funkcji celu należy zsumować błąd kwadratowy dla każdego domu w przykładzie i obliczyć wartość główną. Im mniejsza wartość ceny
(J(θ))
, tym dokładniejsze będą przewidywania naszej funkcji celu. Listing
3 przedstawia prostą implementację funkcji kosztu w języku Java, która przyjmuje jako dane wejściowe funkcję celu, listę danych szkoleniowych i powiązane z nimi etykiety. Wartości predykcji zostaną obliczone w pętli, a błąd zostanie obliczony poprzez odjęcie rzeczywistej wartości ceny (pobranej z etykiety). Następnie zostanie zsumowany kwadrat błędów i obliczona wartość błędu. Koszt zostanie zwrócony jako wartość typu
double
:
Lista-3
public static double cost(Function<ltDouble[], Double> targetFunction,
List<ltDouble[]> dataset,
List<ltDouble> labels) {
int m = dataset.size();
double sumSquaredErrors = 0;
for (int i = 0; i < m; i++) {
Double[] featureVector = dataset.get(i);
double predicted = targetFunction.apply(featureVector);
double label = labels.get(i);
double gap = predicted - label;
sumSquaredErrors += Math.pow(gap, 2);
}
return (1.0 / (2 * m)) * sumSquaredErrors;
}
Nauka funkcji celu
Chociaż funkcja kosztu pomaga ocenić jakość funkcji docelowej i parametrów theta, nadal trzeba znaleźć najbardziej odpowiednie parametry theta. Można do tego użyć algorytmu opadania gradientowego.
Zejście gradientowe
Zejście gradientowe minimalizuje funkcję kosztu. Oznacza to, że służy do znalezienia parametrów theta, które mają minimalny koszt,
(J(θ))
na podstawie danych szkoleniowych. Oto uproszczony algorytm obliczania nowych, bardziej odpowiednich wartości theta:
Zatem parametry wektora theta będą się poprawiać z każdą iteracją algorytmu. Współczynnik uczenia się α określa liczbę obliczeń w każdej iteracji. Obliczenia te można prowadzić do momentu znalezienia „dobrych” wartości theta. Na przykład poniższa funkcja regresji liniowej ma trzy parametry theta:
Przy każdej iteracji zostanie obliczona nowa wartość dla każdego z parametrów theta: , i . Po każdej iteracji można utworzyć nową, bardziej odpowiednią implementację, używając nowego wektora theta
{θ 0 , θ 1 , θ 2 } . Listing
-4 przedstawia kod Java dla algorytmu zaniku gradientu. Theta dla funkcji regresji będzie trenowana przy użyciu danych uczących, danych markerowych i szybkości uczenia się . Rezultatem będzie ulepszona funkcja celu wykorzystująca parametry theta. Metoda będzie wywoływana wielokrotnie, przekazując nową funkcję celu i nowe parametry theta z poprzednich obliczeń. Wywołania te będą powtarzane, aż skonfigurowana funkcja celu osiągnie minimalny plateau:
θ0
θ1
θ2
LinearRegressionFunction
(α)
train()
Lista-4
public static LinearRegressionFunction train(LinearRegressionFunction targetFunction,
List<ltDouble[]> dataset,
List<ltDouble> labels,
double alpha) {
int m = dataset.size();
double[] thetaVector = targetFunction.getThetas();
double[] newThetaVector = new double[thetaVector.length];
for (int j = 0; j < thetaVector.length; j++) {
double sumErrors = 0;
for (int i = 0; i < m; i++) {
Double[] featureVector = dataset.get(i);
double error = targetFunction.apply(featureVector) - labels.get(i);
sumErrors += error * featureVector[j];
}
double gradient = (1.0 / m) * sumErrors;
newThetaVector[j] = thetaVector[j] - alpha * gradient;
}
return new LinearRegressionFunction(newThetaVector);
}
Aby mieć pewność, że koszt stale maleje, możesz uruchomić funkcję kosztu
J(θ)
po każdym etapie uczenia. Po każdej iteracji koszt powinien spadać. Jeśli tak się nie dzieje, oznacza to, że wartość współczynnika uczenia jest za duża i algorytm po prostu przeoczył wartość minimalną. W takim przypadku algorytm zaniku gradientu zawodzi. Poniższe wykresy przedstawiają funkcję celu przy użyciu nowych, obliczonych parametrów theta, zaczynając od początkowego wektora theta
{1.0, 1.0}
. Lewa kolumna przedstawia wykres funkcji predykcji po 50 iteracjach; środkowa kolumna po 200 powtórzeniach; i prawa kolumna po 1000 powtórzeń. Widzimy z nich, że cena maleje po każdej iteracji, a nowa funkcja celu pasuje coraz lepiej. Po 500-600 powtórzeniach parametry theta nie zmieniają się już znacząco, a cena osiąga stabilny plateau. Po tym nie można w ten sposób poprawić dokładności funkcji celu.
W tym przypadku, mimo że koszt nie spada już znacząco po 500-600 iteracjach, funkcja celu nadal nie jest optymalna. Wskazuje to na
rozbieżność . W uczeniu maszynowym termin „niespójność” używany jest do wskazania, że algorytm uczenia się nie znajduje podstawowych trendów w danych. Bazując na rzeczywistych doświadczeniach, można spodziewać się obniżki ceny za metr kwadratowy w przypadku większych nieruchomości. Na tej podstawie możemy stwierdzić, że model zastosowany w procesie uczenia się funkcji celu nie jest wystarczająco dobrze dopasowany do danych. Rozbieżność wynika często z nadmiernego uproszczenia modelu. Tak się stało w naszym przypadku, funkcja celu jest zbyt prosta i do analizy wykorzystuje jeden parametr - powierzchnię domu. Ale ta informacja nie wystarczy, aby dokładnie przewidzieć cenę domu.
Dodawanie funkcji i ich skalowanie
Jeśli okaże się, że funkcja celu nie odpowiada problemowi, który próbujesz rozwiązać, należy ją skorygować. Typowym sposobem skorygowania niespójności jest dodanie dodatkowych funkcji do wektora cech. Na przykładzie ceny domu możesz dodać takie cechy jak liczba pokoi czy wiek domu. Oznacza to, że zamiast używać wektora o jednej wartości cechy
{size}
do opisu domu, można użyć wektora o kilku wartościach, na przykład:
{size, number-of-rooms, age}.
W niektórych przypadkach liczba cech w dostępnych danych uczących jest niewystarczająca. Wtedy warto spróbować skorzystać z cech wielomianowych, które są obliczane na podstawie już istniejących. Na przykład masz możliwość rozszerzenia funkcji celu służącej do określenia ceny domu, tak aby zawierała ona obliczoną cechę metrów kwadratowych (x2):
Korzystanie z wielu funkcji wymaga
skalowania funkcji , które służy do standaryzacji zakresu różnych funkcji. Zatem zakres wartości atrybutu
rozmiar 2 jest znacznie większy niż zakres wartości atrybutu rozmiar. Bez skalowania cech
wielkość 2 będzie nadmiernie wpływać na funkcję kosztu. Błąd wprowadzony przez atrybut
rozmiar 2 będzie znacznie większy niż błąd wprowadzony przez atrybut rozmiar. Poniżej podano prosty algorytm skalowania cech:
Algorytm ten jest zaimplementowany w klasie
FeaturesScaling
w przykładowym kodzie poniżej. Na zajęciach
FeaturesScaling
przedstawiono komercyjną metodę tworzenia funkcji skalującej dostrojonej do danych uczących. Wewnętrznie instancje danych szkoleniowych służą do obliczania wartości średniej, minimalnej i maksymalnej. Wynikowa funkcja pobiera wektor cech i tworzy nowy ze skalowanymi cechami. Skalowanie cech jest konieczne zarówno w procesie uczenia się, jak i w procesie przewidywania, jak pokazano poniżej:
List<ltDouble[]> dataset = new ArrayList<>();
dataset.add(new Double[] { 1.0, 90.0, 8100.0 });
dataset.add(new Double[] { 1.0, 101.0, 10201.0 });
dataset.add(new Double[] { 1.0, 103.0, 10609.0 });
List<ltDouble> labels = new ArrayList<>();
labels.add(249.0);
labels.add(338.0);
labels.add(304.0);
Function<ltDouble[], Double[]> scalingFunc = FeaturesScaling.createFunction(dataset);
List<ltDouble[]> scaledDataset = dataset.stream().map(scalingFunc).collect(Collectors.toList());
LinearRegressionFunction targetFunction = new LinearRegressionFunction(new double[] { 1.0, 1.0, 1.0 });
for (int i = 0; i < 10000; i++) {
targetFunction = Learner.train(targetFunction, scaledDataset, labels, 0.1);
}
Double[] scaledFeatureVector = scalingFunc.apply(new Double[] { 1.0, 600.0, 360000.0 });
double predictedPrice = targetFunction.apply(scaledFeatureVector);
W miarę dodawania coraz większej liczby funkcji zwiększa się dopasowanie do funkcji celu, ale należy zachować ostrożność. Jeśli posuniesz się za daleko i dodasz zbyt wiele funkcji, może się okazać, że nauczysz się funkcji celu, która będzie nadmiernie dopasowana.
Nadmierne dopasowanie i weryfikacja krzyżowa
Przeuczenie ma miejsce, gdy funkcja celu lub model zbyt dobrze pasuje do danych uczących, do tego stopnia, że wychwytuje szum lub przypadkowe zmiany w danych uczących. Przykład nadmiernego dopasowania pokazano na wykresie znajdującym się skrajnie po prawej stronie poniżej:
Jednak model nadmiernie dopasowany działa bardzo dobrze na danych uczących, ale będzie działał słabo na rzeczywistych nieznanych danych. Istnieje kilka sposobów na uniknięcie nadmiernego dopasowania.
- Użyj większego zestawu danych do szkolenia.
- Używaj mniejszej liczby funkcji, jak pokazano na powyższych wykresach.
- Użyj ulepszonego algorytmu uczenia maszynowego, który uwzględnia regularyzację.
Jeśli algorytm predykcyjny nadmiernie dopasowuje się do danych uczących, konieczne jest wyeliminowanie cech, które nie wpływają korzystnie na jego dokładność. Trudność polega na znalezieniu cech, które mają bardziej znaczący wpływ na dokładność przewidywań niż inne. Jak pokazano na wykresach, nadmierne dopasowanie można określić wizualnie za pomocą wykresów. Działa to dobrze w przypadku wykresów z 2 lub 3 współrzędnymi. Wykreślanie i ocena wykresu staje się trudna, jeśli używasz więcej niż 2 funkcji. W przypadku walidacji krzyżowej po zakończeniu procesu uczenia modele są ponownie testowane przy użyciu danych nieznanych algorytmowi. Dostępne oznakowane dane należy podzielić na 3 zbiory:
- dane treningowe;
- dane weryfikacyjne;
- dane testowe.
W takim przypadku 60 proc. oznaczonych rekordów charakteryzujących domy należy wykorzystać w procesie uczenia wariantów docelowego algorytmu. Po zakończeniu procesu uczenia połowa pozostałych danych (nieużywanych wcześniej) powinna zostać wykorzystana do sprawdzenia, czy przeszkolony algorytm docelowy dobrze radzi sobie z nieznanymi danymi. Zazwyczaj do użycia wybierany jest algorytm, który działa lepiej niż inne. Pozostałe dane służą do obliczenia wartości błędu dla ostatecznie wybranego modelu. Istnieją inne techniki sprawdzania krzyżowego, takie jak
k-fold . Nie będę ich jednak opisywać w tym artykule.
Narzędzia do uczenia maszynowego i framework Weka
Większość frameworków i bibliotek udostępnia obszerny zbiór algorytmów uczenia maszynowego. Ponadto zapewniają wygodny interfejs wysokiego poziomu do uczenia, testowania i przetwarzania modeli danych. Weka to jeden z najpopularniejszych frameworków dla JVM. Weka to praktyczna biblioteka Java zawierająca testy graficzne służące do sprawdzania poprawności modeli. Poniższy przykład wykorzystuje bibliotekę Weka do utworzenia zbioru danych szkoleniowych zawierającego funkcje i etykiety. Metoda
setClassIndex()
- do znakowania. W Weka etykieta jest zdefiniowana jako klasa:
ArrayList<ltAttribute> attributes = new ArrayList<>();
Attribute sizeAttribute = new Attribute("sizeFeature");
attributes.add(sizeAttribute);
Attribute squaredSizeAttribute = new Attribute("squaredSizeFeature");
attributes.add(squaredSizeAttribute);
Attribute priceAttribute = new Attribute("priceLabel");
attributes.add(priceAttribute);
Instances trainingDataset = new Instances("trainData", attributes, 5000);
trainingDataset.setClassIndex(trainingSet.numAttributes() - 1);
Instance instance = new DenseInstance(3);
instance.setValue(sizeAttribute, 90.0);
instance.setValue(squaredSizeAttribute, Math.pow(90.0, 2));
instance.setValue(priceAttribute, 249.0);
trainingDataset.add(instance);
Instance instance = new DenseInstance(3);
instance.setValue(sizeAttribute, 101.0);
...
Zestaw danych i przykładowy obiekt można zapisać i wczytać z pliku. Weka korzysta z formatu
ARFF (Attribute Relation File Format), który jest obsługiwany w testach graficznych Weka. Ten zbiór danych służy do uczenia funkcji celu zwanej klasyfikatorem w Weka. Przede wszystkim należy zdefiniować funkcję celu. Poniższy kod
LinearRegression
utworzy instancję klasyfikatora. Ten klasyfikator zostanie przeszkolony przy użyciu metody
buildClassifier()
. Metoda
buildClassifier()
dobiera parametry theta na podstawie danych uczących w poszukiwaniu najlepszego modelu docelowego. Dzięki Weka nie musisz się martwić ustawieniem szybkości uczenia się ani liczby iteracji. Weka wykonuje również niezależne skalowanie funkcji.
Classifier targetFunction = new LinearRegression();
targetFunction.buildClassifier(trainingDataset);
Po dokonaniu tych ustawień funkcję celu można wykorzystać do przewidywania ceny domu, jak pokazano poniżej:
Instances unlabeledInstances = new Instances("predictionset", attributes, 1);
unlabeledInstances.setClassIndex(trainingSet.numAttributes() - 1);
Instance unlabeled = new DenseInstance(3);
unlabeled.setValue(sizeAttribute, 1330.0);
unlabeled.setValue(squaredSizeAttribute, Math.pow(1330.0, 2));
unlabeledInstances.add(unlabeled);
double prediction = targetFunction.classifyInstance(unlabeledInstances.get(0));
Weka udostępnia klasę
Evaluation
do testowania przeszkolonego klasyfikatora lub modelu. W poniższym kodzie zastosowano wybraną tablicę danych walidacyjnych, aby uniknąć fałszywych wyników. Wyniki pomiaru (koszt błędu) zostaną wyświetlone na konsoli. Zazwyczaj wyniki oceny służą do porównywania modeli wyszkolonych przy użyciu różnych algorytmów uczenia maszynowego lub ich odmian:
Evaluation evaluation = new Evaluation(trainingDataset);
evaluation.evaluateModel(targetFunction, validationDataset);
System.out.println(evaluation.toSummaryString("Results", false));
W powyższym przykładzie zastosowano regresję liniową, która przewiduje wartości liczbowe, takie jak cena domu, na podstawie wartości wejściowych. Regresja liniowa wspiera przewidywanie ciągłych wartości liczbowych. Aby przewidzieć wartości binarne („Tak” i „Nie”), musisz użyć innych algorytmów uczenia maszynowego. Na przykład drzewo decyzyjne, sieci neuronowe lub regresja logistyczna.
Classifier targetFunction = new Logistic();
targetFunction.buildClassifier(trainingSet);
Możesz na przykład użyć jednego z tych algorytmów, aby przewidzieć, czy wiadomość e-mail jest spamem, przewidzieć pogodę lub przewidzieć, czy dom będzie się dobrze sprzedawać. Jeśli chcesz nauczyć swój algorytm przewidywania pogody lub tego, jak szybko sprzeda się dom, potrzebujesz innego zestawu danych, np.:
topseller:
ArrayList<string> classVal = new ArrayList<>();
classVal.add("true");
classVal.add("false");
Attribute topsellerAttribute = new Attribute("topsellerLabel", classVal);
attributes.add(topsellerAttribute);
Ten zbiór danych zostanie wykorzystany do uczenia nowego klasyfikatora
topseller
. Po przeszkoleniu wywołanie przewidywania powinno zwrócić indeks klasy tokenu, którego można użyć do uzyskania przewidywanej wartości.
int idx = (int) targetFunction.classifyInstance(unlabeledInstances.get(0));
String prediction = classVal.get(idx);
Wniosek
Chociaż uczenie maszynowe jest ściśle powiązane ze statystyką i wykorzystuje wiele koncepcji matematycznych, zestaw narzędzi do uczenia maszynowego pozwala rozpocząć integrację uczenia maszynowego z programami bez głębokiej znajomości matematyki. Jednak im lepiej rozumiesz podstawowe algorytmy uczenia maszynowego, takie jak algorytm regresji liniowej, który omówiliśmy w tym artykule, tym łatwiej będzie wybrać odpowiedni algorytm i dostroić go w celu uzyskania optymalnej wydajności.
Tłumaczenie z języka angielskiego. Autor: Gregor Roth, architekt oprogramowania, JavaWorld.
GO TO FULL VERSION