Skip to content

Commit

Permalink
Merge branch 'onnx'
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexanders101 committed Mar 1, 2023
2 parents 5194ea0 + 3980549 commit bee0354
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 50 deletions.
54 changes: 54 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,60 @@ you can evaluate the network again on the example dataset by running.
Note that the included example file is very small and you will likely not
see very good performance on it.

### Exporting

Once you are happy with your model, you can export it to an [ONNX](https://onnxruntime.ai/) file to use in external applications. This can be done by running `spanet.export` with the log directory and the desired output file. For example: `python -m spanet.export ./spanet_output/version_0 spanet.onnx`.

Note that only the neural network is able to be exported, and this network outputs the full reconstruction distributions for every event. Unfortunately, the reconstruction algorithm defined [here](spanet/network/prediction_selection.py) cannot be exported as part of the ONNX graph. If your target application uses python, then you can simply use SPANet's selection algorithm, but non-python applications must define their own selection algorithm.

You may examine all of the inputs and outputs with the following snippet:
```python
import onnxruntime # to inference ONNX models, we use the ONNX Runtime

session = onnxruntime.InferenceSession(
"./spanet.onnx",
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)

print("Inputs:", [input.name for input in session.get_inputs()])
print("Outputs:", [output.name for output in session.get_outputs()])
```

#### Inputs

| Input | Shape | DType |
|---------------------------|-------------|-------|
| {sequential_input_1}_data | (B, N1, D1) | float |
| {sequential_input_1}_mask | (B, N1) | bool |
| {sequential_input_2}_data | (B, N2, D2) | float |
| {sequential_input_2}_mask | (B, N2) | bool |
| {global_input_1}_data | (B, 1, D1) | float |
| {global_input_1}_mask | (B, 1) | bool |
| {global_input_2}_data | (B, 1, D2) | float |
| {global_input_2}_mask | (B, 1) | bool |

The ONNX model expects two inputs for every `INPUT` defined in the event file. Replace the values in the braces with their appropriate names. The data contains the features for each input. The features must be provided in the **exact order** that they are defined in the event file. Notice that global inputs require a dummy axis to be added to match the overall shape of the sequential inputs.

**Log Features:** Any features marked either `log` or `log_normalize` must have the following preprocessing transformation applied `f(x) -> log(x + 1)`. You can skip this log preprocessing and have it performed by the network if you specify `--input-log-transform`. However, this operation is expensive to perform by the graph, so we recommend you apply it during your data pipeline for maximum efficiency.

#### Outputs
| Output | Shape | DType |
|-------------------------------------------|----------------|-------|
| {event_particle_1}_assignment_probability | (B, N, N, ...) | float |
| {event_particle_2}_assignment_probability | (B, N, N, ...) | float |
| {event_particle_1}_detection_probability | (B) | float |
| {event_particle_2}_detection_probability | (B) | float |
| {regression_target_1} | (B) | float |
| {regression_target_2} | (B) | float |
| {classification_target_1} | (B, C) | float |
| {classification_target_2} | (B, C) | float |

The ONNX model may produce any of the valid output heads. Each event partile defined produces an assignment distribution for its reconstruction. This distribution with be a singlet/doublet/triplet/etc. joint distribution depending on the number of decay products defined for each particle. The shape will reflect this number of products. For example, if a particle has two decay products, then its `assignment_log_probability` will have a shape of `(B, N, N)`. Each particle also has associated with it a `detection probability` which indicates how likely the particle is to be reconstructable.

The additional outputs will only be present if you define any `REGRESSION` or `CLASSIFICATION` outputs in the event file. Each of the definitions will be add an extra output. The regression outputs simply contain the predicted value for each regression target. The classification outputs contain a distribution over possible classes for each target.

**Log Probability vs. Probability** For additional numerical stability, you may choose to output the log distributions, `log P(x)`, for all probability outputs instead. If you specify `--output-log-transform` in the export script, then the `*_assignment_probability` and `*_detection_probability` outputs will be replaced with `*_assignment_log_probability` and `*_detection_log_probability`. The classification outputs will also be represented as log-probabilities, although the name will not change.

## Citation
If you use this software for a publication, please cite the following:
```bibtex
Expand Down
1 change: 1 addition & 0 deletions spanet/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class Batch(NamedTuple):
class Outputs(NamedTuple):
assignments: List[Tensor]
detections: List[Tensor]
vectors: Dict[str, Tensor]
regressions: Dict[str, Tensor]
classifications: Dict[str, Tensor]

Expand Down
80 changes: 56 additions & 24 deletions spanet/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,53 @@
from glob import glob
from typing import Optional
from typing import Optional, Union, Tuple

import numpy as np
import torch
from tqdm import tqdm
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map

from rich import progress

from spanet import JetReconstructionModel, Options
from spanet.dataset.types import Evaluation
from spanet.dataset.types import Evaluation, Outputs, Source
from spanet.network.jet_reconstruction.jet_reconstruction_network import extract_predictions

from collections import defaultdict


def tree_concatenate(tree):
def dict_concatenate(tree):
output = {}
for key, value in tree.items():
if isinstance(value, dict):
output[key] = tree_concatenate(value)
output[key] = dict_concatenate(value)
else:
output[key] = np.concatenate(value)

return output


def load_model(log_directory: str,
testing_file: Optional[str] = None,
event_info_file: Optional[str] = None,
batch_size: Optional[int] = None,
cuda: bool = False) -> JetReconstructionModel:
def tree_concatenate(trees):
leaves = []
for tree in trees:
data, tree_spec = tree_flatten(tree)
leaves.append(data)

results = [np.concatenate(l) for l in zip(*leaves)]
return tree_unflatten(results, tree_spec)


def load_model(
log_directory: str,
testing_file: Optional[str] = None,
event_info_file: Optional[str] = None,
batch_size: Optional[int] = None,
cuda: bool = False,
checkpoint: Optional[str] = None
) -> JetReconstructionModel:
# Load the best-performing checkpoint on validation data
checkpoint = sorted(glob(f"{log_directory}/checkpoints/epoch*"))[-1]
if checkpoint is None:
checkpoint = sorted(glob(f"{log_directory}/checkpoints/epoch*"))[-1]
print(f"Loading: {checkpoint}")

checkpoint = torch.load(checkpoint, map_location='cpu')
checkpoint = checkpoint["state_dict"]

Expand All @@ -46,10 +64,6 @@ def load_model(log_directory: str,
if batch_size is not None:
options.batch_size = batch_size

# We need a testing file defined somewhere to continue
if options.testing_file is None or options.testing_file == "":
raise ValueError("No testing file found in model options or provided to test.py.")

# Create model and disable all training operations for speed
model = JetReconstructionModel(options)
model.load_state_dict(checkpoint)
Expand All @@ -63,16 +77,26 @@ def load_model(log_directory: str,
return model


def evaluate_on_test_dataset(model: JetReconstructionModel) -> Evaluation:
def evaluate_on_test_dataset(
model: JetReconstructionModel,
progress=progress,
return_full_output: bool = True
) -> Union[Evaluation, Tuple[Evaluation, Outputs]]:
full_assignments = defaultdict(list)
full_assignment_probabilities = defaultdict(list)
full_detection_probabilities = defaultdict(list)

full_classifications = defaultdict(list)
full_regressions = defaultdict(list)

for batch in tqdm(model.test_dataloader(), desc="Evaluating Model"):
sources = [[x[0].to(model.device), x[1].to(model.device)] for x in batch.sources]
full_outputs = []

dataloader = model.test_dataloader()
if progress:
dataloader = progress.track(model.test_dataloader(), description="Evaluating Model")

for batch in dataloader:
sources = tuple(Source(x[0].to(model.device), x[1].to(model.device)) for x in batch.sources)
outputs = model.forward(sources)

assignment_indices = extract_predictions([
Expand Down Expand Up @@ -126,11 +150,19 @@ def evaluate_on_test_dataset(model: JetReconstructionModel) -> Evaluation:
for key, classification in classifications.items():
full_classifications[key].append(classification)

return Evaluation(
tree_concatenate(full_assignments),
tree_concatenate(full_assignment_probabilities),
tree_concatenate(full_detection_probabilities),
tree_concatenate(full_regressions),
tree_concatenate(full_classifications)
if return_full_output:
full_outputs.append(tree_map(lambda x: x.cpu().numpy(), outputs))

evaluation = Evaluation(
dict_concatenate(full_assignments),
dict_concatenate(full_assignment_probabilities),
dict_concatenate(full_detection_probabilities),
dict_concatenate(full_regressions),
dict_concatenate(full_classifications)
)

if return_full_output:
return evaluation, tree_concatenate(full_outputs)

return evaluation

179 changes: 179 additions & 0 deletions spanet/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from argparse import ArgumentParser
from typing import List

import torch
from torch.nn import functional as F
from torch.utils._pytree import tree_map

import pytorch_lightning as pl

from spanet import JetReconstructionModel
from spanet.dataset.types import Source
from spanet.evaluation import load_model


class WrappedModel(pl.LightningModule):
def __init__(
self,
model: JetReconstructionModel,
input_log_transform: bool = False,
output_log_transform: bool = False,
output_embeddings: bool = False
):
super(WrappedModel, self).__init__()

self.model = model
self.input_log_transform = input_log_transform
self.output_log_transform = output_log_transform
self.output_embeddings = output_embeddings

def apply_input_log_transform(self, sources):
new_sources = []
for (data, mask), name in zip(sources, self.model.event_info.input_names):
new_data = torch.stack([
mask * torch.log(data[:, :, i] + 1) if log_transformer else data[:, :, i]
for i, log_transformer in enumerate(self.model.event_info.log_features(name))
], -1)

new_sources.append(Source(new_data, mask))
return new_sources

def forward(self, sources: List[Source]):
if self.input_log_transform:
sources = self.apply_input_log_transform(sources)

outputs = self.model(sources)

if self.output_log_transform:
assignments = [assignment for assignment in outputs.assignments]
detections = [F.logsigmoid(detection) for detection in outputs.detections]

classifications = [
F.log_softmax(outputs.classifications[key], dim=-1)
for key in self.model.training_dataset.classifications.keys()
]

else:
assignments = [assignment.exp() for assignment in outputs.assignments]
detections = [torch.sigmoid(detection) for detection in outputs.detections]

classifications = [
F.softmax(outputs.classifications[key], dim=-1)
for key in self.model.training_dataset.classifications.keys()
]

regressions = [
outputs.regressions[key]
for key in self.model.training_dataset.regressions.keys()
]

embedding_vectors = list(outputs.vectors.values()) if self.output_embeddings else []

return *assignments, *detections, *regressions, *classifications, *embedding_vectors


def onnx_specification(model, output_log_transform: bool = False, output_embeddings: bool = False):
input_names = []
output_names = []

dynamic_axes = {}

for input_name in model.event_info.input_names:
for input_type in ["data", "mask"]:
current_input = f"{input_name}_{input_type}"
input_names.append(current_input)
dynamic_axes[current_input] = {
0: 'batch_size',
1: f'num_{input_name}'
}

for output_name in model.event_info.event_particles.names:
if output_log_transform:
output_names.append(f"{output_name}_assignment_log_probability")
else:
output_names.append(f"{output_name}_assignment_probability")

for output_name in model.event_info.event_particles.names:
if output_log_transform:
output_names.append(f"{output_name}_detection_log_probability")
else:
output_names.append(f"{output_name}_detection_probability")

for regression in model.training_dataset.regressions.keys():
output_names.append(regression)

for classification in model.training_dataset.classifications.keys():
output_names.append(classification)

if output_embeddings:
output_names.append("EVENT/embedding_vector")

for particle, products in model.event_info.product_particles.items():
output_names.append(f"{particle}/PARTICLE/embedding_vector")

for product in products:
output_names.append(f"{particle}/{product}/embedding_vector")

return input_names, output_names, dynamic_axes


def main(
log_directory: str,
output_file: str,
input_log_transform: bool,
output_log_transform: bool,
output_embeddings: bool,
gpu: bool
):
model = load_model(log_directory, cuda=gpu)

# Create wrapped model with flat inputs and outputs
wrapped_model = WrappedModel(model, input_log_transform, output_log_transform, output_embeddings)
wrapped_model.to(model.device)
wrapped_model.eval()
for parameter in wrapped_model.parameters():
parameter.requires_grad_(False)

input_names, output_names, dynamic_axes = onnx_specification(model, output_log_transform, output_embeddings)

batch = next(iter(model.train_dataloader()))
sources = batch.sources
if gpu:
sources = tree_map(lambda x: x.cuda(), batch.sources)
sources = tree_map(lambda x: x[:1], sources)

print("-" * 60)
print(f"Compiling network to ONNX model: {output_file}")
print("-" * 60)
wrapped_model.to_onnx(
output_file,
sources,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=13
)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("log_directory", type=str,
help="Pytorch Lightning Log directory containing the checkpoint and options file.")

parser.add_argument("output_file", type=str,
help="Name to output the ONNX model to.")

parser.add_argument("-g", "--gpu", action="store_true",
help="Trace the network on a gpu.")

parser.add_argument("--input-log-transform", action="store_true",
help="Exported model will apply log transformations to input features automatically.")

parser.add_argument("--output-log-transform", action="store_true",
help="Exported model will output log probabilities. This is more numerically stable.")

parser.add_argument("--output-embeddings", action="store_true",
help="Exported model will also output the embeddings for every part of the event.")

arguments = parser.parse_args()
main(**arguments.__dict__)
Loading

0 comments on commit bee0354

Please sign in to comment.