Source code for sail_on_client.protocol.condda_test

"""Test for CONDDA."""

import logging
from itertools import count
from typing import Dict, List

from sail_on_client.protocol.condda_dataclasses import (
    AlgorithmAttributes,
    InitializeParams,
)
from sail_on_client.protocol.visual_test import VisualTest
from sail_on_client.protocol.condda_round import CONDDARound
from sail_on_client.harness.test_and_evaluation_harness import (
    TestAndEvaluationHarnessType,
)
from sail_on_client.utils.utils import safe_remove
from sail_on_client.errors import RoundError


log = logging.getLogger(__name__)


[docs]class CONDDATest(VisualTest): """Class representing CONDDA Test."""
[docs] def __init__( self, algorithm_attributes: AlgorithmAttributes, data_root: str, domain: str, feature_dir: str, harness: TestAndEvaluationHarnessType, save_dir: str, session_id: str, skip_stages: List[str], use_consolidated_features: bool, use_saved_features: bool, ) -> None: """ Construct test for CONDDA. Args: algorithm_attributes: An instance of algorithm_attributes data_root: Root directory for the dataset domain: Name of the domain for the test feature_dir: Directory where features are stored harness: An Instance of harness used for T&E save_dir: The directory where features are saved session_id: Session identifier for the test skip_stages: List of stages that would be skipped use_consolidated_features: Flag for using consolidated features use_saved_features: Flag for using saved features Returns: None """ super().__init__( algorithm_attributes, data_root, domain, feature_dir, harness, save_dir, session_id, skip_stages, use_consolidated_features, use_saved_features, )
[docs] def __call__(self, test_id: str) -> None: """ Core logic for running test in CONDDA. Args: test_id: An identifier for the test Returns: Score for the test """ metadata = self.harness.get_test_metadata(self.session_id, test_id) redlight_instance = metadata.get("red_light", "") # Initialize algorithm algorithm_instance = self.algorithm_attributes.instance algorithm_parameters = self.algorithm_attributes.parameters algorithm_init_params = InitializeParams( algorithm_parameters, self.session_id, test_id ) algorithm_instance.execute(algorithm_init_params.get_toolset(), "Initialize") # Restore features features_dict, logit_dict = self._restore_features(test_id) # Initialize Round round_instance = CONDDARound( algorithm_instance, self.data_root, features_dict, self.harness, logit_dict, redlight_instance, self.session_id, self.skip_stages, test_id, ) aggregated_features_dict: Dict = {} aggregated_logit_dict: Dict = {} # Run algorithm for multiple rounds for round_id in count(0): log.info(f"Start round: {round_id}") # see if there is another round available try: dataset = self.harness.dataset_request( test_id, round_id, self.session_id ) except RoundError: # no more rounds available, this test is done. break round_instance(dataset, round_id) ( aggregated_features_dict, aggregated_logit_dict, ) = self._aggregate_features_across_round( round_instance, aggregated_features_dict, aggregated_logit_dict ) # cleanup the dataset file for the round safe_remove(dataset) log.info(f"Round complete: {round_id}") self.harness.complete_test(self.session_id, test_id) self._save_features(test_id, aggregated_features_dict, aggregated_logit_dict)