Online Nonlinear Regression with Unscented Kalman Filter ======================================================== Similar to EKF, Unscented Kalman Filter (UKF) can be used for systems where measurement or state process updates are nonlinear functions. The advantage of UKF over EKF is not having to specify jacobian function of the nonlinear update. UKF uses deterministic sampling algorithms to estimate state and its covariance, so instead you have to specify sampling algorithm and its hyperparameters that suits your problem. The example demonstrated here is same with the :ref:`previous section ` Import UKF and start spark session. .. code-block:: scala import com.github.ozancicek.artan.ml.filter.UnscentedKalmanFilter import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.apache.spark.ml.linalg._ val rowsPerSecond = 2 val numStates = 2 For UKF, we only need to define the nonlinear transformation. Similar with EKF, the signature of this function should be ``(Vector, Matrix) => Vector``. .. code-block:: scala // GLM with log link, states to be estimated are a, b // y = exp(a*x + b) + w, where w ~ N(0, 1) val a = 0.2 val b = 0.7 val noiseParam = 1.0 // UDF's for generating measurement vector ([y]) and measurement model matrix ([[x ,1]]) val measurementUDF = udf((x: Double, r: Double) => { val measurement = scala.math.exp(a * x + b) + r new DenseVector(Array(measurement)) }) val measurementModelUDF = udf((x: Double) => { new DenseMatrix(1, 2, Array(x, 1.0)) }) // No jac func is needed compared to EKF val measurementFunc = (in: Vector, model: Matrix) => { val measurement = model.multiply(in) measurement.values(0) = scala.math.exp(measurement.values(0)) measurement } val filter = new UnscentedKalmanFilter() .setStateKeyCol("stateKey") .setInitialStateMean(new DenseVector(Array(0.0, 0.0))) .setInitialStateCovariance( DenseMatrix.eye(2)) .setMeasurementCol("measurement") .setMeasurementModelCol("measurementModel") .setProcessModel(DenseMatrix.eye(2)) .setProcessNoise(DenseMatrix.zeros(2, 2)) .setMeasurementNoise(DenseMatrix.eye(1)) .setMeasurementFunction(measurementFunc) .setCalculateMahalanobis Generate the data & run the query with console sink. .. code-block:: scala val measurements = spark.readStream.format("rate") .option("rowsPerSecond", rowsPerSecond) .load() .withColumn("mod", $"value" % numStates) .withColumn("stateKey", $"mod".cast("String")) .withColumn("x", ($"value"/numStates).cast("Integer").cast("Double")) .withColumn("measurement", measurementUDF($"x", randn() * noiseParam)) .withColumn("measurementModel", measurementModelUDF($"x")) val query = filter.transform(measurements) .writeStream .queryName("UKFRateSourceGLMLog") .outputMode("append") .format("console") .start() query.awaitTermination() /* ------------------------------------------- Batch: 1 ------------------------------------------- +--------+----------+---------+--------------------+--------------------+-------------------+ |stateKey|stateIndex|stepIndex| state| residual| mahalanobis| +--------+----------+---------+--------------------+--------------------+-------------------+ | 0| 1| 0|[[0.0,0.042083883...|[[0.1062957660590...|0.06584432237693608| | 0| 2| 0|[[-0.145367446951...|[[-6.532377989574...| 0.3788428065577319| | 0| 3| 0|[[0.1041022350732...|[[0.9724192360964...| 0.5627644378314648| | 0| 4| 0|[[0.2306342636805...|[[1.0403070080814...| 0.264478850278805| | 0| 5| 0|[[0.1063465161095...|[[-2.572317266578...| 0.3254493264520008| | 1| 1| 0|[[0.0,0.589622351...|[[1.4892722426408...| 0.9225214257075712| | 1| 2| 0|[[-0.204954508948...|[[-15.88844495110...| 0.5335348325127303| | 1| 3| 0|[[-0.118246670452...|[[0.4340293185286...|0.17419976429373318| | 1| 4| 0|[[0.1893660699514...|[[2.2655146428892...| 0.6555587631879435| | 1| 5| 0|[[-0.041656936742...|[[-5.686218438459...| 0.610362762749889| +--------+----------+---------+--------------------+--------------------+-------------------+ ------------------------------------------- Batch: 2 ------------------------------------------- +--------+----------+---------+--------------------+--------------------+--------------------+ |stateKey|stateIndex|stepIndex| state| residual| mahalanobis| +--------+----------+---------+--------------------+--------------------+--------------------+ | 0| 6| 0|[[0.0535886042884...|[[-2.263659901210...| 0.1950673463505543| | 0| 7| 0|[[0.0234734490305...|[[-1.340034077244...| 0.1318700478960915| | 0| 8| 0|[[0.0672813961239...|[[1.9961239736383...| 0.22310849399461063| | 1| 6| 0|[[-0.041153233802...|[[0.0138346168400...|0.001907992146439...| | 1| 7| 0|[[0.0284426670012...|[[2.2444237109837...| 0.31476068999734547| | 1| 8| 0|[[0.0070700693878...|[[-1.172849191366...| 0.11280670334778167| +--------+----------+---------+--------------------+--------------------+--------------------+ ------------------------------------------- Batch: 3 ------------------------------------------- +--------+----------+---------+--------------------+--------------------+--------------------+ |stateKey|stateIndex|stepIndex| state| residual| mahalanobis| +--------+----------+---------+--------------------+--------------------+--------------------+ | 0| 9| 0|[[0.0689014075795...|[[0.1174643793579...|0.009429133195598868| | 0| 10| 0|[[0.1079250363985...|[[3.2873115452680...| 0.2581465189858566| | 1| 9| 0|[[0.0851767804468...|[[4.0510944020424...| 0.4800965035588805| | 1| 10| 0|[[0.0976005513340...|[[1.2501195155840...| 0.08787380959737641| +--------+----------+---------+--------------------+--------------------+--------------------+ ------------------------------------------- Batch: 14 ------------------------------------------- +--------+----------+---------+--------------------+--------------------+-------------------+ |stateKey|stateIndex|stepIndex| state| residual| mahalanobis| +--------+----------+---------+--------------------+--------------------+-------------------+ | 0| 25| 0|[[0.1979735110700...|[[-2.884962984114...| 1.6740538258342939| | 1| 25| 0|[[0.1996110381440...|[[-0.656398157846...|0.41074222096417723| +--------+----------+---------+--------------------+--------------------+-------------------+ ------------------------------------------- Batch: 15 ------------------------------------------- +--------+----------+---------+--------------------+--------------------+-------------------+ |stateKey|stateIndex|stepIndex| state| residual| mahalanobis| +--------+----------+---------+--------------------+--------------------+-------------------+ | 0| 26| 0|[[0.1986075589232...|[[1.0770630811029...| 0.6503247720860308| | 1| 26| 0|[[0.1996956994673...|[[0.1826677631133...|0.11630923034534281| +--------+----------+---------+--------------------+--------------------+-------------------+ */ See `examples `_ for the full code