Skip to content

Latest commit

 

History

History
47 lines (38 loc) · 2.68 KB

README.md

File metadata and controls

47 lines (38 loc) · 2.68 KB

Tiramisu Cost Model

A deep learning model to predict the speedup obtained from applying a sequence of transformations on an input program.

Installation

Install the environment using the environment.yml file as follows:

conda env create -f environment.yml

This should create an environment called cost_model_env.

Whenever you want to use the model, you need to activate the environment as follows:

conda activate cost_model_env

Configuring the repository

All of the main scripts use Hydra for configuration management. To configure the repository, fill the configuration template conf/config.yaml with the paths and parameters required. While using one of the following script files, you can override any configuration in the conf file. For example, to modify the batch size to 512 for training, use the following command. The parameter should be included with its section name.

python generate_dataset.py data_generation.batch_size=512

Processing the dataset

Currently, we have separated the data loading and training from each other. This is because the data loading is very time-consuming, and we don't want to redo it for every training. To solve this, we run a script to load the raw data (JSON), extract the representation for each datapoint, and then save the batched data in a .pt file that can be loaded directly into memory for training. We call this process data generation. To generate the dataset, run the python script generate_dataset.py (after configuring the repository):

python generate_dataset.py

Dataset Sample

A sample from the dataset is provided in the dataset_samples folder as a pickle file. This sample contains approximately 80,000 data points, divided into a training set and a validation set. The training set includes 600 synthetic Tiramisu programs (~60,000 schedules), while the validation set consists of 125 synthetic programs (~20,000 schedules).

Training the model

To run the training, run the python script train_model.py (after configuring the repository and generating the dataset):

python train_model.py

Using wandb for visualization

The repository allows the use Weights and Biases for visualization. To enable it, set the use_wandb parameter to True, after logging into wandb from the command line. The project name should be specified. This name does not have to already exist in wandb. During training, the progress can be found on the wandb platform.

Evaluation of the trained model

To evaluate the trained model, run the python script evaluate_model.py (after configuring the repository and generating the dataset):

python evaluate_model.py