Online Linear Regression with Recursive Least Squares filter

Recursive Least Squares (RLS) filter solves the least squares problem without requiring the complete data for training, it can perform sequential updates to the model from a sequence of observations which is useful for streaming applications.

Scala

Import RLS filter & spark, start spark session.

import com.github.ozancicek.artan.ml.filter.RecursiveLeastSquaresFilter
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.ml.linalg._


val spark = SparkSession
  .builder
  .appName("RLSExample")
  .getOrCreate

import spark.implicits._

Training multiple models is achieved by mapping samples to models. Each label and featuers can be associated with a different model by creating a ‘key’ column and specifying it with setStateKeyCol. Not specifying any key column will result in training a single model.

Training data is generated using streaming rate source. Streaming rate source generates consecutive numbers with timestamps. These consecutive numbers are binned for different models and then used for generating label & features vectors.

val numStates = 100

// Simple linear model, states to be estimated are a, b and c
// z = a*x + b*y + c + w, where w ~ N(0, 1)

val a = 0.5
val b = 0.2
val c = 1.2
val noiseParam = 1.0
val featuresSize = 3

val featuresUDF = udf((x: Double, y: Double) => {
    new DenseVector(Array(x, y, 1.0))
})

val labelUDF = udf((x: Double, y: Double, w: Double) => {
    a*x + b*y + c + w
})

val features = spark.readStream.format("rate")
  .option("rowsPerSecond", 10)
  .load()
  .withColumn("mod", $"value" % numStates)
  .withColumn("stateKey", $"mod".cast("String"))
  .withColumn("x", ($"value"/numStates).cast("Integer").cast("Double"))
  .withColumn("y", sqrt($"x"))
  .withColumn("label", labelUDF($"x", $"y", randn() * noiseParam))
  .withColumn("features", featuresUDF($"x", $"y"))

The estimated state distribution will be outputted in state struct column. The model parameters can be found at state.mean field as a vector. Along with the state column, stateKey and stateIndex column can be used for identifying different models and their incremented index.

val truncate = udf((state: DenseVector) => state.values.map(t => (math floor t * 100)/100))

val filter = new RecursiveLeastSquaresFilter()
  .setStateKeyCol("stateKey")
  .setFeatureSize(3)
  .setInitialEstimate(new DenseVector(Array(0.0, 0.0, 0.0)))
  .setRegularizationMatrixFactor(10E6)
  .setForgettingFactor(0.99)

val query = filter.transform(features)
  .select($"stateKey", $"stateIndex", truncate($"state.mean").alias("modelParameters"))
  .writeStream
  .queryName("RLSRateSourceOLS")
  .outputMode("append")
  .format("console")
  .start()

query.awaitTermination()

/*
Batch: 65
-------------------------------------------
+--------+----------+-------------------+
|stateKey|stateIndex|    modelParameters|
+--------+----------+-------------------+
|       7|        68|[0.54, -0.19, 1.98]|
|       3|        68|  [0.5, 0.11, 1.41]|
|       8|        68|[0.53, -0.13, 1.89]|
|       0|        68| [0.46, 0.53, 0.34]|
|       5|        68|   [0.5, 0.2, 1.05]|
|       6|        68| [0.45, 0.68, 0.18]|
|       9|        68|[0.53, -0.15, 1.82]|
|       1|        68|  [0.5, 0.09, 2.17]|
|       4|        68| [0.51, 0.11, 1.17]|
|       2|        68|  [0.48, 0.35, 0.9]|
+--------+----------+-------------------+

-------------------------------------------
Batch: 66
-------------------------------------------
+--------+----------+-------------------+
|stateKey|stateIndex|    modelParameters|
+--------+----------+-------------------+
|       7|        69|[0.54, -0.18, 1.96]|
|       3|        69| [0.49, 0.19, 1.28]|
|       8|        69|[0.53, -0.19, 1.99]|
|       0|        69|  [0.45, 0.6, 0.23]|
|       5|        69| [0.51, 0.14, 1.15]|
|       6|        69| [0.45, 0.71, 0.14]|
|       9|        69| [0.53, -0.1, 1.75]|
|       1|        69| [0.49, 0.15, 2.09]|
|       4|        69|  [0.51, 0.1, 1.18]|
|       2|        69| [0.49, 0.25, 1.04]|
+--------+----------+-------------------+
*/

