Skip to content

Binjian/tspace

Repository files navigation

tspace

Overview

tspace is an data pipleline framework for deep reinforcement learning with IO interface, processing and configuration. The current code base depicts an automotive implementation. The goal of the system is to increase the energy efficiency (reward) of a BEV by imposing modification on parameters (action) of powertrain controller, the VCU, based on observations of the vehicle (state), i.e. speed, acceleration, electric engine current, voltage etc. The main features are:

  • works in both training and inferrence mode, supporting
    • coordinated ETL and ML pipelines,
    • online and offline training,
    • local and distributed training;
  • supports multiple models:
    • reinforcement learning models with DDPG and
    • recurrent models (RDPG) for time sequences with arbitrary length;
    • offline reinforcement learning with “Implict Diffusion Q-Learning” (IDQ)
  • the data pipelines are compatible to both ETL and ML dataflow with
    • support of multiple data sources (local CAN or remote cloud object storage),
    • stateful time sequence processing with sequential model and
    • support of both NoSQL database, local and cloud data storage.

Overview of tspace architecture

The diagram shows the basic architecture of tspace.

It is the entry point of the tspace. It orchestrates the whole ETL and ML workflow.

  • It configures KvaserCAN, RemoteCAN, Cruncher, Agent, Model, Database, Pipeline.
  • It manages the scheduling of two primary threads in the first tier of cascaded threading pools in tspace.avatar.main.
  • It selects the either KvaserCAN or RemoteCAN as the vehicle interface for reading the observation and applying the action.

KvaserCAN

It is implemented with Kvaser which provides

  • a local interface for reading the observation (CAN messages of vehicle states) via Kvaser using udp_context to get CAN messages as json data from a local udp server. Then it encodes the raw json data into a pandas.DataFrame for forwarding through the data pipeline to Cruncher.

  • It provides a local interface for applying the action (flashing parameters) onto the vehicle ECU (VCU). Before sending the action, it decodes the action from the pandas.DataFrame into packed string buffer and then sends it to the ECU by calling send_float_array from VehicleInterface.consume.

  • The control messages for training HMI go through the same UDP port. They are used to modify the threading events to control the episodic training process with VehicleInterface.hmi_control.

RemoteCAN

It provides a remote interface to the vehicle via the object storage system on the cloud sent by the onboard TBox. It’s implemented with Cloud:

Cruncher

