Deep learning models are increasingly deployed on edge Internet of Things (IoT) devices. However, these models typically operate under supervised conditions and fail to recognize unseen classes different from training. To address this, zero-shot learning (ZSL) aims to classify data of unseen classes with the help of semantic information. Foundation models (FMs) trained on web-scale data have shown impressive ZSL capability in natural language processing and visual understanding. However, leveraging FMs' generalized knowledge for zero-shot IoT sensing using signals such as mmWave, IMU, and Wi-Fi has not been fully investigated. In this work, we align the IoT data embeddings with the semantic embeddings generated by an FM's text encoder for zero-shot IoT sensing. To utilize the physics principles governing the generation of IoT sensor signals to derive more effective prompts for semantic embedding extraction, we propose to use cross-attention to combine a learnable soft prompt that is optimized automatically on training data and an auxiliary hard prompt that encodes domain knowledge of the IoT sensing task. To address the problem of IoT embeddings biasing to seen classes due to the lack of unseen class data during training, we propose using data augmentation to synthesize unseen class IoT data for fine-tuning the IoT feature extractor and embedding projector. We evaluate our approach on multiple IoT sensing tasks. Results show that our approach achieves superior open-set detection and generalized zero-shot learning performance compared with various baselines.
- Create and activate the python virtual environment. (Tips)
- Install all requirements:
pip install -r requirements.txt
- USC-HAD: Change the
dataset_path
in./settings/USC.yaml
as your USC-HAD directory path (delete the*.m
and*.txt
in the directory's first level menu in advance). - PAMAP2: Get the path of the subdirectory
Protocol
of PAMAP2 dataset, change thedataset_path
in./settings/pamap.yaml
as your ownProtocol
path. - MM-Fi
- mmWave: We use the official
filtered_mmwave
for training and testing. Change thedataset_path
in./settings/mmwave.yaml
as yourfiltered_mmwave
directory path. - Wi-Fi:
TBD
- mmWave: We use the official
- quick train
python main.py --config_choose <dataset_config>
- train on previous saved log/data
python main.py --config_choose <dataset_config> --back_up_path <path_to_saved_log_or_data>
- view logs: check the logs in the
./logs
- only foundation model
python main.py --config_choose <dataset_config> --back_up_path <path_to_saved_log_or_data> --test_model_path <path_to_saved_model>
- local supervised model + foundation model
# Train local model first (back_up_path use the same data as the previous trained foundation model)
python main_sup.py --config_choose <dataset_config> --back_up_path <path_to_saved_log_or_data>
# Inference local + foundation model
python main.py --config_choose <dataset_config> --back_up_path <path_to_saved_log_or_data> --test_model_path <path_to_saved_fm_model> --local_model_path <path_to_saved_local_model>
- full and detailed documentation
- Wi-Fi data preparation
- configuration guide
- code clean and reorganization
- data augmentation guide