Source code for sail_on_client.protocol.ond_round

"""Round for OND."""

import logging
from typing import List, Any, Dict, Union

from sail_on_client.protocol.ond_dataclasses import (
    NoveltyClassificationParams,
    NoveltyAdaptationParams,
)
from sail_on_client.protocol.visual_dataclasses import (
    FeatureExtractionParams,
    WorldChangeDetectionParams,
)
from sail_on_client.protocol.visual_round import VisualRound
from sail_on_client.harness.test_and_evaluation_harness import (
    TestAndEvaluationHarnessType,
)
from sail_on_client.utils.utils import safe_remove
from sail_on_client.utils.decorators import skip_stage


log = logging.getLogger(__name__)


[docs]class ONDRound(VisualRound): """Class Representing a round in OND."""
[docs] def __init__( self, algorithm: Any, data_root: str, features_dict: Dict, harness: TestAndEvaluationHarnessType, logit_dict: Dict, redlight_instance: str, session_id: str, skip_stages: List[str], test_id: str, ) -> None: """ Construct round for OND. Args: algorithm: An instance of algorithm data_root: Root directory of the data features_dict: Dictionary with features for the entire dataset harness: An instance of the harness used for T&E logit_dict: Dictionary with logits for the entire dataset redlight_instance: The instance when the world changes session_id: Session id associated with the algorithm skip_stages: List of stages that are skipped test_id: Test id associated with the round Returns: None """ super().__init__( algorithm, data_root, features_dict, harness, logit_dict, redlight_instance, session_id, skip_stages, test_id, )
[docs] @skip_stage("NoveltyClassification") def _run_novelty_classification( self, nc_params: NoveltyClassificationParams, round_id: int ) -> None: """ Private helper function for novelty classification. Args: nc_params: An instance of dataclass with parameters for novelty classification round_id: Identifier for a round Returns: None """ ncl_result = self.algorithm.execute( nc_params.get_toolset(), "NoveltyClassification" ) self.harness.post_results( {"classification": ncl_result}, self.test_id, round_id, self.session_id ) safe_remove(ncl_result)
@skip_stage("EvaluateRoundwise") def _evaluate_roundwise(self, round_id: int) -> Dict: """ Compute roundwise accuracy. Args: round_id: Identifier for a round Returns: Dictionary with accuracy metrics for round """ return self.harness.evaluate_round_wise(self.test_id, round_id, self.session_id)
[docs] @skip_stage("NoveltyAdaptation") def _run_novelty_adaptation(self, na_params: NoveltyAdaptationParams) -> None: """ Private helper function for adaptation. Args: na_params: An instance of dataclass with parameters for adaptation Returns: None """ return self.algorithm.execute(na_params.get_toolset(), "NoveltyAdaptation")
[docs] def __call__(self, dataset: str, round_id: int) -> Union[Dict, None]: """ Core logic for running round in OND. Args: algorithm: An instance of the algorithm dataset: Path to a file with the dataset for the round round_id: An Identifier for a round Returns: Score for the round """ # Run feature extraction fe_params = FeatureExtractionParams(dataset, self.data_root, round_id) instance_ids = ONDRound.get_instance_ids(dataset) rfeature_dict, rlogit_dict = self._run_feature_extraction( fe_params, instance_ids ) # Run World Change Detection wc_params = WorldChangeDetectionParams( rfeature_dict, rlogit_dict, round_id, self.redlight_instance ) self._run_world_change_detection(wc_params, round_id) # Run Novelty Classification nc_params = NoveltyClassificationParams(rfeature_dict, rlogit_dict, round_id) self._run_novelty_classification(nc_params, round_id) # Compute metrics for the round round_score = self._evaluate_roundwise(round_id) na_params = NoveltyAdaptationParams(round_id) self._run_novelty_adaptation(na_params) return round_score