It is main pivot of the data pipeline for pre-processing the observation and post-processing the action:

  • The Cruncher.filter reveives the observation through the data pipeline from KvaserCAN or RemoteCAN. It pre-processes the input data into the quadruple with a timestamp $(timestamp, state, action, reward, state')$ and give it to the reinforcement Agent DPG, subsequently its child DDPG or RDPG, for inferring an optimal action determined by its current policy. After getting the prediction of the agent, it encodes the prediction result into an action object and forwards it to VehicleInterface.consume to be flashed onto VCU.

  • It collects the critic, actor loss, the total reward for each episode, the running reward and the action at the end of the episode. It also saves the model checkpoint and the training log locally.

Agent

It provides a wrapper for the reinforcement learning model with DPG:

  • provides methods to create, load or initialize the Recurrent Deterministic Policy Gradient Model, or restore checkpoints to it.
  • It provides the concrete methods for the abstract ones in the DPG interface.
  • RDPG.actor_predict_step is the inference method with graph optimization via tf.function.
  • RDPG.train_step is the training method with graph optimization via tf.function. It also applies the weight update to the actor and critic network
  • RDPG.train samples a ragged minibatch of episodes with different lengths from the buffer. It can handle training of time sequences with arbitrary length by truncated back propagation through time (TBPTT) with splitting the episodes and looping over the subsequences with Masking layers to update the weights by RDPG.train_step.
  • provides methods to create and initialize the Implicit Diffusion Q-learning Model.
  • The implementation of model is based on the repo jaxrl5 with Jax and Flax interface.
  • It provides the concrete methods for the abstract ones in the DPG interface.
  • IDQL.actor_predict is the inference method.
  • IDQL.train is the training method. Jaxrl5 takes care of the weight update to the actor and critic and the value network. It samples a minibatch of tuples (state, action, reward, next state) from the buffer.

Model

It’s the neural network model for the reinforcement learning agent. For now it’s only implemented for RDPG in SeqActor and SeqCritic.

It is the actor network with two recurrent LSTM layers, two dense layers and a Masking layer for handling ragged input sequence.

  • SeqActor.predict outputs the action given the state for inference, thus the batch dimension has to be one.
  • SeqActor.evaluate_actions outputs the action given a batch of states for training. It’s used in the training loop to get the prediction of the target actor network to calculate the critic loss.
  • It handles the ragged input sequences with Masking layer and the stateful recurrent layers for TBPTT
  • For inference, SeqCritic is not used and only SeqActor is required.

It is the critic network with two recurrent LSTM layers and two dense layer and a Masking layer for handling ragged input sequence.

Storage

represents the data storage in the repository pattern with two polymorphic abstraction layers Buffer and Pool.

is an abstract class. It provides a view of data storage to the agent:

  • Agent uses the abstract methods Buffer.load, Buffer.save and Buffer.close loads or saves data from or to the Pool, and closes the connection to the Pool.
  • The abstract Buffer.sample samples a minibatch from the Pool. It needs the child of Buffer to implement the concrete efficient sampling method, which depends on the underlying data storage system.
  • The concrete methode Buffer.store store the whole episode data into the Pool
  • The concrete methode Buffer.find simply calls Pool.find to find the data with the given query.

It’s a concrete class for the underlying NoSQL database MongoDB.

  • It implements the abstract methods required by the Buffer interface.
  • MongoBuffer.decode_batch_records prepare the sample batch data from MongoPool into a compliant format for agent training.
  • It can handle both DDPG record data type and RDPG episode data type.

It’s a concrete class for the distributed data storage system Dask.

  • It implements the abstract methods required by the Buffer interface.
  • DaskBuffer.decode_batch_records prepare the sample batch data from DaskPool into a compliant format for agent training.
  • It can handle both DDPG record data type and RDPG episode data type.

is an abstract class. It’s the interface for the underlying data storage. For the moment, it’s implemented with MongoPool and DaskPool.

It’s a concrete class for the underlying NoSQL database MongoDB with time series support. It handles both record data type and episode data type with MongoDB collection features.

  • It provides the interface to the MongoDB database with the pymongo library.
  • It implements the abstract methods required by the Pool interface.
  • MongoPool.store_record stores the record data into the MongoDB database for DDPG agent.
  • MongoPool.store_episode stores the episode data into the MongoDB database for RDPG agent.

It’s an abstract class for the distributed data storage system Dask, since we have to use different backends: Parquet for record data type and avro for episode data type.

  • It supports both local file storage and remote object storage with the dask library.
  • It defines the generic data type for the abstract method required by the Pool interface. The generic data type can then be specialized by the concrete classes either as dask.DataFrame for record data type or dask.Bag for episode data type.

is a concrete class for the record data type with the Parquet file format as backend storage.

is a concrete class for the episode data type with the avro file format as backend storage.

  • It implements the abstract methods required by the DaskPool interface and Pool subsequently.
  • AvroPool.sample provides an efficient unified sampling interface via Dask.Bag to a avro storage either locally or remotely.
  • AvroPool.get_query provides the query object through Dask indexing for the AvroPool.sample method.

Configuration

provides all classes for the configuration of the tspace framework. Most of them serve as meta information for the observation data and used in later indexing or grouping for efficient sampling. It includes

  • Truck with children TruckInCloud and TruckInField with different interfaces using mixins TboxMixin and KvaserMixin. It provides a managed truck list and two dictionaries for quick access to the truck configuration;
  • Driver with properties to be store in the meta information of the observation data;
  • TripMessenger for different the HMI input source;
  • CANMessenger for different CAN message source;
  • DBConfig for management of the database configuration;

Scheduling

The schduling of ETL and ML training and inference is carried out as two levels of cascaded threading pools.

Primary threading pool

is managed by Avatar with two primary threads in tspace.avatar.main:

  • The first primary thread is for data caputring
  • The second primary thread is for training and inference

Data capturing thread

calls VehicleInterface.ignite, which is shared by Kvaser and Cloud. It just starts a secondary threading pool containing six threads

  • VehicleInterface.produce get the raw data either from the local UDP server as in Kvaser or the remote cloud object storage as in Cloud and forward it to the raw data pipeline. In case of Kvaser, it also gets the training HMI control messages from the same UDP server and put them in the HMI data pipeline.
  • VehicleInterface.hmi_control manages the episodic state machine to control the training and inference process.
  • VehicleInterface.countdown handles the episode end with a countdown timer to synchronize the data caputring is aligned with the episode end event.
  • VehicleInterface.filter transforms the raw input json object into pandas.DataFrame and forward it to the input data pipeline of Cruncher.filter thread.
  • VehicleInterface.consume is responsible for fetching the action object from the output data pipeline of Cruncher.filter thread and having it flashed on the vehicle ECU (VCU).
  • VehicleInterface.watch_dog provides a watchdog to monitor the health of the data capturing process and the training process. It triggers the system stop if the observation or action quality is below a threshold.

Model training and inference thread

call Cruncher.filter. Importantly, all processing in this thread is done synchronously in order to preserve the order of the time sequence, thus the causality of the oberservation and action.

  • It gets the data through the input pipeline and delegates the data to the agent for training or inference.
  • After getting the prediction from the agent, it encodes the prediction result into an action object and forwards it through the output pipeline to VehicleInterface.consume to have it flashed on VCU.
  • It also controls the training loop, the inference loop and manage the training log and model checkpoint.
  • This thread is synchronized with the threads in the secondary threading pool with pre-defined threading.Event: start_event, stop_event, flash_event, interrupt_event and exit_event.

TODO

  1. Add time sequence embedding database support with LanceDB for TimeGPT
  2. Batch mode for large scale inference and training with Unit of Work pattern
  3. Add schemes for serializing generic time series data

How to use

Install

pip install tspace