Project for cs224w, GNN for wiki game.
Authors: Michael Rybalkin, Cary Xiao, Noah Islam
This project uses the English Wikipedia hyperlink network dataset found here: https://snap.stanford.edu/data/enwiki-2013.html. Download the dataset using download_dataset.sh.
pip install -r requirements.txt./download_dataset.sh
- If you would like to play the wiki game for yourself on the
n=1000subset of the wikipedia dataset, run./human_player 1000. The first time you do this, the subsampled graph needs to be generated, and this process takes a few minutes. Afterwards, it will be cached. Before playing, you must provide the names of the starting and ending articles to play with. Run./node_names.py 1000to get a list of all the valid node names to use here. If you don't want to choose the start and end node, run with baseline mode using./human_player 1000 bto play 20 games with randomly-selected (seeded) start and end nodes. - If you would like to run the node2vec agent and see it play the wiki game, you first need to generate node2vec embeddings. Run
./gen_node2vec_embeddings.py 1000. Now, to simulate 20 trials of the wiki game with the simple node2vec agent (picking the neighbor who's embedding has the highest cosine similarity with the target embedding) run./node2vec_player.py 1000 20. - If you would like to run one of the GNN-based approaches (MLP or GraphSAGE), you must first generate the node2vec embeddings (if you have not already done so). Precomputed embeddings for
n=1000andn=10000are already in theembeddingsdir. Pretrained checkpoints are commited in thecheckpoints/best_model.ptfor MLP andcheckpoints/graphsage/best_model.ptfor GraphSAGE. To run them, try for example./gnn_player.py 1000 b checkpoints/best_model.pt. If you would like to train your own models using our training data instead, you can run the train script. Example usage:python3 gnn/train.py --model mlp --epochs 50. Checkpoints generated are saved to thecheckpointsdir. Now, you can run thegnn_player.pyscript and specify the filepath to the checkpoint as the third argument. - If you would like to compare the paths taken by different agents in the wiki game, you can use the script
./visualize_test_results.py. When running a player script in baseline mode, a file is produced which contains the paths taken by the player. All of these files which are placed in the directoryplayers_to_visualizewill be rendered by the visualization script. Each approach will have a unique color, and if multiple approaches traverse the same nodes/edges, they will be purple. If the graphs are too cluttered, try removing some of the files fromplayers_to_visualize.
download_dataset.sh: Downloads the dataset, sets up the repo directory for the other scripts.util.py: Contains utils used by other scripts. Is also runnable, and takes optional argumentn(usage:./util.py 1000). When ran, the script subsamples the full dataset to only inculde thenhighest-degree articles, and saves the objects topickles/.... Edges are treated as bidirectional when degree is counted (both incoming and outgoing links are counted). Defaults ton=1000when unspecified.human_player.py: CLI demo of the Wiki Game, prompts for a start and target article title, then lists all neighboring article titles and prompts the user to make a selection. After the target is reached, displays the path of articles taken from start to target. Takes optional argumentn(usage:./human_player.py 1000) which subsamples the dataset to only include thenhighest-degree articles. Ifnis not specified, uses the full dataset. To play 20 trials and write results to a file, add an attional argument like so:./human_player.py 1000 b.node2vec_player.py: Simulates the simple node2vec agent playing the wiki game. Takes two args. First isnfor the number of ndoes. Second is optional, is the number of trials to run in baseline mode (defaults to 20). Example usage:./node2vec_player.py 1000 20. If in baseline mode, write results of running baseline trials to filesystem.gnn_player.py: Simulates the GNN agent playing the wiki game. Takes four args. First isnfor the number of nodes, defaults ton=1000. Second isbfor baseline mode, where the GNN is evaluated on some number of trials selected randomly with seed. Third is the filepath to the checkpoint of the trained GNN. Fourth is the number of trials to use in baseline mode (default is 20). Example usage:./gnn_player 1000 b checkpoints/best_model.pt 200. If in baseline mode, write results of running baseline trials to filesystem.visualize_test_results.py: Visualization script used for comparing paths of different approaches. Visualizes all path files in the dirplayers_to_visualize. Black edges showing untraversed edges between visited nodes can be disabled with the boolean flagSHOW_GRAY_EDGES. This script assumes that all path files were generated in baseline mode using the same src,dst node pairs and the same number of trials. Trials if the path files must be consistent withNUM_PROMPTS. Example usage:./visualize_test_results.npy.node_names.py: Takes an argumentn(usage:./node_names.py 1000), and prints the names of all nodes in the subsampled dataset of topnnodes of highest degree. Requires that the dataset has already been generated.gen_node2vec_embeddings.py: Uses Node2vec to create node an embedding for each node in the subgraph. Saves embeddings inembeddingsfolder. Takes parameternfor number of nodes in the graph. Usage:./gen_node2vec_embeddings.py 1000.baseline.py: Plays 20 trials of the wiki game using BFS, then reports the shortest path length and number of visited nodes using BFS.graph_stats.py: This is a simple script used to calculate the mean, standard deviation, and maximum for each strategy in both the 20-game and 200-game benchmark. To use this script, simply move all.pklfiles you want to calculate the stats of into thetest_results/directory and runpython3 graph_stats.py. sgraph_train_loss.py: This is a simple script that creates the loss graphs across each epoch we show in the training section of our Medium post. To get each graph, runpython3 ./graph_train_loss.py <.json file>, where you specify the path to one of the two JSON files incheckpoint/.