See examples for the full code

Python

Import RLS filter & spark, start spark session.

from artan.filter import RecursiveLeastSquaresFilter
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.ml.feature import VectorAssembler

spark = SparkSession.builder.appName("RLSExample").getOrCreate()

Each feature and label can be associated with a different model by creating a key column and specifying it with setStateKeyCol. Not specifying any key column will result in training a single model. Training data is generated using streaming rate source. Streaming rate source generates consecutive numbers with timestamps. These consecutive numbers are binned for different models and then used for generating label & features vectors.

num_states = 10
# Simple linear model, parameters to be estimated are a, b and c
# z = a*x + b*y + c + w, where w ~ N(0, 1)
a = 0.5
b = 0.2
c = 1.2
noise_param = 1
features_size = 3
label_expression = F.col("x") * a + F.col("y") * b + c + F.col("w")

input_df = spark.readStream.format("rate").option("rowsPerSecond", 10).load()\
    .withColumn("mod", F.col("value") % num_states)\
    .withColumn("stateKey", F.col("mod").cast("String"))\
    .withColumn("x", (F.col("value")/num_states).cast("Integer").cast("Double"))\
    .withColumn("y", F.sqrt("x"))\
    .withColumn("bias", F.lit(1.0))\
    .withColumn("w", F.randn(0) * noise_param)\
    .withColumn("label", label_expression)

assembler = VectorAssembler(inputCols=["x", "y", "bias"], outputCol="features")

measurements = assembler.transform(input_df)

The estimated state distribution will be outputted in state struct column. The model parameters can be found at state.mean field as a vector. Along with the state column, stateKey and stateIndex column can be used for identifying different models and their incremented index.

rls = RecursiveLeastSquaresFilter()\
    .setStateKeyCol("stateKey")\
    .setFeatureSize(3)\
    .setInitialEstimate(Vectors.dense([0.0, 0.0, 0.0]))\
    .setRegularizationMatrixFactor(10E6)\
    .setForgettingFactor(0.99)

query = rls.transform(measurements)\
    .writeStream\
    .queryName("RLSRateSourceOLS")\
    .outputMode("append")\
    .format("console")\
    .start()

query.awaitTermination()

"""
-------------------------------------------
Batch: 30
-------------------------------------------
+--------+----------+--------------------+
|stateKey|stateIndex|               state|
+--------+----------+--------------------+
|       7|        42|[[0.4911266440390...|
|       3|        42|[[0.4912998991072...|
|       8|        42|[[0.4836819761355...|
|       0|        42|[[0.5604206240212...|
|       5|        42|[[0.5234529160112...|
|       6|        42|[[0.5543561214337...|
|       9|        42|[[0.4085256071251...|
|       1|        42|[[0.4831233161778...|
|       4|        42|[[0.5283651158175...|
|       2|        42|[[0.4393527335453...|
+--------+----------+--------------------+

-------------------------------------------
Batch: 31
-------------------------------------------
+--------+----------+--------------------+
|stateKey|stateIndex|               state|
+--------+----------+--------------------+
|       7|        43|[[0.4949646265364...|
|       3|        43|[[0.5051874312281...|
|       8|        43|[[0.4697275993015...|
|       0|        43|[[0.5407062556163...|
|       5|        43|[[0.5223665417204...|
|       6|        43|[[0.5438141213982...|
|       9|        43|[[0.3951488184173...|
|       1|        43|[[0.4639848681905...|
|       4|        43|[[0.5232375369727...|
|       2|        43|[[0.4618607402587...|
+--------+----------+--------------------+

"""

See examples for the full code