Saving And Restoring Attributes of Algorithm

Introduction

We rely on Checkpoint mixin to save and restore attribute of an algorithm. The mixin can be added to an adaptor associated with an algorithm to provide save_attributes and restore_attributes function. These functions can be used with configuration options to save and restore states in a pickle file.

Warning

Saving and restorying attributes is experimental, and error-prone due to lack of a concrete use-case in the current evaluation.

Note

Saving features and saving attributes were created with different use cases. Saving features is used to save features for videos and restore them across multiple tests. Thus the features can be extracted for a video once and can be reused across multiple tests. We cannot guarantee this mode of operation when checkpointing since attributes are not saved with videos as keys. Thus checkpointing can be used to skip steps for a protocol but it shouldn’t be used on a sequence of videos that is different from the original sequence. This would result in undefined behavior.

Sample Detector and Adapter

Sample Detector

    def __init__(self) -> None:
        """
        Detector constructor.

        Args:
            toolset (dict): Dictionary containing parameters for the constructor
        """
        MockONDAgent.__init__(self)

    def feature_extraction(
        self, toolset: Dict
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """
        Feature extraction step for the algorithm.

        Args:
            toolset (dict): Dictionary containing parameters for different steps

        Return:
            Tuple of dictionary
        """
        self.dummy_dict = toolset["dummy_dict"]
        self.dummy_list = toolset["dummy_list"]
        self.dummy_tuple = toolset["dummy_tuple"]
        self.dummy_tensor = toolset["dummy_tensor"]
        self.dummy_val = toolset["dummy_val"]
        return {}, {}


class MockONDAdapterWithCheckpoint(Checkpointer, MockONDAgent):

Sample Adapter

"""Mocks mainly used for testing protocols."""

from sail_on_client.checkpointer import Checkpointer
from sail_on_client.agent.ond_agent import ONDAgent

from typing import Dict, Any, Tuple, Callable

import logging
import os
import shutil
import torch

log = logging.getLogger(__name__)


class MockONDAgent(ONDAgent):
    """Mock Detector for OND Protocol."""

    def __init__(self) -> None:
        """Construct Mock OND Detector."""
        super().__init__()
        self.step_dict: Dict[str, Callable] = {
            "Initialize": self.initialize,
            "FeatureExtraction": self.feature_extraction,
            "WorldDetection": self.world_detection,
            "NoveltyClassification": self.novelty_classification,
            "NoveltyAdaption": self.novelty_adaptation,
            "NoveltyCharacterization": self.novelty_characterization,
        }

    def initialize(self, toolset: Dict) -> None:
        """
        Algorithm Initialization.

        Args:
            toolset (dict): Dictionary containing parameters for different steps

        Return:
            None
        """
        pass

    def feature_extraction(
        self, toolset: Dict
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """
        Feature extraction step for the algorithm.

        Args:
            toolset (dict): Dictionary containing parameters for different steps

        Return:
            Tuple of dictionary
        """
        self.dataset = toolset["dataset"]
        return {}, {}

    def world_detection(self, toolset: Dict) -> str:
        """
        Detect change in world ( Novelty has been introduced ).

        Args:
            toolset (dict): Dictionary containing parameters for different steps

        Return:
            path to csv file containing the results for change in world
        """
        dataset_dir = os.path.dirname(self.dataset)
        dst_file = os.path.join(dataset_dir, "wc.csv")
        shutil.copyfile(self.dataset, dst_file)
        return dst_file

    def novelty_classification(self, toolset: Dict) -> str:
        """
        Classify data provided in known classes and unknown class.

        Args:
            toolset (dict): Dictionary containing parameters for different steps

        Return:
            path to csv file containing the results for novelty classification step
        """
        dataset_dir = os.path.dirname(self.dataset)
        dst_file = os.path.join(dataset_dir, "ncl.csv")
        shutil.copyfile(self.dataset, dst_file)
        return dst_file

    def novelty_adaptation(self, toolset: Dict) -> None:
        """
        Update models based on novelty classification and characterization.

        Args:
            toolset (dict): Dictionary containing parameters for different steps

        Return:
            None
        """
        pass

    def novelty_characterization(self, toolset: Dict) -> str:
        """
        Characterize novelty by clustering different novel samples.

        Args:
            toolset (dict): Dictionary containing parameters for different steps

        Return:
            path to csv file containing the results for novelty characterization step
        """
        dataset_dir = os.path.dirname(self.dataset)
        dst_file = os.path.join(dataset_dir, "nc.csv")
        shutil.copyfile(self.dataset, dst_file)
        return dst_file

    def execute(self, toolset: Dict, step_descriptor: str) -> Any:
        """
        Execute method used by the protocol to run different steps.

        Args:
            toolset (dict): Dictionary containing parameters for different steps
            step_descriptor (str): Name of the step
        """
        log.info(f"Executing {step_descriptor}")
        return self.step_dict[step_descriptor](toolset)


class MockONDAgentWithAttributes(MockONDAgent):
    """Mock Detector for testing checkpointing."""

    def __init__(self) -> None:
        """
        Detector constructor.

        Args:
            toolset (dict): Dictionary containing parameters for the constructor
        """
        MockONDAgent.__init__(self)

    def feature_extraction(
        self, toolset: Dict
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """
        Feature extraction step for the algorithm.

        Args:
            toolset (dict): Dictionary containing parameters for different steps

        Return:
            Tuple of dictionary
        """
        self.dummy_dict = toolset["dummy_dict"]
        self.dummy_list = toolset["dummy_list"]
        self.dummy_tuple = toolset["dummy_tuple"]
        self.dummy_tensor = toolset["dummy_tensor"]
        self.dummy_val = toolset["dummy_val"]
        return {}, {}


class MockONDAdapterWithCheckpoint(Checkpointer, MockONDAgent):
    """Mock Adapter for testing checkpointing."""

    def __init__(self, toolset: Dict) -> None:
        """
        Detector constructor.

        Args:
            toolset (dict): Dictionary containing parameters for the constructor
        """
        MockONDAgent.__init__(self)
        Checkpointer.__init__(self, toolset)
        self.detector = MockONDAgentWithAttributes()

    def get_config(self) -> Dict:
        """
        Get config for the plugin.

        Returns:
            Parameters for the agent
        """
        config = super().get_config()
        config.update(self.toolset)
        return config

    def execute(self, toolset: Dict, step_descriptor: str) -> Any:
        """
        Execute method used by the protocol to run different steps.

        Args:
            toolset (dict): Dictionary containing parameters for different steps
            step_descriptor (str): Name of the step
        """
        log.info(f"Executing {step_descriptor}")
        return self.detector.step_dict[step_descriptor](toolset)

    def __eq__(self, other: object) -> bool:
        """
        Overriden method to compare two mock adapters.

        Args:
            other (MockONDAdapterWithCheckpoint): Another instance of mock adapter

        Return:
            True if both instances have same attributes
        """
        if not isinstance(other, MockONDAdapterWithCheckpoint):
            return NotImplemented

        return (
            self.detector.dummy_dict == other.detector.dummy_dict
            and self.detector.dummy_list == other.detector.dummy_list
            and self.detector.dummy_tuple == other.detector.dummy_tuple
            and bool(
                torch.all(
                    torch.eq(self.detector.dummy_tensor, other.detector.dummy_tensor)
                )
            )
            and self.detector.dummy_val == other.detector.dummy_val
        )