diff --git a/.github/workflows/score_new_plugins.yml b/.github/workflows/score_new_plugins.yml index baceef061..158a1066c 100644 --- a/.github/workflows/score_new_plugins.yml +++ b/.github/workflows/score_new_plugins.yml @@ -20,12 +20,11 @@ permissions: write-all jobs: - changes_models_or_benchmarks: - name: Check if PR makes changes to /models or /benchmarks + process_submission: + name: If triggering PR alters /models or /benchmarks, initiates scoring for relevant plugins if: github.event.pull_request.merged == true runs-on: ubuntu-latest outputs: - PLUGIN_INFO: ${{ steps.getpluginfo.outputs.PLUGIN_INFO }} RUN_SCORING: ${{ steps.scoringneeded.outputs.RUN_SCORING }} steps: - name: Check out repository code @@ -37,6 +36,13 @@ jobs: uses: actions/setup-python@v4 with: python-version: 3.7 + + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-1 - name: Installing package dependencies run: | @@ -53,54 +59,71 @@ jobs: id: getpluginfo run: | echo "PLUGIN_INFO='$(python -c 'from brainscore_core.plugin_management.parse_plugin_changes import get_scoring_info; get_scoring_info("${{ env.CHANGED_FILES }}", "brainscore_vision")')'" >> $GITHUB_OUTPUT - + - name: Check if scoring needed id: scoringneeded run: | echo "RUN_SCORING=$(jq -r '.run_score' <<< ${{ steps.getpluginfo.outputs.PLUGIN_INFO }})" >> $GITHUB_OUTPUT - - get_submitter_info: - name: Get PR author email and (if web submission) Brain-Score user ID - runs-on: ubuntu-latest - needs: [changes_models_or_benchmarks] - if: needs.changes_models_or_benchmarks.outputs.RUN_SCORING == 'True' - env: - PLUGIN_INFO: ${{ needs.changes_models_or_benchmarks.outputs.PLUGIN_INFO }} - outputs: - PLUGIN_INFO: ${{ steps.add_email_to_pluginfo.outputs.PLUGIN_INFO }} - steps: - - name: Parse user ID from PR title and add to PLUGIN_INFO (WEB ONLY where we don't have access to the GitHub user) - id: add_uid_to_pluginfo - if: contains(github.event.pull_request.labels.*.name, 'automerge-web') + + - name: Find PR author email for non-web submissions + if: "!contains(github.event.pull_request.labels.*.name, 'automerge-web') && steps.scoringneeded.outputs.RUN_SCORING == 'True'" + uses: evvanErb/get-github-email-by-username-action@v2.0 + id: getemail + with: + github-username: ${{github.event.pull_request.user.login}} + token: ${{ secrets.GITHUB_TOKEN }} # Including token enables most reliable way to get a user's email + - name: Update PLUGIN_INFO for non-web submissions + if: "!contains(github.event.pull_request.labels.*.name, 'automerge-web') && steps.scoringneeded.outputs.RUN_SCORING == 'True'" + id: non_automerge_web + run: | + echo "The PR author email is ${{ steps.getemail.outputs.email }}" + echo "PLUGIN_INFO=$(<<<${{ steps.getpluginfo.outputs.PLUGIN_INFO }} tr -d "'" | jq -c '. + {email: "${{ steps.getemail.outputs.email }}"}')" >> $GITHUB_ENV + + - name: Update PLUGIN_INFO for automerge-web (find uid, public v. private, and bs email) + if: contains(github.event.pull_request.labels.*.name, 'automerge-web') && steps.scoringneeded.outputs.RUN_SCORING == 'True' + id: automerge_web run: | BS_UID="$(echo '${{ github.event.pull_request.title }}' | sed -E 's/.*\(user:([^)]+)\).*/\1/')" BS_PUBLIC="$(echo '${{ github.event.pull_request.title }}' | sed -E 's/.*\(public:([^)]+)\).*/\1/')" - echo "The Brain-Score user ID is $BS_UID" - echo "PLUGIN_INFO=$(<<<$PLUGIN_INFO tr -d "'" | jq -c ". + {user_id: \"$BS_UID\", public: \"$BS_PUBLIC\"}")" >> $GITHUB_ENV + USER_EMAIL=$(python -c "from brainscore_core.submission.database import email_from_uid; from brainscore_core.submission.endpoints import UserManager; user_manager=UserManager(db_secret='${{ secrets.BSC_DATABASESECRET }}'); print(email_from_uid($BS_UID))") + echo "::add-mask::$USER_EMAIL" # Mask the USER_EMAIL + echo "PLUGIN_INFO=$(<<<${{ steps.getpluginfo.outputs.PLUGIN_INFO }} tr -d "'" | jq -c ". + {user_id: \"$BS_UID\", public: \"$BS_PUBLIC\", email: \"$USER_EMAIL\"}")" >> $GITHUB_ENV - - name: Get PR author email from GitHub username - id: getemail - uses: evvanErb/get-github-email-by-username-action@v2.0 - with: - github-username: ${{github.event.pull_request.user.login}} # PR author's username - token: ${{ secrets.GITHUB_TOKEN }} # Including token enables most reliable way to get a user's email - - name: Add PR author email to PLUGIN_INFO - id: add_email_to_pluginfo + - name: Write PLUGIN_INFO to a json file run: | - echo "The PR author email is ${{ steps.getemail.outputs.email }}" - echo "PLUGIN_INFO=$(<<<$PLUGIN_INFO tr -d "'" | jq -c '. + {author_email: "${{ steps.getemail.outputs.email }}"}')" >> $GITHUB_OUTPUT + echo "$PLUGIN_INFO" > plugin-info.json + + - name: Upload PLUGIN_INFO as an artifact + uses: actions/upload-artifact@v2 + with: + name: plugin-info + path: plugin-info.json + run_scoring: name: Score plugins runs-on: ubuntu-latest - needs: [changes_models_or_benchmarks, get_submitter_info] - if: needs.changes_models_or_benchmarks.outputs.RUN_SCORING == 'True' + needs: [process_submission] + if: needs.process_submission.outputs.RUN_SCORING == 'True' env: - PLUGIN_INFO: ${{ needs.get_submitter_info.outputs.PLUGIN_INFO }} JENKINS_USER: ${{ secrets.JENKINS_USER }} JENKINS_TOKEN: ${{ secrets.JENKINS_TOKEN }} JENKINS_TRIGGER: ${{ secrets.JENKINS_TRIGGER }} steps: + + - name: Download PLUGIN_INFO artifact + uses: actions/download-artifact@v2 + with: + name: plugin-info + path: artifact-directory + + - name: Set PLUGIN_INFO as an environment variable + run: | + PLUGIN_INFO=$(cat artifact-directory/plugin-info.json) + USER_EMAIL=$(echo "$PLUGIN_INFO" | jq -r '.email') + echo "::add-mask::$USER_EMAIL" # add a mask when bringing email back from artifact + echo "PLUGIN_INFO=${PLUGIN_INFO}" >> $GITHUB_ENV + - name: Add domain, public, competition, and model_type to PLUGIN_INFO run: | echo "PLUGIN_INFO=$(<<<$PLUGIN_INFO tr -d "'" | jq -c '. + {domain: "vision", competition: "None", model_type: "Brain_Model"}')" >> $GITHUB_ENV diff --git a/brainscore_vision/model_helpers/activations/core.py b/brainscore_vision/model_helpers/activations/core.py index 3d982b5e2..f37631dbb 100644 --- a/brainscore_vision/model_helpers/activations/core.py +++ b/brainscore_vision/model_helpers/activations/core.py @@ -1,5 +1,8 @@ import copy import os +import cv2 +import tempfile +from typing import Dict, Tuple, List, Union import functools import logging @@ -8,6 +11,7 @@ import numpy as np from tqdm.auto import tqdm +import xarray as xr from brainio.assemblies import NeuroidAssembly, walk_coords from brainio.stimuli import StimulusSet @@ -32,17 +36,34 @@ def __init__(self, get_activations, preprocessing, identifier=False, batch_size= self.preprocess = preprocessing or (lambda x: x) self._stimulus_set_hooks = {} self._batch_activations_hooks = {} + self._microsaccade_helper = MicrosaccadeHelper() - def __call__(self, stimuli, layers, stimuli_identifier=None): + def __call__(self, stimuli, layers, stimuli_identifier=None, number_of_trials: int = 1, + require_variance: bool = False): """ :param stimuli_identifier: a stimuli identifier for the stored results file. False to disable saving. + :param number_of_trials: An integer that determines how many repetitions of the same model performs. + :param require_variance: A bool that asks models to output different responses to the same stimuli (i.e., + allows stochastic responses to identical stimuli, even in otherwise deterministic base models). + We here implement this using microsaccades. For more, see ... + """ + if require_variance: + self._microsaccade_helper.number_of_trials = number_of_trials # for use with microsaccades + if (self._microsaccade_helper.visual_degrees is None) and require_variance: + self._logger.debug("When using microsaccades for model commitments other than ModelCommitment, you should " + "set self.activations_model.set_visual_degrees(visual_degrees). Not doing so risks " + "breaking microsaccades.") if isinstance(stimuli, StimulusSet): - return self.from_stimulus_set(stimulus_set=stimuli, layers=layers, stimuli_identifier=stimuli_identifier) + function_call = functools.partial(self.from_stimulus_set, stimulus_set=stimuli) else: - return self.from_paths(stimuli_paths=stimuli, layers=layers, stimuli_identifier=stimuli_identifier) + function_call = functools.partial(self.from_paths, stimuli_paths=stimuli) + return function_call( + layers=layers, + stimuli_identifier=stimuli_identifier, + require_variance=require_variance) - def from_stimulus_set(self, stimulus_set, layers, stimuli_identifier=None): + def from_stimulus_set(self, stimulus_set, layers, stimuli_identifier=None, require_variance: bool = False): """ :param stimuli_identifier: a stimuli identifier for the stored results file. False to disable saving. None to use `stimulus_set.identifier` @@ -53,38 +74,49 @@ def from_stimulus_set(self, stimulus_set, layers, stimuli_identifier=None): stimulus_set = hook(stimulus_set) stimuli_paths = [str(stimulus_set.get_stimulus(stimulus_id)) for stimulus_id in stimulus_set['stimulus_id']] activations = self.from_paths(stimuli_paths=stimuli_paths, layers=layers, stimuli_identifier=stimuli_identifier) - activations = attach_stimulus_set_meta(activations, stimulus_set) + activations = attach_stimulus_set_meta(activations, + stimulus_set, + number_of_trials=self._microsaccade_helper.number_of_trials, + require_variance=require_variance) return activations - def from_paths(self, stimuli_paths, layers, stimuli_identifier=None): + def from_paths(self, stimuli_paths, layers, stimuli_identifier=None, require_variance=None): if layers is None: layers = ['logits'] if self.identifier and stimuli_identifier: fnc = functools.partial(self._from_paths_stored, - identifier=self.identifier, stimuli_identifier=stimuli_identifier) + identifier=self.identifier, + stimuli_identifier=stimuli_identifier, + require_variance=require_variance) else: self._logger.debug(f"self.identifier `{self.identifier}` or stimuli_identifier {stimuli_identifier} " f"are not set, will not store") fnc = self._from_paths - # In case stimuli paths are duplicates (e.g. multiple trials), we first reduce them to only the paths that need - # to be run individually, compute activations for those, and then expand the activations to all paths again. - # This is done here, before storing, so that we only store the reduced activations. - reduced_paths = self._reduce_paths(stimuli_paths) - activations = fnc(layers=layers, stimuli_paths=reduced_paths) - activations = self._expand_paths(activations, original_paths=stimuli_paths) + if require_variance: + activations = fnc(layers=layers, stimuli_paths=stimuli_paths, require_variance=require_variance) + else: + # When we are not asked for varying responses but receive `stimuli_paths` duplicates (e.g. multiple trials), + # we first reduce them to only the paths that need to be run individually, compute activations for those, + # and then expand the activations to all paths again. This is done here, before storing, so that we only + # store the reduced activations. + reduced_paths = self._reduce_paths(stimuli_paths) + activations = fnc(layers=layers, stimuli_paths=reduced_paths, require_variance=require_variance) + activations = self._expand_paths(activations, original_paths=stimuli_paths) return activations @store_xarray(identifier_ignore=['stimuli_paths', 'layers'], combine_fields={'layers': 'layer'}) - def _from_paths_stored(self, identifier, layers, stimuli_identifier, stimuli_paths): - return self._from_paths(layers=layers, stimuli_paths=stimuli_paths) + def _from_paths_stored(self, identifier, layers, stimuli_identifier, + stimuli_paths, number_of_trials: int = 1, require_variance: bool = False): + return self._from_paths(layers=layers, stimuli_paths=stimuli_paths, require_variance=require_variance) - def _from_paths(self, layers, stimuli_paths): + def _from_paths(self, layers, stimuli_paths, require_variance: bool = False): if len(layers) == 0: raise ValueError("No layers passed to retrieve activations from") self._logger.info('Running stimuli') - layer_activations = self._get_activations_batched(stimuli_paths, layers=layers, batch_size=self._batch_size) + layer_activations = self._get_activations_batched(stimuli_paths, layers=layers, batch_size=self._batch_size, + require_variance=require_variance) self._logger.info('Packaging into assembly') - return self._package(layer_activations, stimuli_paths) + return self._package(layer_activations=layer_activations, stimuli_paths=stimuli_paths, require_variance=require_variance) def _reduce_paths(self, stimuli_paths): return list(set(stimuli_paths)) @@ -95,7 +127,7 @@ def _expand_paths(self, activations, original_paths): sorted_x = activations_paths[argsort_indices] sorted_index = np.searchsorted(sorted_x, original_paths) index = [argsort_indices[i] for i in sorted_index] - return activations[{'stimulus_path': index}] + return activations[{'presentation': index}] def register_batch_activations_hook(self, hook): r""" @@ -125,31 +157,65 @@ def register_stimulus_set_hook(self, hook): self._stimulus_set_hooks[handle.id] = hook return handle - def _get_activations_batched(self, paths, layers, batch_size): - layer_activations = None + def _get_activations_batched(self, paths, layers, batch_size: int, require_variance: bool): + layer_activations = OrderedDict() for batch_start in tqdm(range(0, len(paths), batch_size), unit_scale=batch_size, desc="activations"): batch_end = min(batch_start + batch_size, len(paths)) batch_inputs = paths[batch_start:batch_end] - batch_activations = self._get_batch_activations(batch_inputs, layer_names=layers, batch_size=batch_size) - for hook in self._batch_activations_hooks.copy().values(): # copy to avoid handle re-enabling messing with the loop + + batch_activations = OrderedDict() + # compute activations on the entire batch one microsaccade shift at a time. + for shift_number in range(self._microsaccade_helper.number_of_trials): + activations = self._get_batch_activations(inputs=batch_inputs, + layer_names=layers, + batch_size=batch_size, + require_variance=require_variance, + trial_number=shift_number) + + for layer_name, layer_output in activations.items(): + batch_activations.setdefault(layer_name, []).append(layer_output) + + # concatenate all microsaccade shifts in this batch (for example, if the model microsaccaded 15 times, + # the 15 microsaccaded layer_outputs are concatenated to the batch here. + for layer_name, layer_outputs in batch_activations.items(): + batch_activations[layer_name] = np.concatenate(layer_outputs) + + for hook in self._batch_activations_hooks.copy().values(): batch_activations = hook(batch_activations) - if layer_activations is None: - layer_activations = copy.copy(batch_activations) - else: - for layer_name, layer_output in batch_activations.items(): - layer_activations[layer_name] = np.concatenate((layer_activations[layer_name], layer_output)) + # add this batch to layer_activations + for layer_name, layer_output in batch_activations.items(): + layer_activations.setdefault(layer_name, []).append(layer_output) - return layer_activations + # concat all batches + for layer_name, layer_outputs in layer_activations.items(): + layer_activations[layer_name] = np.concatenate(layer_outputs) - def _get_batch_activations(self, inputs, layer_names, batch_size): + return layer_activations # this is all batches + + def _get_batch_activations(self, inputs, layer_names, batch_size: int, require_variance: bool = False, + trial_number: int = 1): inputs, num_padding = self._pad(inputs, batch_size) preprocessed_inputs = self.preprocess(inputs) + preprocessed_inputs = self._microsaccade_helper.translate_images(images=preprocessed_inputs, + image_paths=inputs, + trial_number=trial_number, + require_variance=require_variance) activations = self.get_activations(preprocessed_inputs, layer_names) assert isinstance(activations, OrderedDict) activations = self._unpad(activations, num_padding) + if require_variance: + self._microsaccade_helper.remove_temporary_files(preprocessed_inputs) return activations + def set_visual_degrees(self, visual_degrees: float): + """ + A method used by ModelCommitments to give the ActivationsExtractorHelper.MicrosaccadeHelper their visual + degrees for performing microsaccades. + """ + self._microsaccade_helper.visual_degrees = visual_degrees + + def _pad(self, batch_images, batch_size): num_images = len(batch_images) if num_images % batch_size == 0: @@ -161,11 +227,14 @@ def _pad(self, batch_images, batch_size): def _unpad(self, layer_activations, num_padding): return change_dict(layer_activations, lambda values: values[:-num_padding or None]) - def _package(self, layer_activations, stimuli_paths): + def _package(self, layer_activations, stimuli_paths, require_variance: bool): shapes = [a.shape for a in layer_activations.values()] self._logger.debug(f"Activations shapes: {shapes}") self._logger.debug("Packaging individual layers") - layer_assemblies = [self._package_layer(single_layer_activations, layer=layer, stimuli_paths=stimuli_paths) for + layer_assemblies = [self._package_layer(single_layer_activations, + layer=layer, + stimuli_paths=stimuli_paths, + require_variance=require_variance) for layer, single_layer_activations in tqdm(layer_activations.items(), desc='layer packaging')] # merge manually instead of using merge_data_arrays since `xarray.merge` is very slow with these large arrays # complication: (non)neuroid_coords are taken from the structure of layer_assemblies[0] i.e. the 1st assembly; @@ -182,17 +251,25 @@ def _package(self, layer_activations, stimuli_paths): for coord in neuroid_coords: neuroid_coords[coord][1] = np.concatenate((neuroid_coords[coord][1], layer_assembly[coord].values)) assert layer_assemblies[0].dims == layer_assembly.dims - for dim in set(layer_assembly.dims) - {'neuroid'}: - for coord in layer_assembly[dim].coords: - assert (layer_assembly[coord].values == nonneuroid_coords[coord][1]).all() + for coord, dims, values in walk_coords(layer_assembly): + if set(dims) == {'neuroid'}: + continue + assert (values == nonneuroid_coords[coord][1]).all() + neuroid_coords = {coord: (dims_values[0], dims_values[1]) # re-package as tuple instead of list for xarray for coord, dims_values in neuroid_coords.items()} model_assembly = type(layer_assemblies[0])(model_assembly, coords={**nonneuroid_coords, **neuroid_coords}, dims=layer_assemblies[0].dims) return model_assembly - def _package_layer(self, layer_activations, layer, stimuli_paths): - assert layer_activations.shape[0] == len(stimuli_paths) + def _package_layer(self, layer_activations: np.ndarray, layer: str, stimuli_paths: List[str], require_variance: bool = False): + # activation shape is larger if variance in responses is required from the model by a factor of number_of_trials + if require_variance: + runs_per_image = self._microsaccade_helper.number_of_trials + else: + runs_per_image = 1 + assert layer_activations.shape[0] == len(stimuli_paths) * runs_per_image + stimuli_paths = np.repeat(stimuli_paths, runs_per_image) activations, flatten_indices = flatten(layer_activations, return_index=True) # collapse for single neuroid dim flatten_coord_names = None if flatten_indices.shape[1] == 1: # fully connected, e.g. classifier @@ -209,17 +286,19 @@ def _package_layer(self, layer_activations, layer, stimuli_paths): self._logger.debug(f"Unknown layer activations shape {layer_activations.shape}, not inferring channels") # build assembly - coords = {'stimulus_path': stimuli_paths, + coords = {'stimulus_path': ('presentation', stimuli_paths), + **self._microsaccade_helper.build_microsaccade_coords(stimuli_paths), 'neuroid_num': ('neuroid', list(range(activations.shape[1]))), 'model': ('neuroid', [self.identifier] * activations.shape[1]), 'layer': ('neuroid', [layer] * activations.shape[1]), } + if flatten_coord_names: flatten_coords = {flatten_coord_names[i]: [sample_index[i] if i < flatten_indices.shape[1] else np.nan for sample_index in flatten_indices] for i in range(len(flatten_coord_names))} coords = {**coords, **{coord: ('neuroid', values) for coord, values in flatten_coords.items()}} - layer_assembly = NeuroidAssembly(activations, coords=coords, dims=['stimulus_path', 'neuroid']) + layer_assembly = NeuroidAssembly(activations, coords=coords, dims=['presentation', 'neuroid']) neuroid_id = [".".join([f"{value}" for value in values]) for values in zip(*[ layer_assembly[coord].values for coord in ['model', 'layer', 'neuroid_num']])] layer_assembly['neuroid_id'] = 'neuroid', neuroid_id @@ -232,6 +311,230 @@ def insert_attrs(self, wrapper): wrapper.register_stimulus_set_hook = self.register_stimulus_set_hook +class MicrosaccadeHelper: + """ + A class that allows ActivationsExtractorHelper to implement microsaccades. + + Human microsaccade amplitude varies by who you ask, an estimate might be <0.1 deg = 360 arcsec = 6arcmin. + Our motivation to make use of such microsaccades is to obtain multiple different neural activities to the + same input stimulus from non-stochastic models. This enables models to engage on e.g. psychophysical + functions which often require variance for the same stimulus. In the current implementation, + if `require_variance=True`, the model microsaccades in the preprocessed input space in sub-pixel increments, + the extent and position of which are determined by `self._visual_degrees`, and + `self.microsaccade_extent_degrees`. + + More information: + --> Rolfs 2009 "Microsaccades: Small steps on a long way" Vision Research, Volume 49, Issue 20, 15 + October 2009, Pages 2415-2441. + --> Haddad & Steinmann 1973 "The smallest voluntary saccade: Implications for fixation" Vision + Research Volume 13, Issue 6, June 1973, Pages 1075-1086, IN5-IN6. + Implemented by Ben Lonnqvist and Johannes Mehrer. + """ + def __init__(self): + self._logger = logging.getLogger(fullname(self)) + self.number_of_trials = 1 # for use with microsaccades. + self.microsaccade_extent_degrees = 0.05 # how many degrees models microsaccade by default + + # a dict that contains two dicts, one for representing microsaccades in pixels, and one in degrees. + # Each dict inside contain image paths and their respective microsaccades. For example + # {'pixels': {'abc.jpg': [(0, 0), (1.5, 2)]}, 'degrees': {'abc.jpg': [(0., 0.), (0.0075, 0.001)]}} + self.microsaccades = {'pixels': {}, 'degrees': {}} + # Model visual degrees. Used for computing microsaccades in the space of degrees rather than pixels + self.visual_degrees = None + + def translate_images(self, images: List[Union[str, np.ndarray]], image_paths: List[str], trial_number: int, + require_variance: bool) -> List[str]: + """ + Translate images according to selected microsaccades, if microsaccades are required. + + :param images: A list of images (in the case of tensorflow models), or a list of arrays (non-tf models). + :param image_paths: A list of image paths. Both `image_paths` and `images` are needed since while both tf and + non-tf models preprocess images before this point, non-tf models' preprocessed images + are fixed as arrays when fed into here. As such, simply returning `image_paths` for + non-tf models would require double-loading of the images, which does not seem like a + good idea. + """ + output_images = [] + for index, image_path in enumerate(image_paths): + # When microsaccades are not used, skip computing them and return the base images. + # This iteration could be entirely skipped, but recording microsaccades for all images regardless + # of whether variance is required or not is convenient for adding an extra presentation dimension + # in the layer assembly later to keep track of as much metadata as possible, to avoid layer assembly + # collapse, or to avoid otherwise extraneous mock dims. + # The method could further be streamlined by calling `self.get_image_with_shape()` and + # `self.select_microsaccade` for all images regardless of require_variance, but it seems like a bad + # idea to introduce cv2 image loading for all models and images, regardless of whether they are actually + # microsaccading. + if not require_variance: + self.microsaccades['pixels'][image_path] = [(0., 0.)] + self.microsaccades['degrees'][image_path] = [(0., 0.)] + output_images.append(images[index]) + else: + # translate images according to microsaccades if we are using microsaccades + image, image_shape, image_is_channels_first = self.get_image_with_shape(images[index]) + microsaccade_location_pixels = self.select_microsaccade(image_path=image_path, + trial_number=trial_number, + image_shape=image_shape) + return_string = True if isinstance(images[index], str) else False + output_images.append(self.translate_image(image=image, + microsaccade_location=microsaccade_location_pixels, + image_shape=image_shape, + return_string=return_string, + image_is_channels_first=image_is_channels_first)) + return self.reshape_microsaccaded_images(output_images) + + def translate_image(self, image: str, microsaccade_location: Tuple[float, float], image_shape: Tuple[int, int], + return_string: bool, image_is_channels_first: bool) -> str: + """Translates and saves a temporary image to temporary_fp.""" + translated_image = self.translate(image=image, shift=microsaccade_location, image_shape=image_shape, + image_is_channels_first=image_is_channels_first) + if not return_string: # if the model accepts ndarrays after preprocessing, return one + return translated_image + else: # if the model accepts strings after preprocessing, write temp file + temp_file_descriptor, temporary_fp = tempfile.mkstemp(suffix=".png") + os.close(temp_file_descriptor) + if not cv2.imwrite(temporary_fp, translated_image): + raise Exception(f"cv2.imwrite failed: {temporary_fp}") + return temporary_fp + + def select_microsaccade(self, image_path: str, trial_number: int, image_shape: Tuple[int, int] + ) -> Tuple[float, float]: + """ + A function for generating a microsaccade location. The function returns a tuple of pixel shifts expanding from + the center of the image. + + Microsaccade locations are placed within a circle, evenly distributed across the entire area in a spiral, + from the center to the circumference. We keep track of microsaccades both on a pixel and visual angle basis, + but only pixel values are returned. This is because shifting the image using cv2 requires pixel representation. + """ + # if we did not already compute `self.microsaccades`, we build them first. + if image_path not in self.microsaccades.keys(): + self.build_microsaccades(image_path=image_path, image_shape=image_shape) + return self.microsaccades['pixels'][image_path][trial_number] + + def build_microsaccades(self, image_path: str, image_shape: Tuple[int, int]): + if image_shape[0] != image_shape[1]: + self._logger.debug('Input image is not a square. Image dimension 0 is used to calculate the ' + 'extent of microsaccades.') + + assert self.visual_degrees is not None, ( + 'self._visual_degrees is not set by the ModelCommitment, but microsaccades ' + 'are in use. Set activations_model visual degrees in your commitment after defining ' + 'your activations_model. For example, self.activations_model.set_visual_degrees' + '(visual_degrees). For detailed information, see ' + ':meth:`~brainscore_vision.model_helpers.activations.ActivationsExtractorHelper.' + '__call__`,') + # compute the maximum radius of microsaccade extent in pixel space + radius_ratio = self.microsaccade_extent_degrees / self.visual_degrees + max_radius = radius_ratio * image_shape[0] # maximum radius in pixels, set in self.microsaccade_extent_degrees + + selected_microsaccades = {'pixels': [], 'degrees': []} + # microsaccades are placed in a spiral at sub-pixel increments + a = max_radius / np.sqrt(self.number_of_trials) # spiral coefficient to space microsaccades evenly + for i in range(self.number_of_trials): + r = np.sqrt(i / self.number_of_trials) * max_radius # compute radial distance for the i-th point + theta = a * np.sqrt(i) * 2 * np.pi / max_radius # compute angle for the i-th point + + # convert polar coordinates to Cartesian, centered on the image + x = r * np.cos(theta) + y = r * np.sin(theta) + + pixels_per_degree = self.calculate_pixels_per_degree_in_image(image_shape[0]) + selected_microsaccades['pixels'].append((x, y)) + selected_microsaccades['degrees'].append(self.convert_pixels_to_degrees((x, y), pixels_per_degree)) + + # to keep consistent with number_of_trials, we count trial_number from 1 instead of from 0 + self.microsaccades['pixels'][image_path] = selected_microsaccades['pixels'] + self.microsaccades['degrees'][image_path] = selected_microsaccades['degrees'] + + def unpack_microsaccade_coords(self, stimuli_paths: np.ndarray, pixels_or_degrees: str, dim: int): + """Unpacks microsaccades from stimuli_paths into a single list to conform with coord requirements.""" + assert pixels_or_degrees == 'pixels' or pixels_or_degrees == 'degrees' + unpacked_microsaccades = [] + for stimulus_path in stimuli_paths: + for microsaccade in self.microsaccades[pixels_or_degrees][stimulus_path]: + unpacked_microsaccades.append(microsaccade[dim]) + return unpacked_microsaccades + + def calculate_pixels_per_degree_in_image(self, image_width_pixels: int) -> float: + """Calculates the pixels per degree in the image, assuming the calculation based on image width.""" + pixels_per_degree = image_width_pixels / self.visual_degrees + return pixels_per_degree + + def build_microsaccade_coords(self, stimuli_paths: np.array) -> Dict: + return { + 'microsaccade_shift_x_pixels': ('presentation', self.unpack_microsaccade_coords( + np.unique(stimuli_paths), + pixels_or_degrees='pixels', + dim=0)), + 'microsaccade_shift_y_pixels': ('presentation', self.unpack_microsaccade_coords( + np.unique(stimuli_paths), + pixels_or_degrees='pixels', + dim=1)), + 'microsaccade_shift_x_degrees': ('presentation', self.unpack_microsaccade_coords( + np.unique(stimuli_paths), + pixels_or_degrees='degrees', + dim=0)), + 'microsaccade_shift_y_degrees': ('presentation', self.unpack_microsaccade_coords( + np.unique(stimuli_paths), + pixels_or_degrees='degrees', + dim=1)) + } + + @staticmethod + def convert_pixels_to_degrees(pixel_coords: Tuple[float, float], pixels_per_degree: float) -> Tuple[float, float]: + degrees_x = pixel_coords[0] / pixels_per_degree + degrees_y = pixel_coords[1] / pixels_per_degree + return degrees_x, degrees_y + + @staticmethod + def remove_temporary_files(temporary_file_paths: List[str]) -> None: + """ + This function is used to manually remove all temporary file paths. We do this instead of using implicit + python garbage collection to 1) ensure that tensorflow models have access to temporary files when needed; + 2) to make the point at which temporary files are removed explicit. + """ + for temporary_file_path in temporary_file_paths: + if isinstance(temporary_file_path, str): # do not try to remove loaded images + try: + os.remove(temporary_file_path) + except FileNotFoundError: + pass + + @staticmethod + def translate(image: np.array, shift: Tuple[float, float], image_shape: Tuple[int, int], + image_is_channels_first: bool) -> np.array: + rows, cols = image_shape + # translation matrix + M = np.float32([[1, 0, shift[0]], [0, 1, shift[1]]]) + + if image_is_channels_first: + image = np.transpose(image, (1, 2, 0)) # cv2 expects channels last + # Apply translation, filling new line(s) with line(s) closest to it(them). + translated_image = cv2.warpAffine(image, M, (cols, rows), flags=cv2.INTER_LINEAR, # for sub-pixel shifts + borderMode=cv2.BORDER_REPLICATE) + if image_is_channels_first: + translated_image = np.transpose(translated_image, (2, 0, 1)) # convert the image back to channels-first + return translated_image + + @staticmethod + def get_image_with_shape(image: Union[str, np.ndarray]) -> Tuple[np.array, Tuple[int, int], bool]: + if isinstance(image, str): # tf models return strings after preprocessing + image = cv2.imread(image) + rows, cols, _ = image.shape # cv2 uses height, width, channels + image_is_channels_first = False + else: + _, rows, cols, = image.shape # pytorch and keras use channels, height, width + image_is_channels_first = True + return image, (rows, cols), image_is_channels_first + + @staticmethod + def reshape_microsaccaded_images(images: List) -> Union[List[str], np.ndarray]: + if any(isinstance(image, str) for image in images): + return images + return np.stack(images, axis=0) + + def change_dict(d, change_function, keep_name=False, multithread=False): if not multithread: map_fnc = map @@ -261,16 +564,37 @@ def lstrip_local(path): return path -def attach_stimulus_set_meta(assembly, stimulus_set): +def attach_stimulus_set_meta(assembly, stimulus_set, number_of_trials: int, require_variance: bool = False): stimulus_paths = [str(stimulus_set.get_stimulus(stimulus_id)) for stimulus_id in stimulus_set['stimulus_id']] stimulus_paths = [lstrip_local(path) for path in stimulus_paths] assembly_paths = [lstrip_local(path) for path in assembly['stimulus_path'].values] - assert (np.array(assembly_paths) == np.array(stimulus_paths)).all() - assembly['stimulus_path'] = stimulus_set['stimulus_id'].values + + # when microsaccades are used, we repeat stimulus_paths number_of_trials times to correctly populate the dim + if require_variance: + replication_factor = number_of_trials + else: + replication_factor = 1 + repeated_stimulus_paths = np.repeat(stimulus_paths, replication_factor) + assert (np.array(assembly_paths) == np.array(repeated_stimulus_paths)).all() + repeated_stimulus_ids = np.repeat(stimulus_set['stimulus_id'].values, replication_factor) + + if replication_factor > 1: + # repeat over the presentation dimension to accommodate multiple runs per stimulus + assembly = xr.concat([assembly for _ in range(replication_factor)], dim='presentation') + assembly = assembly.reset_index('presentation') + assembly['stimulus_path'] = ('presentation', repeated_stimulus_ids) assembly = assembly.rename({'stimulus_path': 'stimulus_id'}) + + assert (np.array(assembly_paths) == np.array(stimulus_paths)).all() + + all_columns = [] for column in stimulus_set.columns: - assembly[column] = 'stimulus_id', stimulus_set[column].values - assembly = assembly.stack(presentation=('stimulus_id',)) + repeated_values = np.repeat(stimulus_set[column].values, replication_factor) + assembly = assembly.assign_coords({column: ('presentation', repeated_values)}) # assign multiple coords at once + all_columns.append(column) + + presentation_coords = all_columns + [coord for coord, dims, values in walk_coords(assembly['presentation'])] + assembly = assembly.set_index(presentation=list(set(presentation_coords))) # assign MultiIndex return assembly diff --git a/brainscore_vision/model_helpers/brain_transformation/__init__.py b/brainscore_vision/model_helpers/brain_transformation/__init__.py index 47938c6d3..90df82460 100644 --- a/brainscore_vision/model_helpers/brain_transformation/__init__.py +++ b/brainscore_vision/model_helpers/brain_transformation/__init__.py @@ -24,6 +24,10 @@ def __init__(self, identifier, visual_degrees=8): self.layers = layers self.activations_model = activations_model + # We set the visual degrees of the ActivationsExtractorHelper here to avoid changing its signature. + # The ideal solution would be to not expose the _extractor of the activations_model here, but to change + # the signature of the ActivationsExtractorHelper. See https://github.com/brain-score/vision/issues/554 + self.activations_model._extractor.set_visual_degrees(visual_degrees) # for microsaccades self._visual_degrees = visual_degrees # region-layer mapping if region_layer_map is None: diff --git a/brainscore_vision/model_helpers/brain_transformation/neural.py b/brainscore_vision/model_helpers/brain_transformation/neural.py index 51cff1eb3..f1cb04356 100644 --- a/brainscore_vision/model_helpers/brain_transformation/neural.py +++ b/brainscore_vision/model_helpers/brain_transformation/neural.py @@ -23,7 +23,13 @@ def __init__(self, identifier, activations_model, region_layer_map, visual_degre def identifier(self): return self._identifier - def look_at(self, stimuli, number_of_trials=1): + def look_at(self, stimuli, number_of_trials=1, require_variance: bool = False): + """ + :param number_of_trials: An integer that determines how many repetitions of the same image the model performs. + :param require_variance: Whether to require models to return different activations for the same stimuli or not. + For detailed information, see + :meth:`~brainscore_vision.model_helpers.activations.ActivationsExtractorHelper.__call__`, + """ layer_regions = {} for region in self.recorded_regions: layers = self.region_layer_map[region] @@ -31,13 +37,16 @@ def look_at(self, stimuli, number_of_trials=1): for layer in layers: assert layer not in layer_regions, f"layer {layer} has already been assigned for {layer_regions[layer]}" layer_regions[layer] = region - activations = self.run_activations( - stimuli, layers=list(layer_regions.keys()), number_of_trials=number_of_trials) + activations = self.run_activations(stimuli, + layers=list(layer_regions.keys()), + number_of_trials=number_of_trials, + require_variance=require_variance) activations['region'] = 'neuroid', [layer_regions[layer] for layer in activations['layer'].values] return activations - def run_activations(self, stimuli, layers, number_of_trials=1): - activations = self.activations_model(stimuli, layers=layers) + def run_activations(self, stimuli, layers, number_of_trials=1, require_variance=None): + activations = self.activations_model(stimuli, layers=layers, number_of_trials=number_of_trials, + require_variance=require_variance) return activations def start_task(self, task): diff --git a/brainscore_vision/model_helpers/check_submission/check_models.py b/brainscore_vision/model_helpers/check_submission/check_models.py index 49be433ae..b4e386e14 100644 --- a/brainscore_vision/model_helpers/check_submission/check_models.py +++ b/brainscore_vision/model_helpers/check_submission/check_models.py @@ -1,13 +1,12 @@ import numpy as np import os - -import brainscore_vision.metric_helpers from brainio.assemblies import NeuroidAssembly from brainio.stimuli import StimulusSet +from brainscore_vision import load_ceiling, load_metric, load_dataset from brainscore_vision.benchmark_helpers.neural_common import average_repetition, timebins_from_assembly +from brainscore_vision.benchmark_helpers.screen import place_on_screen from brainscore_vision.benchmarks import BenchmarkBase, ceil_score from brainscore_vision.metrics.internal_consistency import InternalConsistency -from brainscore_vision.metrics.regression_correlation import CrossRegressedCorrelation, pls_regression, pearsonr_correlation from brainscore_vision.model_helpers.brain_transformation import ModelCommitment, LayerSelection, RegionLayerMap from brainscore_vision.model_interface import BrainModel @@ -59,59 +58,30 @@ def check_processing(model_identifier, module): class _MockBenchmark(BenchmarkBase): def __init__(self): - assembly_repetition = get_assembly() - assert len(np.unique(assembly_repetition['region'])) == 1 + assembly_repetition = load_dataset("MajajHong2015.public").sel(region="IT").squeeze("time_bin") assert hasattr(assembly_repetition, 'repetition') self.region = 'IT' self.assembly = average_repetition(assembly_repetition) self._assembly = self.assembly self.timebins = timebins_from_assembly(self.assembly) - - self._similarity_metric = CrossRegressedCorrelation( - regression=pls_regression(), correlation=pearsonr_correlation(), - crossvalidation_kwargs=dict(stratification_coord=brainscore_vision.metric_helpers.Defaults.stratification_coord - if hasattr(self.assembly, brainscore_vision.metric_helpers.Defaults.stratification_coord) else None)) + self._similarity_metric = load_metric('pls', crossvalidation_kwargs=dict(stratification_coord='object_name')) identifier = f'{assembly_repetition.name}-layer_selection' - ceiler = InternalConsistency() + ceiler = load_ceiling('internal_consistency') super(_MockBenchmark, self).__init__(identifier=identifier, ceiling_func=lambda: ceiler(assembly_repetition), version='1.0') def __call__(self, candidate: BrainModel, do_behavior=False): + # adapt stimuli to visual degrees + stimuli = place_on_screen(self.assembly.stimulus_set, target_visual_degrees=candidate.visual_degrees(), + source_visual_degrees=8) # arbitrary choice for source degrees # Check neural recordings candidate.start_recording(self.region, time_bins=self.timebins) - source_assembly = candidate.look_at(self.assembly.stimulus_set) + source_assembly = candidate.look_at(stimuli) # Check behavioral tasks if do_behavior: candidate.start_task(BrainModel.Task.probabilities, self.assembly.stimulus_set) - candidate.look_at(self.assembly.stimulus_set) + candidate.look_at(stimuli) raw_score = self._similarity_metric(source_assembly, self.assembly) return ceil_score(raw_score, self.ceiling) - -def get_assembly(): - image_names = [] - for i in range(1, 21): - image_names.append(f'images/{i}.png') - assembly = NeuroidAssembly((np.arange(40 * 5) + np.random.standard_normal(40 * 5)).reshape((5, 40, 1)), - coords={'stimulus_id': ( - 'presentation', - image_names * 2), - 'object_name': ('presentation', ['a'] * 40), - 'repetition': ('presentation', ([1] * 20 + [2] * 20)), - 'neuroid_id': ('neuroid', np.arange(5)), - 'region': ('neuroid', ['IT'] * 5), - 'time_bin_start': ('time_bin', [70]), - 'time_bin_end': ('time_bin', [170]) - }, - dims=['neuroid', 'presentation', 'time_bin']) - labels = ['a'] * 10 + ['b'] * 10 - stimulus_set = StimulusSet([{'stimulus_id': image_names[i], 'object_name': 'a', 'image_label': labels[i]} - for i in range(20)]) - stimulus_set.stimulus_paths = {image_name: os.path.join(os.path.dirname(__file__), image_name) - for image_name in image_names} - stimulus_set.identifier = 'test' - assembly.attrs['stimulus_set'] = stimulus_set - assembly.attrs['stimulus_set_name'] = stimulus_set.identifier - assembly = assembly.squeeze("time_bin") - return assembly.transpose('presentation', 'neuroid') diff --git a/brainscore_vision/models/cornet_s_0_0_0/__init__.py b/brainscore_vision/models/cornet_s_0_0_0/__init__.py new file mode 100644 index 000000000..47fceee3c --- /dev/null +++ b/brainscore_vision/models/cornet_s_0_0_0/__init__.py @@ -0,0 +1,5 @@ +from brainscore_vision import model_registry +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +from .model import get_model, get_layers + +model_registry['cornet_s_0_0_0'] = lambda: ModelCommitment(identifier='cornet_s_0_0_0', activations_model=get_model('cornet_s_0_0_0'), layers=get_layers('cornet_s_0_0_0')) diff --git a/brainscore_vision/models/cornet_s_0_0_0/model.py b/brainscore_vision/models/cornet_s_0_0_0/model.py new file mode 100644 index 000000000..5fbbf7ed4 --- /dev/null +++ b/brainscore_vision/models/cornet_s_0_0_0/model.py @@ -0,0 +1,188 @@ +from brainscore_vision.model_helpers.check_submission import check_models +import functools +from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper +from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images +import torch +import numpy as np +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +import math +from collections import OrderedDict +from torch import nn + + +HASH = '1d3f7974' + + +class Flatten(nn.Module): + + """ + Helper module for flattening input tensor to 1-D for the use in Linear modules + """ + + def forward(self, x): + return x.view(x.size(0), -1) + + +class Identity(nn.Module): + + """ + Helper module that stores the current tensor. Useful for accessing by name + """ + + def forward(self, x): + return x + + +class CORblock_S(nn.Module): + + scale = 4 # scale of the bottleneck convolution channels + + def __init__(self, in_channels, out_channels, times=1): + super().__init__() + + self.times = times + + self.conv_input = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self.skip = nn.Conv2d(out_channels, out_channels, + kernel_size=1, stride=2, bias=False) + self.norm_skip = nn.BatchNorm2d(out_channels) + + self.conv1 = nn.Conv2d(out_channels, out_channels * self.scale, + kernel_size=1, bias=False) + self.nonlin1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(out_channels * self.scale, out_channels * self.scale, + kernel_size=3, stride=2, padding=1, bias=False) + self.nonlin2 = nn.ReLU(inplace=True) + + self.conv3 = nn.Conv2d(out_channels * self.scale, out_channels, + kernel_size=1, bias=False) + self.nonlin3 = nn.ReLU(inplace=True) + + self.output = Identity() # for an easy access to this block's output + + # need BatchNorm for each time step for training to work well + for t in range(self.times): + setattr(self, f'norm1_{t}', nn.BatchNorm2d(out_channels * self.scale)) + setattr(self, f'norm2_{t}', nn.BatchNorm2d(out_channels * self.scale)) + setattr(self, f'norm3_{t}', nn.BatchNorm2d(out_channels)) + + def forward(self, inp): + x = self.conv_input(inp) + + for t in range(self.times): + if t == 0: + skip = self.norm_skip(self.skip(x)) + self.conv2.stride = (2, 2) + else: + skip = x + self.conv2.stride = (1, 1) + + x = self.conv1(x) + x = getattr(self, f'norm1_{t}')(x) + x = self.nonlin1(x) + + x = self.conv2(x) + x = getattr(self, f'norm2_{t}')(x) + x = self.nonlin2(x) + + x = self.conv3(x) + x = getattr(self, f'norm3_{t}')(x) + + x += skip + x = self.nonlin3(x) + output = self.output(x) + + return output + + +def CORnet_S(): + model = nn.Sequential(OrderedDict([ + ('V1', nn.Sequential(OrderedDict([ # this one is custom to save GPU memory + ('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False)), + ('norm1', nn.BatchNorm2d(64)), + ('nonlin1', nn.ReLU(inplace=True)), + ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, + bias=False)), + ('norm2', nn.BatchNorm2d(64)), + ('nonlin2', nn.ReLU(inplace=True)), + ('output', Identity()) + ]))), + ('V2', CORblock_S(64, 128, times=2)), + ('V4', CORblock_S(128, 256, times=4)), + ('IT', CORblock_S(256, 512, times=2)), + ('decoder', nn.Sequential(OrderedDict([ + ('avgpool', nn.AdaptiveAvgPool2d(1)), + ('flatten', Flatten()), + ('linear', nn.Linear(512, 1000)), + ('output', Identity()) + ]))) + ])) + + # weight initialization + for m in model.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + # nn.Linear is missing here because I originally forgot + # to add it during the training of this network + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + return model + + +def get_custom_cornet_s(): + model = CORnet_S() + model = torch.nn.DataParallel(model) + url = f'https://storage.googleapis.com/neurodp/0_0_0_ckpt.pt' + ckpt_data = torch.hub.load_state_dict_from_url(url) + model.load_state_dict(ckpt_data) + model = model.module + for param in model.parameters(): + param.requires_grad = True + + return model + + +def get_model_list(): + return ['cornet_s_0_0_0'] + + +def get_model(name): + assert name == 'cornet_s_0_0_0' + model = get_custom_cornet_s() + preprocessing = functools.partial(load_preprocess_images, image_size=224) + wrapper = PytorchWrapper(identifier='cornet_s_0_0_0', + model=model, + preprocessing=preprocessing) + wrapper.image_size = 224 + return wrapper + + +def get_layers(name): + assert name == 'cornet_s_0_0_0' + return ['V1', 'V2', 'V4', 'IT', 'decoder'] + + +def get_bibtex(model_identifier): + return """@inproceedings{KubiliusSchrimpf2019CORnet, + archivePrefix = {arXiv}, + arxivId = {1909.06161}, + author = {Kubilius, Jonas and Schrimpf, Martin and Hong, Ha and Majaj, Najib J. and Rajalingham, Rishi and Issa, Elias B. and Kar, Kohitij and Bashivan, Pouya and Prescott-Roy, Jonathan and Schmidt, Kailyn and Nayebi, Aran and Bear, Daniel and Yamins, Daniel L. K. and DiCarlo, James J.}, + booktitle = {Neural Information Processing Systems (NeurIPS)}, + editor = {Wallach, H. and Larochelle, H. and Beygelzimer, A. and D'Alch{\'{e}}-Buc, F. and Fox, E. and Garnett, R.}, + pages = {12785----12796}, + publisher = {Curran Associates, Inc.}, + title = {{Brain-Like Object Recognition with High-Performing Shallow Recurrent ANNs}}, + url = {http://papers.nips.cc/paper/9441-brain-like-object-recognition-with-high-performing-shallow-recurrent-anns}, + year = {2019} + } + """ + + +if __name__ == '__main__': + check_models.check_base_models(__name__) diff --git a/brainscore_vision/models/cornet_s_0_0_0/setup.py b/brainscore_vision/models/cornet_s_0_0_0/setup.py new file mode 100644 index 000000000..421914cfb --- /dev/null +++ b/brainscore_vision/models/cornet_s_0_0_0/setup.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from setuptools import setup, find_packages + +requirements = [ "torchvision", + "torch" +] + +setup( + packages=find_packages(exclude=['tests']), + include_package_data=True, + install_requires=requirements, + license="MIT license", + zip_safe=False, + keywords='brain-score template', + classifiers=[ + 'Development Status :: 2 - Pre-Alpha', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Programming Language :: Python :: 3.7', + ], + test_suite='tests', +) diff --git a/brainscore_vision/models/cornet_s_0_0_0/test.py b/brainscore_vision/models/cornet_s_0_0_0/test.py new file mode 100644 index 000000000..e594ba9e1 --- /dev/null +++ b/brainscore_vision/models/cornet_s_0_0_0/test.py @@ -0,0 +1 @@ +# Left empty as part of 2023 models migration diff --git a/brainscore_vision/models/cornet_s_0_1_0/__init__.py b/brainscore_vision/models/cornet_s_0_1_0/__init__.py new file mode 100644 index 000000000..0b50cf923 --- /dev/null +++ b/brainscore_vision/models/cornet_s_0_1_0/__init__.py @@ -0,0 +1,5 @@ +from brainscore_vision import model_registry +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +from .model import get_model, get_layers + +model_registry['cornet_s_0_1_0'] = lambda: ModelCommitment(identifier='cornet_s_0_1_0', activations_model=get_model('cornet_s_0_1_0'), layers=get_layers('cornet_s_0_1_0')) diff --git a/brainscore_vision/models/cornet_s_0_1_0/model.py b/brainscore_vision/models/cornet_s_0_1_0/model.py new file mode 100644 index 000000000..ece765616 --- /dev/null +++ b/brainscore_vision/models/cornet_s_0_1_0/model.py @@ -0,0 +1,201 @@ +from brainscore_vision.model_helpers.check_submission import check_models +import functools +from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper +from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images +import torch +import numpy as np +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +import math +from collections import OrderedDict +from torch import nn +import torch.nn.utils.prune as prune + + +HASH = '1d3f7974' + + +class Flatten(nn.Module): + + """ + Helper module for flattening input tensor to 1-D for the use in Linear modules + """ + + def forward(self, x): + return x.view(x.size(0), -1) + + +class Identity(nn.Module): + + """ + Helper module that stores the current tensor. Useful for accessing by name + """ + + def forward(self, x): + return x + + +class CORblock_S(nn.Module): + + scale = 4 # scale of the bottleneck convolution channels + + def __init__(self, in_channels, out_channels, times=1): + super().__init__() + + self.times = times + + self.conv_input = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self.skip = nn.Conv2d(out_channels, out_channels, + kernel_size=1, stride=2, bias=False) + self.norm_skip = nn.BatchNorm2d(out_channels) + + self.conv1 = nn.Conv2d(out_channels, out_channels * self.scale, + kernel_size=1, bias=False) + self.nonlin1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(out_channels * self.scale, out_channels * self.scale, + kernel_size=3, stride=2, padding=1, bias=False) + self.nonlin2 = nn.ReLU(inplace=True) + + self.conv3 = nn.Conv2d(out_channels * self.scale, out_channels, + kernel_size=1, bias=False) + self.nonlin3 = nn.ReLU(inplace=True) + + self.output = Identity() # for an easy access to this block's output + + # need BatchNorm for each time step for training to work well + for t in range(self.times): + setattr(self, f'norm1_{t}', nn.BatchNorm2d(out_channels * self.scale)) + setattr(self, f'norm2_{t}', nn.BatchNorm2d(out_channels * self.scale)) + setattr(self, f'norm3_{t}', nn.BatchNorm2d(out_channels)) + + def forward(self, inp): + x = self.conv_input(inp) + + for t in range(self.times): + if t == 0: + skip = self.norm_skip(self.skip(x)) + self.conv2.stride = (2, 2) + else: + skip = x + self.conv2.stride = (1, 1) + + x = self.conv1(x) + x = getattr(self, f'norm1_{t}')(x) + x = self.nonlin1(x) + + x = self.conv2(x) + x = getattr(self, f'norm2_{t}')(x) + x = self.nonlin2(x) + + x = self.conv3(x) + x = getattr(self, f'norm3_{t}')(x) + + x += skip + x = self.nonlin3(x) + output = self.output(x) + + return output + + +def CORnet_S(): + model = nn.Sequential(OrderedDict([ + ('V1', nn.Sequential(OrderedDict([ # this one is custom to save GPU memory + ('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False)), + ('norm1', nn.BatchNorm2d(64)), + ('nonlin1', nn.ReLU(inplace=True)), + ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, + bias=False)), + ('norm2', nn.BatchNorm2d(64)), + ('nonlin2', nn.ReLU(inplace=True)), + ('output', Identity()) + ]))), + ('V2', CORblock_S(64, 128, times=2)), + ('V4', CORblock_S(128, 256, times=4)), + ('IT', CORblock_S(256, 512, times=2)), + ('decoder', nn.Sequential(OrderedDict([ + ('avgpool', nn.AdaptiveAvgPool2d(1)), + ('flatten', Flatten()), + ('linear', nn.Linear(512, 1000)), + ('output', Identity()) + ]))) + ])) + + # weight initialization + for m in model.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + # nn.Linear is missing here because I originally forgot + # to add it during the training of this network + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + return model + + +def get_custom_cornet_s(): + model = CORnet_S() + model = torch.nn.DataParallel(model) + regions = [model.module.V1, model.module.V2, model.module.V4, model.module.IT] + region_idx = 0 + region = regions[region_idx] + lesion_iters = 1 + retrain_epochs = 0 + + url = f'https://storage.googleapis.com/neurodp/{region_idx}_{lesion_iters}_{retrain_epochs}_ckpt.pt' + ckpt_data = torch.hub.load_state_dict_from_url(url) + + for _ in range(lesion_iters): + conv_layers = [module for module in region.modules() if isinstance(module, torch.nn.Conv2d)] + for x in conv_layers: + prune.random_unstructured(x, name='weight', amount=0.2) + + model.load_state_dict(ckpt_data) + model = model.module + for param in model.parameters(): + param.requires_grad = True + + return model + + +def get_model_list(): + return ['cornet_s_0_1_0'] + + +def get_model(name): + assert name == 'cornet_s_0_1_0' + model = get_custom_cornet_s() + preprocessing = functools.partial(load_preprocess_images, image_size=224) + wrapper = PytorchWrapper(identifier='cornet_s_0_1_0', + model=model, + preprocessing=preprocessing) + wrapper.image_size = 224 + return wrapper + + +def get_layers(name): + assert name == 'cornet_s_0_1_0' + return ['V1', 'V2', 'V4', 'IT', 'decoder'] + + +def get_bibtex(model_identifier): + return """@inproceedings{KubiliusSchrimpf2019CORnet, + archivePrefix = {arXiv}, + arxivId = {1909.06161}, + author = {Kubilius, Jonas and Schrimpf, Martin and Hong, Ha and Majaj, Najib J. and Rajalingham, Rishi and Issa, Elias B. and Kar, Kohitij and Bashivan, Pouya and Prescott-Roy, Jonathan and Schmidt, Kailyn and Nayebi, Aran and Bear, Daniel and Yamins, Daniel L. K. and DiCarlo, James J.}, + booktitle = {Neural Information Processing Systems (NeurIPS)}, + editor = {Wallach, H. and Larochelle, H. and Beygelzimer, A. and D'Alch{\'{e}}-Buc, F. and Fox, E. and Garnett, R.}, + pages = {12785----12796}, + publisher = {Curran Associates, Inc.}, + title = {{Brain-Like Object Recognition with High-Performing Shallow Recurrent ANNs}}, + url = {http://papers.nips.cc/paper/9441-brain-like-object-recognition-with-high-performing-shallow-recurrent-anns}, + year = {2019} + } + """ + + +if __name__ == '__main__': + check_models.check_base_models(__name__) diff --git a/brainscore_vision/models/cornet_s_0_1_0/setup.py b/brainscore_vision/models/cornet_s_0_1_0/setup.py new file mode 100644 index 000000000..421914cfb --- /dev/null +++ b/brainscore_vision/models/cornet_s_0_1_0/setup.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from setuptools import setup, find_packages + +requirements = [ "torchvision", + "torch" +] + +setup( + packages=find_packages(exclude=['tests']), + include_package_data=True, + install_requires=requirements, + license="MIT license", + zip_safe=False, + keywords='brain-score template', + classifiers=[ + 'Development Status :: 2 - Pre-Alpha', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Programming Language :: Python :: 3.7', + ], + test_suite='tests', +) diff --git a/brainscore_vision/models/cornet_s_0_1_0/test.py b/brainscore_vision/models/cornet_s_0_1_0/test.py new file mode 100644 index 000000000..e594ba9e1 --- /dev/null +++ b/brainscore_vision/models/cornet_s_0_1_0/test.py @@ -0,0 +1 @@ +# Left empty as part of 2023 models migration diff --git a/brainscore_vision/models/eBarlow_Vanilla/__init__.py b/brainscore_vision/models/eBarlow_Vanilla/__init__.py new file mode 100644 index 000000000..87ea2edaf --- /dev/null +++ b/brainscore_vision/models/eBarlow_Vanilla/__init__.py @@ -0,0 +1,9 @@ +from brainscore_vision import model_registry +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +from .model import get_model, get_layers + +model_registry["eBarlow_Vanilla"] = lambda: ModelCommitment( + identifier="eBarlow_Vanilla", + activations_model=get_model("eBarlow_Vanilla"), + layers=get_layers("eBarlow_Vanilla"), +) diff --git a/brainscore_vision/models/eBarlow_Vanilla/model.py b/brainscore_vision/models/eBarlow_Vanilla/model.py new file mode 100644 index 000000000..952b7e457 --- /dev/null +++ b/brainscore_vision/models/eBarlow_Vanilla/model.py @@ -0,0 +1,50 @@ +from brainscore_vision.model_helpers.check_submission import check_models +import functools +import os +from urllib.request import urlretrieve +import torchvision.models +from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper +from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images +from pathlib import Path +from brainscore_vision.model_helpers import download_weights +import torch + +# This is an example implementation for submitting resnet-50 as a pytorch model + +# Attention: It is important, that the wrapper identifier is unique per model! +# The results will otherwise be the same due to brain-scores internal result caching mechanism. +# Please load your pytorch model for usage in CPU. There won't be GPUs available for scoring your model. +# If the model requires a GPU, contact the brain-score team directly. +from brainscore_vision.model_helpers.check_submission import check_models + + +def get_model_list(): + return ["eBarlow_Vanilla"] + + +def get_model(name): + assert name == "eBarlow_Vanilla" + model = torchvision.models.resnet50(pretrained=False) + url = "https://users.flatironinstitute.org/~tyerxa/equi_proj/training_checkpoints/classifiers/barlow/vanilla/classifier.pt" + fh = urlretrieve(url) + state_dict = torch.load(fh[0], map_location=torch.device("cpu")) + model.load_state_dict(state_dict) + preprocessing = functools.partial(load_preprocess_images, image_size=224) + wrapper = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing) + wrapper.image_size = 224 + return wrapper + + +def get_layers(name): + assert name == "eBarlow_Vanilla" + + outs = ["conv1", "layer1", "layer2", "layer3", "layer4", "avgpool", "fc"] + return outs + + +def get_bibtex(model_identifier): + return """xx""" + + +if __name__ == "__main__": + check_models.check_base_models(__name__) diff --git a/brainscore_vision/models/eBarlow_Vanilla/setup.py b/brainscore_vision/models/eBarlow_Vanilla/setup.py new file mode 100644 index 000000000..421914cfb --- /dev/null +++ b/brainscore_vision/models/eBarlow_Vanilla/setup.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from setuptools import setup, find_packages + +requirements = [ "torchvision", + "torch" +] + +setup( + packages=find_packages(exclude=['tests']), + include_package_data=True, + install_requires=requirements, + license="MIT license", + zip_safe=False, + keywords='brain-score template', + classifiers=[ + 'Development Status :: 2 - Pre-Alpha', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Programming Language :: Python :: 3.7', + ], + test_suite='tests', +) diff --git a/brainscore_vision/models/eBarlow_Vanilla/test.py b/brainscore_vision/models/eBarlow_Vanilla/test.py new file mode 100644 index 000000000..e594ba9e1 --- /dev/null +++ b/brainscore_vision/models/eBarlow_Vanilla/test.py @@ -0,0 +1 @@ +# Left empty as part of 2023 models migration diff --git a/brainscore_vision/models/eBarlow_lmda_01/__init__.py b/brainscore_vision/models/eBarlow_lmda_01/__init__.py new file mode 100644 index 000000000..3c6f2803a --- /dev/null +++ b/brainscore_vision/models/eBarlow_lmda_01/__init__.py @@ -0,0 +1,9 @@ +from brainscore_vision import model_registry +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +from .model import get_model, get_layers + +model_registry["eBarlow_lmda_01"] = lambda: ModelCommitment( + identifier="eBarlow_lmda_01", + activations_model=get_model("eBarlow_lmda_01"), + layers=get_layers("eBarlow_lmda_01"), +) diff --git a/brainscore_vision/models/eBarlow_lmda_01/model.py b/brainscore_vision/models/eBarlow_lmda_01/model.py new file mode 100644 index 000000000..f0372240b --- /dev/null +++ b/brainscore_vision/models/eBarlow_lmda_01/model.py @@ -0,0 +1,50 @@ +from brainscore_vision.model_helpers.check_submission import check_models +import functools +import os +from urllib.request import urlretrieve +import torchvision.models +from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper +from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images +from pathlib import Path +from brainscore_vision.model_helpers import download_weights +import torch + +# This is an example implementation for submitting resnet-50 as a pytorch model + +# Attention: It is important, that the wrapper identifier is unique per model! +# The results will otherwise be the same due to brain-scores internal result caching mechanism. +# Please load your pytorch model for usage in CPU. There won't be GPUs available for scoring your model. +# If the model requires a GPU, contact the brain-score team directly. +from brainscore_vision.model_helpers.check_submission import check_models + + +def get_model_list(): + return ["eBarlow_lmda_01"] + + +def get_model(name): + assert name == "eBarlow_lmda_01" + model = torchvision.models.resnet50(pretrained=False) + url = "https://users.flatironinstitute.org/~tyerxa/equi_proj/training_checkpoints/classifiers/barlow/equi_matched_0.1/classifier.pt" + fh = urlretrieve(url) + state_dict = torch.load(fh[0], map_location=torch.device("cpu")) + model.load_state_dict(state_dict) + preprocessing = functools.partial(load_preprocess_images, image_size=224) + wrapper = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing) + wrapper.image_size = 224 + return wrapper + + +def get_layers(name): + assert name == "eBarlow_lmda_01" + + outs = ["conv1", "layer1", "layer2", "layer3", "layer4", "avgpool", "fc"] + return outs + + +def get_bibtex(model_identifier): + return """xx""" + + +if __name__ == "__main__": + check_models.check_base_models(__name__) diff --git a/brainscore_vision/models/eBarlow_lmda_01/setup.py b/brainscore_vision/models/eBarlow_lmda_01/setup.py new file mode 100644 index 000000000..421914cfb --- /dev/null +++ b/brainscore_vision/models/eBarlow_lmda_01/setup.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from setuptools import setup, find_packages + +requirements = [ "torchvision", + "torch" +] + +setup( + packages=find_packages(exclude=['tests']), + include_package_data=True, + install_requires=requirements, + license="MIT license", + zip_safe=False, + keywords='brain-score template', + classifiers=[ + 'Development Status :: 2 - Pre-Alpha', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Programming Language :: Python :: 3.7', + ], + test_suite='tests', +) diff --git a/brainscore_vision/models/eBarlow_lmda_01/test.py b/brainscore_vision/models/eBarlow_lmda_01/test.py new file mode 100644 index 000000000..e594ba9e1 --- /dev/null +++ b/brainscore_vision/models/eBarlow_lmda_01/test.py @@ -0,0 +1 @@ +# Left empty as part of 2023 models migration diff --git a/pyproject.toml b/pyproject.toml index d77411e1e..51ef564ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "importlib-metadata<5", # workaround to https://github.com/brain-score/brainio/issues/28 "scikit-learn", # for metric_helpers/transformations.py cross-validation "scipy", # for benchmark_helpers/properties_common.py + "opencv-python", # for microsaccades "h5py", "tqdm", "gitpython", diff --git a/tests/test_model_helpers/activations/test___init__.py b/tests/test_model_helpers/activations/test___init__.py index 86a1ee788..3d1306329 100644 --- a/tests/test_model_helpers/activations/test___init__.py +++ b/tests/test_model_helpers/activations/test___init__.py @@ -163,6 +163,20 @@ def tfslim_vgg16(): pytest.param(tfslim_vgg16, ['vgg_16/pool5'], marks=pytest.mark.memory_intense), ] +# exact microsaccades for pytorch_alexnet, grayscale.png, for 1 and 10 number_of_trials +exact_microsaccades = {"x_degrees": {1: np.array([0.]), + 10: np.array([0., -0.00639121, -0.02114204, -0.02616418, -0.02128906, + -0.00941355, 0.00596172, 0.02166913, 0.03523793, 0.04498976])}, + "y_degrees": {1: np.array([0.]), + 10: np.array([0., 0.0144621, 0.00728107, -0.00808922, -0.02338324, -0.0340791, + -0.03826824, -0.03578336, -0.02753704, -0.01503068])}, + "x_pixels": {1: np.array([0.]), + 10: np.array([0., -0.17895397, -0.59197722, -0.73259714, -0.59609364, -0.26357934, + 0.16692818, 0.60673569, 0.98666196, 1.25971335])}, + "y_pixels": {1: np.array([0.]), + 10: np.array([0., 0.40493885, 0.20386999, -0.22649819, -0.65473077, -0.95421482, + -1.07151061, -1.00193403, -0.77103707, -0.42085896])}} + @pytest.mark.parametrize("image_name", ['rgb.jpg', 'grayscale.png', 'grayscale2.jpg', 'grayscale_alpha.png', 'palletized.png']) @@ -189,6 +203,68 @@ def test_from_image_path(model_ctr, layers, image_name, pca_components, logits): return activations +@pytest.mark.parametrize("image_name", ['rgb.jpg', 'grayscale.png', 'grayscale2.jpg', 'grayscale_alpha.png', + 'palletized.png']) +@pytest.mark.parametrize(["model_ctr", "layers"], models_layers) +@pytest.mark.parametrize("number_of_trials", [1, 3, 10]) +def test_require_variance_has_shift_coords(model_ctr, layers, image_name, number_of_trials): + stimulus_paths = [os.path.join(os.path.dirname(__file__), image_name)] + activations_extractor = model_ctr() + # when using microsaccades, the ModelCommitment sets its visual angle. Since this test skips the ModelCommitment, + # we set it here manually. + activations_extractor._extractor.set_visual_degrees(8.) + + activations = activations_extractor(stimuli=stimulus_paths, layers=layers, number_of_trials=number_of_trials, + require_variance=True) + + assert activations is not None + assert len(activations['microsaccade_shift_x_pixels']) == number_of_trials * len(stimulus_paths) + assert len(activations['microsaccade_shift_y_pixels']) == number_of_trials * len(stimulus_paths) + assert len(activations['microsaccade_shift_x_degrees']) == number_of_trials * len(stimulus_paths) + assert len(activations['microsaccade_shift_y_degrees']) == number_of_trials * len(stimulus_paths) + + +@pytest.mark.parametrize("image_name", ['rgb.jpg', 'grayscale.png', 'grayscale2.jpg', 'grayscale_alpha.png', + 'palletized.png']) +@pytest.mark.parametrize(["model_ctr", "layers"], models_layers) +@pytest.mark.parametrize("require_variance", [False, True]) +@pytest.mark.parametrize("number_of_trials", [1, 3, 10]) +def test_require_variance_presentation_length(model_ctr, layers, image_name, require_variance, number_of_trials): + stimulus_paths = [os.path.join(os.path.dirname(__file__), image_name)] + activations_extractor = model_ctr() + # when using microsaccades, the ModelCommitment sets its visual angle. Since this test skips the ModelCommitment, + # we set it here manually. + activations_extractor._extractor.set_visual_degrees(8.) + + activations = activations_extractor(stimuli=stimulus_paths, layers=layers, + number_of_trials=number_of_trials, require_variance=require_variance) + + assert activations is not None + if require_variance: + assert len(activations['presentation']) == number_of_trials + else: + assert len(activations['presentation']) == 1 + + +@pytest.mark.parametrize("image_name", ['rgb.jpg', 'grayscale.png', 'grayscale2.jpg', 'grayscale_alpha.png', + 'palletized.png']) +@pytest.mark.parametrize(["model_ctr", "layers"], models_layers) +def test_temporary_file_handling(model_ctr, layers, image_name): + import tempfile + stimulus_paths = [os.path.join(os.path.dirname(__file__), image_name)] + activations_extractor = model_ctr() + # when using microsaccades, the ModelCommitment sets its visual angle. Since this test skips the ModelCommitment, + # we set it here manually. + activations_extractor._extractor.set_visual_degrees(8.) + + activations = activations_extractor(stimuli=stimulus_paths, layers=layers, number_of_trials=2, + require_variance=True) + temp_files = [f for f in os.listdir(tempfile.gettempdir()) if f.startswith('temp') and f.endswith('.png')] + + assert activations is not None + assert len(temp_files) == 0 + + def _build_stimulus_set(image_names): stimulus_set = StimulusSet([{'stimulus_id': image_name, 'some_meta': image_name[::-1]} for image_name in image_names]) @@ -223,9 +299,51 @@ def test_exact_activations(pca_components): image_name='rgb.jpg', pca_components=pca_components, logits=False) path_to_expected = Path(__file__).parent / f'alexnet-rgb-{pca_components}.nc' expected = xr.load_dataarray(path_to_expected) + + # Originally, the `stimulus_path` Index was used to index into xarrays in Brain-Score, but this was changed + # as a part of PR #492 to a MultiIndex to allow metadata to be attached to multiple repetitions of the same + # `stimulus_path`. Old .nc files need to be updated to use the `presentation` index instead of `stimulus_path`, + # and instead of changing the extant activations, this test was simply modified to simulate that. + expected = expected.rename({'stimulus_path': 'presentation'}) + assert (activations == expected).all() +@pytest.mark.memory_intense +@pytest.mark.parametrize("number_of_trials", [1, 10]) +def test_exact_microsaccades(number_of_trials): + image_name = 'grayscale.png' + stimulus_paths = [os.path.join(os.path.dirname(__file__), image_name)] + activations_extractor = pytorch_alexnet() + # when using microsaccades, the ModelCommitment sets its visual angle. Since this test skips the ModelCommitment, + # we set it here manually. + activations_extractor._extractor.set_visual_degrees(8.) + # the exact microsaccades were computed at this extent + assert activations_extractor._extractor._microsaccade_helper.microsaccade_extent_degrees == 0.05 + + activations = activations_extractor(stimuli=stimulus_paths, layers=['features.12'], + number_of_trials=number_of_trials, require_variance=True) + + assert activations is not None + # test with np.isclose instead of == since while the arrays are visually equal, == often fails due to float errors + assert np.isclose(activations['microsaccade_shift_x_degrees'].values, + exact_microsaccades['x_degrees'][number_of_trials], + rtol=1e-05, + atol=1e-08).all() + assert np.isclose(activations['microsaccade_shift_y_degrees'].values, + exact_microsaccades['y_degrees'][number_of_trials], + rtol=1e-05, + atol=1e-08).all() + assert np.isclose(activations['microsaccade_shift_x_pixels'].values, + exact_microsaccades['x_pixels'][number_of_trials], + rtol=1e-05, + atol=1e-08).all() + assert np.isclose(activations['microsaccade_shift_y_pixels'].values, + exact_microsaccades['y_pixels'][number_of_trials], + rtol=1e-05, + atol=1e-08).all() + + @pytest.mark.memory_intense @pytest.mark.parametrize(["model_ctr", "internal_layers"], [ (pytorch_alexnet, ['features.12', 'classifier.5']), diff --git a/tests/test_model_helpers/brain_transformation/test___init__.py b/tests/test_model_helpers/brain_transformation/test___init__.py index 49fb55b1c..b61a201b8 100644 --- a/tests/test_model_helpers/brain_transformation/test___init__.py +++ b/tests/test_model_helpers/brain_transformation/test___init__.py @@ -1,9 +1,18 @@ +from unittest.mock import Mock + from brainscore_vision.model_helpers.brain_transformation import ModelCommitment from brainscore_vision.model_helpers.utils import fullname class TestVisualDegrees: def test_standard_commitment(self): - brain_model = ModelCommitment(identifier=fullname(self), activations_model=None, + # create mock ActivationsExtractorHelper with a mock set_visual_degrees to avoid failing set_visual_degrees() + mock_extractor = Mock() + mock_extractor.set_visual_degrees = Mock() + mock_activations_model = Mock() + mock_activations_model._extractor = mock_extractor + + # Initialize ModelCommitment with the mock activations_model + brain_model = ModelCommitment(identifier=fullname(self), activations_model=mock_activations_model, layers=['dummy']) assert brain_model.visual_degrees() == 8