Skip to content

Commit

Permalink
Major revision to extend SPADE with new capabilities. Now it is possi…
Browse files Browse the repository at this point in the history
…ble to

set a voting strategy for the pseudolabeler. It is possible to have a separate
number of GMM components per model. The `alpha` weight parameter can now be set
separately for positive and negative pseudo-labels.

See the updated README for more details.

PiperOrigin-RevId: 718463956
  • Loading branch information
raj-sinha authored and The spade_anomaly_detection Authors committed Jan 22, 2025
1 parent 0bcf797 commit 69d21b2
Show file tree
Hide file tree
Showing 14 changed files with 394 additions and 115 deletions.
12 changes: 11 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.4.0] - 2025-01-22

* Major revision to extend SPADE with new capabilities. Now it is possible to
set a voting strategy for the pseudolabeler. It is possible to have a separate
number of GMM components per model. The `alpha` weight parameter can now be set
separately for positive and negative pseudolabels.
* Allow labels to be arbitrary strings.
* Upgrade Docker base image to Tensorflow 2.17.

## [0.3.3] - 2024-08-05

* Add support for wildcards in GCS URIs in CSV data loader.
Expand Down Expand Up @@ -58,7 +67,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Initial release

[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.3...HEAD
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.4.0...HEAD
[0.4.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.3...v0.4.0
[0.3.3]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.2...v0.3.3
[0.3.2]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.1...v0.3.2
[0.3.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.0...v0.3.1
Expand Down
36 changes: 26 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ The metric reported by the pipeline is model [AUC](https://developers.google.com

<span style="color:red;background-color:lightgrey">label_col_name (string)</span>: The name of the label column in the input BigQuery table.

<span style="color:red;background-color:lightgrey">labels_are_strings</span>: Whether the labels in the input dataset are strings or integers.

<span style="color:red;background-color:lightgrey">positive_data_value (integer)</span>: The value used in the label column to denote positive data - data points that are anomalous. “1” can be used, for example.

<span style="color:red;background-color:lightgrey">negative_data_value (integer)</span>: The value used in the label column to denote negative data - data points that are not anomalous. “0” can be used, for example.
Expand All @@ -99,17 +101,21 @@ one class classifier ensemble to label a point as negative. The higher this valu

<span style="color:yellow;background-color:lightgrey">data_test_gcs_uri</span>: Cloud Storage location to store the CSV data to be used for evaluating the supervised model. Note that the positive and negative label values must also be the same in this testing set. It is okay to have your test labels in that form, or use 1 for positive and 0 for negative. Use exactly one of BigQuery locations or GCS locations.

<span style="color:yellow;background-color:lightgrey">upload_only</span>: Use this setting in conjunction with `output_bigquery_table_path` or `data_output_gcs_uri`. When `True`, the algorithm will just upload the pseudo labeled data to the specified table, and will skip training a supervised model. When set to `False`, the algorithm will also train a supervised model and upload it to a GCS location. Default is `False`.
<span style="color:yellow;background-color:lightgrey">upload_only (bool)</span>: Use this setting in conjunction with `output_bigquery_table_path` or `data_output_gcs_uri`. When `True`, the algorithm will just upload the pseudo labeled data to the specified table, and will skip training a supervised model. When set to `False`, the algorithm will also train a supervised model and upload it to a GCS location. Default is `False`.

<span style="color:yellow;background-color:lightgrey">output_bigquery_table_path</span>: A complete BigQuery path in the form of 'project.dataset.table' to be used for uploading the pseudo labeled data. This includes features and new labels. By default, we will use the column names from the input_bigquery_table_path BigQuery table. Use exactly one of BigQuery locations or GCS locations.

<span style="color:yellow;background-color:lightgrey">data_output_gcs_uri</span>: Cloud Storage location used for uploading the pseudo labeled data as CSV. This includes features and new labels. By default, we will use the column names from the data_input_gcs_uri table. Use exactly one of BigQuery locations or GCS locations.

<span style="color:yellow;background-color:lightgrey">alpha (float)</span>: Sample weights for weighting the loss function, only for pseudo-labeled data from the occ ensemble. Original data that is labeled will have a weight of 1.0. By default, we use alpha = 1.0.
<span style="color:yellow;background-color:lightgrey">voting_strategy (bool)</span>: The voting strategy to use when determining if a data point is anomalous. By default, we use unanimous voting, meaning all the models in the ensemble need to agree in order to label a data point as anomalous.

<span style="color:yellow;background-color:lightgrey">alpha (float)</span>: Sample weights for weighting the loss function, only for positively pseudo-labeled data from the occ ensemble. Original data that is labeled will have a weight of 1.0. If this is provided and `alpha_negative_pseudolabels` is not provided, then this value will be used for both positive and negative pseudo-labeled data. By default, we use alpha = 1.0.

<span style="color:yellow;background-color:lightgrey">alpha_negative_pseudolabels (float)</span>: Sample weights for weighting the loss function, only for negatively pseudo-labeled data from the occ ensemble. Original data that is labeled will have a weight of 1.0. If this is not provided, then the `alpha` value will be used for both positive and negative pseudo-labeled data. By default, we use alpha_negative_pseudolabels = 1.0.

<span style="color:yellow;background-color:lightgrey">ensemble_count</span>: Integer representing the number of one class classifiers in the ensemble used for pseudo labeling unlabeled data points. The more models in the ensemble, the less likely it is for all the models to gain consensus, and thus will reduce the amount of labeled data points. By default, we use 5 one class classifiers.

<span style="color:yellow;background-color:lightgrey">n_components</span>: Integer representing the number of components to use in the one class classifier ensemble. By default, we use 1 component.
<span style="color:yellow;background-color:lightgrey">n_components</span>: The number of components to use in the one class classifier ensemble. By default, we use 1 component. Pass a single integer if all the ensemble models should have the same number of components. Pass a space-separated list of integers if you want to use different numbers of components for each model in the ensemble. By default, we use 1 component.

<span style="color:yellow;background-color:lightgrey">covariance_type</span>: String representing the covariance type to use in the one class classifier ensemble. By default, we use 'full' covariance. Note that when there are many components, a 'full' covariance matrix may not be suitable.

Expand Down Expand Up @@ -189,20 +195,30 @@ OUTPUT_BIGQUERY_TABLE_PATH=${5:-"${PROJECT_ID}.[bq-dataset].[bq-output-table]"}
DATA_OUTPUT_GCS_URI=${6:-""}
OUTPUT_GCS_URI=${7:-"gs://[gcs-bucket]/[model-folder]"}
LABEL_COL_NAME=${8:-"y"}
# The label column is of type float, these must match in order for array
# The label column is of type string, these must match in order for array
# filtering to work correctly.
POSITIVE_DATA_VALUE=${9:-"1"}
NEGATIVE_DATA_VALUE=${10:-"0"}
UNLABELED_DATA_VALUE=${11:-"-1"}
POSITIVE_THRESHOLD=${12:-".1"}
NEGATIVE_THRESHOLD=${13:-"95"}
TEST_BIGQUERY_TABLE_PATH=${14:-"${PROJECT_ID}.[bq-dataset].[bq-test-table]"}
DATA_TEST_GCS_URI=${15:-""}
TEST_LABEL_COL_NAME=${16:-"y"}
ALPHA=${17:-"1.0"}
ENSEMBLE_COUNT=${19:-"5"}
VERBOSE=${22:-"True"}
UPLOAD_ONLY=${23:-"False"}
TEST_DATASET_HOLDOUT_FRACTION=${15:-"0"}
DATA_TEST_GCS_URI=${16:-""}
TEST_LABEL_COL_NAME=${17:-"y"}
VOTING_STRATEGY=${18:-"UNANIMOUS"}
ALPHA=${19:-"1.0"}
ALPHA_NEGATIVE_PSEUDOLABELS=${20:-"1.0"}
BATCHES_PER_MODEL=${21:-"1"}
ENSEMBLE_COUNT=${22:-"5"}
N_COMPONENTS=${23:-"1"}
# N_COMPONENTS=${23:-"1,3,5,7,9"}
COVARIANCE_TYPE=${24:-"full"}
MAX_OCC_BATCH_SIZE=${25:-"50000"}
LABELING_AND_MODEL_TRAINING_BATCH_SIZE=${26:-"100000"}
LABELS_ARE_STRINGS=${27:-"True"}
VERBOSE=${28:-"True"}
UPLOAD_ONLY=${29:-"False"}
IMAGE_URI="us-docker.pkg.dev/[project_id]/spade-anomaly-detection/spade:latest"
Expand Down
2 changes: 1 addition & 1 deletion spade_anomaly_detection/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Google Cloud's optimized Tensorflow image
FROM gcr.io/deeplearning-platform-release/tf-cpu.2-10
FROM us-docker.pkg.dev/deeplearning-platform-release/gcr.io/tf2-cpu.2-17.py310

# Alternative Tensorflow image with GPU.
# FROM gcr.io/deeplearning-platform-release/tf-gpu.2-10:latest
Expand Down
2 changes: 1 addition & 1 deletion spade_anomaly_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@

# A new PyPI release will be pushed every time `__version__` is increased.
# When changing this, also update the CHANGELOG.md.
__version__ = '0.3.3'
__version__ = '0.4.0'
1 change: 1 addition & 0 deletions spade_anomaly_detection/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def setUp(self):
output_bigquery_table_path='',
data_output_gcs_uri='',
alpha=1.0,
alpha_negative_pseudolabels=1.0,
batches_per_model=1,
labeling_and_model_training_batch_size=None,
ensemble_count=5,
Expand Down
105 changes: 85 additions & 20 deletions spade_anomaly_detection/occ_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,34 @@
# Using typing instead of collections due to Vertex training containers
# not supporting them.

import dataclasses
from typing import Final, MutableMapping, Optional, Sequence

from absl import logging
import numpy as np
from sklearn import mixture
from spade_anomaly_detection import parameters
import tensorflow as tf


@dataclasses.dataclass
class PseudolabelsContainer:
"""Container to hold the outputs of the pseudolabeling process.
Attributes:
new_features: np.ndarray of features for the new pseudolabeled data.
new_labels: np.ndarray of labels for the new pseudolabeled data.
weights: np.ndarray of weights for the new pseudolabeled data.
pseudolabel_flags: np.ndarray of flags indicating whether the data point is
ground truth or pseudolabeled.
"""

new_features: np.ndarray
new_labels: np.ndarray
weights: np.ndarray
pseudolabel_flags: np.ndarray


_RANDOM_SEED: Final[int] = 42
_SHUFFLE_BUFFER_SIZE: Final[int] = 10_000
_LABEL_TYPE: Final[str] = 'INT64'
Expand Down Expand Up @@ -76,6 +96,9 @@ class GmmEnsemble:
precision when raising this value, and an increase in recall when lowering
it. Equavalent to saying the given data point needs to be X percentile or
greater in order to be considered anomalous.
voting_strategy: The voting strategy to use when determining if a data point
is anomalous. By default, we use unanimous voting, meaning all the models
in the ensemble need to agree in order to label a data point as anomalous.
unlabeled_record_count: The number of unlabeled records in the dataset.
negative_record_count: The number of negative records in the dataset.
unlabeled_data_value: The value used in the label column to denote unlabeled
Expand All @@ -90,13 +113,16 @@ class GmmEnsemble:
# TODO(b/247116870): Create dataclass when another OCC is added.
def __init__(
self,
n_components: int = 1,
n_components: tuple[int, ...] = (1,),
covariance_type: str = 'full',
init_params: str = 'kmeans',
max_iter: int = 100,
ensemble_count: int = 5,
positive_threshold: float = 1.0,
negative_threshold: float = 95.0,
voting_strategy: parameters.VotingStrategy = (
parameters.VotingStrategy.UNANIMOUS
),
random_seed: int = _RANDOM_SEED,
unlabeled_record_count: int | None = None,
negative_record_count: int | None = None,
Expand All @@ -111,6 +137,7 @@ def __init__(
self.ensemble_count = ensemble_count
self.positive_threshold = positive_threshold
self.negative_threshold = negative_threshold
self.voting_strategy = voting_strategy
self._random_seed = random_seed
self.unlabeled_record_count = unlabeled_record_count
self.negative_record_count = negative_record_count
Expand All @@ -122,14 +149,21 @@ def __init__(

self._warm_start = False

def _get_model(self) -> mixture.GaussianMixture:
def _get_model(self, idx: int) -> mixture.GaussianMixture:
"""Instantiates a Gaussian mixture model.
Args:
idx: The index of the model in the ensemble.
Returns:
Gaussian mixture model with class attributes.
"""
return mixture.GaussianMixture(
n_components=self.n_components,
n_components=(
self.n_components[idx]
if len(self.n_components) == self.ensemble_count
else self.n_components[0]
),
covariance_type=self.covariance_type,
init_params=self.init_params,
warm_start=self._warm_start,
Expand Down Expand Up @@ -249,8 +283,8 @@ def fit(
)
dataset_iterator = ds_batched.as_numpy_iterator()

for _ in range(self.ensemble_count):
model = self._get_model()
for idx in range(self.ensemble_count):
model = self._get_model(idx=idx)

for _ in range(batches_per_occ):
features, _ = dataset_iterator.next()
Expand All @@ -269,6 +303,26 @@ def fit(

return self.ensemble

def _vote(self, model_scores: np.ndarray) -> np.ndarray:
"""Votes on whether a data point is anomalous or not.
Args:
model_scores: The scores for each model in the ensemble for a given data
point. Can be the positive score or the negative score.
Returns:
True if the data point is anomalous, False otherwise.
"""
if self.voting_strategy == parameters.VotingStrategy.UNANIMOUS:
return model_scores == self.ensemble_count
elif self.voting_strategy == parameters.VotingStrategy.MAJORITY:
return model_scores > self.ensemble_count // 2
else:
raise ValueError(
f'Unsupported voting strategy: {self.voting_strategy}. Supported'
' strategies are UNANIMOUS and MAJORITY.'
)

def _score_unlabeled_data(
self,
unlabeled_features: np.ndarray,
Expand Down Expand Up @@ -310,8 +364,8 @@ def _score_unlabeled_data(
model_scores_pos += binary_scores_pos
model_scores_neg += binary_scores_neg

positive_indices = np.where(model_scores_pos == self.ensemble_count)[0]
negative_indices = np.where(model_scores_neg == self.ensemble_count)[0]
positive_indices = np.where(self._vote(model_scores_pos))[0]
negative_indices = np.where(self._vote(model_scores_neg))[0]

return {
'positive_indices': positive_indices,
Expand All @@ -326,8 +380,9 @@ def pseudo_label(
negative_data_value: str | int | None,
unlabeled_data_value: str | int,
alpha: float = 1.0,
alpha_negative_pseudolabels: float = 1.0,
verbose: Optional[bool] = False,
) -> Sequence[np.ndarray]:
) -> PseudolabelsContainer:
"""Labels unlabeled data using the trained ensemble of OCCs.
Args:
Expand All @@ -341,16 +396,18 @@ def pseudo_label(
data - data points that are not anomalous.
unlabeled_data_value: The value used in the label column to denote
unlabeled data.
alpha: This value is used to adjust the influence of the pseudo labeled
data in training the supervised model.
alpha: This value is used to adjust the influence of the positively pseudo
labeled data in training the supervised model.
alpha_negative_pseudolabels: This value is used to adjust the influence of
the negatively pseudo labeled data in training the supervised model.
verbose: Chooses the amount of logging info to display. This can be useful
when debugging model performance.
Returns:
A sequence including updated features (features for which we now have
A container including updated features (features for which we now have
labels for), updated labels (includes pseudo labeled positive and negative
values, as well as ground truth), the weights (correct alpha values)
for the new pseudo labeled data points, and a binary flag that indicates
for the new pseudo labeled data points, a binary flag that indicates
whether the data point is newly pseudolabeled, or ground truth. Labels are
in the format of 1 for positive and 0 for negative. Flag is 1 for
pseudo-labeled and 0 for ground truth.
Expand Down Expand Up @@ -390,13 +447,15 @@ def pseudo_label(
],
axis=0,
)
weights = np.concatenate([
np.repeat(alpha, len(new_positive_indices)),
np.repeat(alpha, len(new_negative_indices)),
np.ones([len(original_positive_idx)]),
np.ones([len(original_negative_idx)])
],
axis=0)
weights = np.concatenate(
[
np.repeat(alpha, len(new_positive_indices)),
np.repeat(alpha_negative_pseudolabels, len(new_negative_indices)),
np.ones([len(original_positive_idx)]),
np.ones([len(original_negative_idx)]),
],
axis=0,
)
pseudolabel_flags = np.concatenate(
[
np.ones(len(new_positive_indices)),
Expand All @@ -412,4 +471,10 @@ def pseudo_label(
len(new_positive_indices))
logging.info('Number of new negative labels: %s',
len(new_negative_indices))
return new_features, new_labels, weights, pseudolabel_flags

return PseudolabelsContainer(
new_features=new_features,
new_labels=new_labels,
weights=weights,
pseudolabel_flags=pseudolabel_flags,
)
Loading

0 comments on commit 69d21b2

Please sign in to comment.