Source code for sail_on_client.protocol.visual_round
"""Round for visual protocol."""
import logging
from typing import List, Any, Tuple, Dict
from sail_on_client.protocol.visual_dataclasses import (
FeatureExtractionParams,
WorldChangeDetectionParams,
)
from sail_on_client.harness.test_and_evaluation_harness import (
TestAndEvaluationHarnessType,
)
from sail_on_client.utils.utils import safe_remove
from sail_on_client.utils.decorators import skip_stage
log = logging.getLogger(__name__)
[docs]class VisualRound:
"""Class with common elements for visual protocols."""
[docs] def __init__(
self,
algorithm: Any,
data_root: str,
features_dict: Dict,
harness: TestAndEvaluationHarnessType,
logit_dict: Dict,
redlight_instance: str,
session_id: str,
skip_stages: List[str],
test_id: str,
) -> None:
"""
Construct VisualRound.
Args:
algorithm: An instance of algorithm
data_root: Root directory of the data
features_dict: Dictionary with features for the entire dataset
harness: An instance of the harness used for T&E
logit_dict: Dictionary with logits for the entire dataset
redlight_instance: The instance when the world changes
session_id: Session id associated with the algorithm
skip_stages: List of stages that are skipped
test_id: Test id associated with the round
Returns:
None
"""
self.algorithm = algorithm
self.data_root = data_root
self.features_dict = features_dict
self.harness = harness
self.logit_dict = logit_dict
self.redlight_instance = redlight_instance
self.session_id = session_id
self.skip_stages = skip_stages
self.test_id = test_id
[docs] @staticmethod
def get_instance_ids(dataset_path: str) -> List[str]:
"""
Get instance ids from the dataset.
Args:
dataset_path: Path to text file with instances used in a round
Returns:
List of instance ids from the dataset
"""
with open(dataset_path, "r") as dataset:
instance_ids = dataset.readlines()
instance_ids = [instance_id.strip() for instance_id in instance_ids]
return instance_ids
[docs] @skip_stage("WorldDetection")
def _run_world_change_detection(
self,
wcd_params: WorldChangeDetectionParams,
round_id: int,
) -> None:
"""
Private helper function for detecting that the world has changed.
Args:
wcd_params: An instance of dataclass with parameters for world change detection
round_id: Identifier for a round
Returns:
None
"""
wd_result = self.algorithm.execute(wcd_params.get_toolset(), "WorldDetection")
self.harness.post_results(
{"detection": wd_result}, self.test_id, round_id, self.session_id
)
safe_remove(wd_result)