Перейти к основному содержимому
Перейти к основному содержимому

stochasticLogisticRegression

stochasticLogisticRegression

Добавлено в версии: v20.1

Функция реализует стохастическую логистическую регрессию. Может использоваться для задач бинарной классификации, поддерживает те же пользовательские параметры, что и stochasticLinearRegression, и работает аналогично.

Использование

Функция используется в два шага:

  1. Обучение модели

Для подбора параметров можно использовать такой запрос:

CREATE TABLE IF NOT EXISTS train_data
(
    param1 Float64,
    param2 Float64,
    target Float64
) ENGINE = Memory;

CREATE TABLE your_model ENGINE = Memory AS SELECT
stochasticLogisticRegression(0.1, 0.0, 5, 'SGD')(target, x1, x2)
AS state FROM train_data;

Здесь также необходимо вставить данные в таблицу train_data. Количество параметров не является фиксированным и зависит только от количества аргументов, переданных в logisticRegressionState. Все они должны быть числовыми значениями. Обратите внимание, что столбец с целевым значением (которое требуется научиться предсказывать) вставляется первым аргументом.

Прогнозируемые метки должны находиться в диапазоне [-1, 1].

  1. Прогнозирование

Используя сохраненное состояние, можно предсказать вероятность того, что объект имеет метку 1.

WITH (SELECT state FROM your_model) AS model SELECT
evalMLMethod(model, param1, param2) FROM test_data

Запрос вернёт столбец с вероятностями. Обратите внимание, что первый аргумент evalMLMethod — это объект AggregateFunctionState, а следующие — столбцы признаков.

Также можно задать границу вероятности, которая определяет принадлежность элементов к различным меткам.

SELECT result < 1.1 AND result > 0.5 FROM
(WITH (SELECT state FROM your_model) AS model SELECT
evalMLMethod(model, param1, param2) AS result FROM test_data)

Тогда результатом будут метки.

test_data — таблица, аналогичная train_data, но может не содержать целевого значения.

Синтаксис

stochasticLogisticRegression([learning_rate, l2_regularization_coef, mini_batch_size, method])(target, x1, x2, ...)

Аргументы

  • learning_rate — Коэффициент, определяющий длину шага при выполнении шага градиентного спуска. Слишком большое значение может привести к бесконечным значениям весов модели. Значение по умолчанию — 0.00001. Float64
  • l2_regularization_coef — коэффициент L2-регуляризации, который помогает предотвратить переобучение. По умолчанию значение — 0.1. Float64
  • mini_batch_size — задаёт количество элементов, для которых будут вычисляться и суммироваться градиенты при выполнении одного шага градиентного спуска. Чистый стохастический спуск использует один элемент, однако использование небольших батчей (порядка 10 элементов) делает шаги градиентного спуска более стабильными. Значение по умолчанию — 15. UInt64
  • method — метод для обновления весов: Adam (по умолчанию), SGD, Momentum, Nesterov. Momentum и Nesterov требуют немного больше вычислений и памяти, однако они полезны с точки зрения скорости сходимости и устойчивости стохастических градиентных методов. String
  • target — целевые метки бинарной классификации. Должны находиться в диапазоне [-1, 1]. Float
  • x1, x2, ... — значения признаков (независимые переменные). Все должны быть числами. Float

Возвращаемое значение

Возвращает веса обученной модели логистической регрессии. Для получения предсказаний используйте evalMLMethod, которая возвращает вероятности того, что объект имеет метку 1. Array(Float64)

Примеры

Обучение модели

CREATE TABLE your_model
ENGINE = MergeTree
ORDER BY tuple()
AS SELECT
stochasticLogisticRegressionState(1.0, 1.0, 10, 'SGD')(target, x1, x2)
AS state FROM train_data
Saves trained model state to table

Создание прогнозов

WITH (SELECT state FROM your_model) AS model
SELECT
evalMLMethod(model, x1, x2)
FROM test_data
Returns probability values for test data

Классификация с порогом

SELECT result < 1.1 AND result > 0.5
FROM (
WITH (SELECT state FROM your_model) AS model SELECT
evalMLMethod(model, x1, x2) AS result FROM test_data)
Returns binary classification labels using probability threshold

См. также