Source code for sail_on_client.protocol.visual_test

"""Test for Visual Protocol."""

import logging
import os
import pickle as pkl
from typing import Union, Tuple, Dict, List
import ubelt as ub

from sail_on_client.harness.test_and_evaluation_harness import (
    TestAndEvaluationHarnessType,
)
from sail_on_client.protocol.visual_round import VisualRound
from sail_on_client.utils.decorators import skip_stage
from sail_on_client.protocol.ond_dataclasses import (
    AlgorithmAttributes as ONDAlgorithmAttributes,
)
from sail_on_client.protocol.condda_dataclasses import (
    AlgorithmAttributes as CONDDAAlgorithmAttributes,
)


log = logging.getLogger(__name__)


[docs]class VisualTest: """Class representing test in visual protocol."""
[docs] def __init__( self, algorithm_attributes: Union[ONDAlgorithmAttributes, CONDDAAlgorithmAttributes], data_root: str, domain: str, feature_dir: str, harness: TestAndEvaluationHarnessType, save_dir: str, session_id: str, skip_stages: List[str], use_consolidated_features: bool, use_saved_features: bool, ) -> None: """ Construct visual test. Args: algorithm_attributes: An instance of algorithm_attributes data_root: Root directory for the dataset domain: Name of the domain for the test feature_dir: Directory to save features harness: An Instance of harness used for T&E save_dir: The directory where features are saved session_id: Session identifier for the test skip_stages: List of stages that would be skipped use_consolidated_features: Flag for using consolidated features use_saved_features: Flag for using saved features Returns: None """ self.algorithm_attributes = algorithm_attributes self.data_root = data_root self.domain = domain self.harness = harness self.feature_dir = feature_dir self.save_dir = save_dir self.session_id = session_id self.skip_stages = skip_stages self.use_consolidated_features = use_consolidated_features self.use_saved_features = use_saved_features
[docs] def _restore_features(self, test_id: str) -> Tuple[Dict, Dict]: """ Private function to restore features. Args: test_id: An identifier for the test Returns: Tuple of dictionary with features and logits obtained from the feature extractor """ features_dict: Dict = {} logit_dict: Dict = {} algorithm_name = self.algorithm_attributes.name if self.use_saved_features: if os.path.isdir(self.feature_dir): if self.use_consolidated_features: feature_fname = f"{algorithm_name}_features.pkl" else: feature_fname = f"{test_id}_{algorithm_name}_features.pkl" feature_path = os.path.join(self.feature_dir, feature_fname) test_features = pkl.load(open(feature_path, "rb")) else: test_features = pkl.load(open(self.feature_dir, "rb")) features_dict = test_features["features_dict"] logit_dict = test_features["logit_dict"] return features_dict, logit_dict
def _aggregate_features_across_round( self, round_instance: VisualRound, feature_dict: Dict, logit_dict: Dict ) -> Tuple[Dict, Dict]: """ Aggregate features across multiple rounds. Args: round_instance: Instance of ond round feature_dict: Aggregated features until this function was called logit_dict: Aggregated logit until this function was called Return: Tuple of features and logits with features and logits from the round """ feature_dict.update(getattr(round_instance, "rfeature_dict", {})) logit_dict.update(getattr(round_instance, "rlogit_dict", {})) return feature_dict, logit_dict
[docs] @skip_stage("SaveFeatures") def _save_features( self, test_id: str, feature_dict: Dict, logit_dict: Dict ) -> None: """ Save features for a test. Args: test_id: An identifier for the test feature_dict: Features for the test logit_dict: Logit for the test Return: None """ ub.ensuredir(self.feature_dir) algorithm_name = self.algorithm_attributes.name feature_path = os.path.join( self.feature_dir, f"{test_id}_{algorithm_name}_features.pkl" ) log.info(f"Saving features in {feature_path}") with open(feature_path, "wb") as f: pkl.dump({"features_dict": feature_dict, "logit_dict": logit_dict}, f)