Skip to content

Commit

Permalink
Merge pull request #726 from shivdeep-singh-ibm/fix_semantic_ordering
Browse files Browse the repository at this point in the history
Multiple fixes for semantic order transform
  • Loading branch information
Param-S authored Oct 21, 2024
2 parents ed97afe + fdcf158 commit 5942f42
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,11 @@ def compute_exec_params_func(
"repo_lvl_store_ray_cpus": repo_lvl_store_ray_cpus,
"repo_lvl_store_ray_nworkers": repo_lvl_store_ray_nworkers,
"repo_lvl_sorting_algo": repo_lvl_sorting_algo,
"repo_lvl_stage_one_only": repo_lvl_stage_one_only,
"repo_lvl_sorting_enabled": repo_lvl_sorting_enabled,
"repo_lvl_output_by_langs": repo_lvl_output_by_langs,
"repo_lvl_combine_rows": repo_lvl_combine_rows,
}
if repo_lvl_stage_one_only == True:
res["repo_lvl_stage_one_only"] = ""
if repo_lvl_sorting_enabled == True:
res["repo_lvl_sorting_enabled"] = ""
if repo_lvl_output_by_langs == True:
res["repo_lvl_output_by_langs"] = ""
if repo_lvl_combine_rows == True:
res["repo_lvl_combine_rows"] = ""
return res


Expand Down
28 changes: 23 additions & 5 deletions transforms/code/repo_level_ordering/ray/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,31 @@ testing and IDE set up.

## Summary

This transform does repository level packing of data and arranging them to prioritise semantic dependancies. This
was done to prepare long context data for [Scaling Granite Code Models to 128K Context](https://arxiv.org/pdf/2407.13739)
. Quoting the paper.

>To create long-context data, we develop a new approach that packs files from the same
repository together, arranging them to prioritize semantic dependencies. We identify these
dependencies by analyzing file imports and create a directed acyclic graph, where each
file is a node and edges represent API imports between files. After breaking any cycles
in the graph, we perform a topological sort to establish an ordering of files based on their
semantic dependencies. We then organize the files in a repository by placing documentation
and build files first, followed by the ordered set of files with semantic dependencies, and
finally the remaining non-connected files. These non-connected files are arranged according
to their folder structure, using a depth-first search to traverse the repository. Finally, we
determine the dominant programming language of a repository based on file extensions
and presence of build files, to organise repo-ordered files by programming languages


This transform can group the data by `repo_name` and apply additional transformations like( sorting or output_by_language or combining rows) on the grouped data.
This transform requires the input data to have at least the following columns:

- repo name: Name of the repo, it is used for grouping in this transform.
- **repo name**: Name of the repo, it is used for grouping in this transform.

- title : Which is usually file path.
- **title** : Which is usually file path.

- language: Programming language of content
- **language**: Programming language of content

The input data for this transform should be in parquet format. The input data is expected to have code data arranged in rows
such that each row represents a file. The required columns in the input data shoud correspond to a) repository name b) file path
Expand Down Expand Up @@ -151,10 +168,11 @@ python src/repo_level_order_transform_ray.py \
--run_locally True \
--data_s3_cred "$s3_kreds" \
--data_s3_config "$s3_conf" \
--repo_lvl_store_type local \
--repo_lvl_store_backend_dir '/tmp/mystore' \
--repo_lvl_store_type ray \
--repo_lvl_combine_rows True\
--repo_lvl_sorting_enabled True\
--repo_lvl_store_ray_cpus 0.2 \
--repo_lvl_store_ray_nworkers 1 \
--repo_lvl_sorting_algo SORT_SEMANTIC \
--repo_lvl_output_by_langs True
```
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
from typing import Callable

import pandas as pd
import pyarrow as pa
from dpk_repo_level_order.internal.check_languages import (
get_dominant_language_repo_packing,
Expand All @@ -20,26 +21,47 @@
SORT_SEMANTIC_NORMALISED = "SORT_SEMANTIC_NORMALISED"


def semantic_sort(df, logger, title_column_name, language_column_name):
def semantic_sort(
df: pd.DataFrame, logger: logging.Logger, title_column_name: str, language_column_name: str
) -> pd.DataFrame:
return sort_sem(
files_df=df, logger=logger, title_column_name=title_column_name, language_column_name=language_column_name
)


def semantic_sort_normalised(df, logger, title_column_name, language_column_name):
def semantic_sort_normalised(
df: pd.DataFrame, logger: logging.Logger, title_column_name: str, language_column_name: str
) -> pd.DataFrame:
check_and_update_title(df)
return sort_sem(
files_df=df, logger=logger, title_column_name=title_column_name, language_column_name=language_column_name
)


def default_sort(df, logger, title_column_name, language_column_name):
def default_sort(
df: pd.DataFrame, logger: logging.Logger, title_column_name: str, language_column_name: str
) -> pd.DataFrame:
return sort_by_path(df=df, logger=logger, title_column_name=title_column_name)


def get_sorting_func(
sorting_algo: str, title_column_name: str, logger: logging.Logger, language_column_name: str
) -> Callable[[pa.Table], pa.Table]:
"""Get a sorting function based on the specified algorithm.
Args:
sorting_algo (str): The sorting algorithm to use.
title_column_name (str): The name of the column containing file
titles.
logger (logging.Logger): A logger object for logging messages.
language_column_name (str): The name of the column containing file
languages.
Returns:
Callable[[pa.Table, str], pa.Table]: A function that takes a PyArrow Table
and a file name as input and
returns a sorted PyArrow Table.
"""
if sorting_algo == SORT_SEMANTIC:
sort_by = semantic_sort
logger.info("semantic sort enabled")
Expand Down Expand Up @@ -74,7 +96,26 @@ def sorter(table: pa.Table, file_name: str) -> pa.Table:
return sorter


def get_dominant_language_func(language_column_name, title_column_name):
def get_dominant_language_func(language_column_name: str, title_column_name: str) -> Callable[[pa.Table, str], str]:
"""
This function takes two column names as input and returns a function
that can be applied to a pyarrow table.
The returned function determines the dominant programming language in
the pyarrow table and returns the filename with the detected language
prepended.
Args:
language_column_name (str): Name of the column containing the
programming languages.
title_column_name (str): Name of the column containing the file
titles/paths.
Returns:
Callable[[pa.Table, str], str]: A function that takes a table as
input and returns a new table with the filenames modified to include the
detected dominant language.
"""

def dominant_lang_per_repo(table: pa.Table, filename: str) -> str:
"""
This function takes a table whose rows are documents from a repo
Expand Down Expand Up @@ -137,6 +178,28 @@ def lang_distribution(grouping_column):


def get_transforming_func(sorting_func=None, superrows_func=None, filename_func=None, language_column_name="language"):
"""
This function takes three optional functions as input and returns a
function that can be applied to a pyarrow table and file name.
The returned function performs some transformation on the input table
and file name based on the provided functions.
Args:
sorting_func (Callable[[pa.Table, str], pa.Table]): A function that sorts the
rows of a table based on a column. Defaults to None.
superrows_func (Callable[[pa.Table, str, str], pa.Table]): A
function that creates new rows in a table based on the values of other
columns. Defaults to None.
filename_func (Callable[[pa.Table, str], str]): A function that modifies the
file name. Defaults to None.
language_column_name (str): The name of the column containing the
programming languages. Defaults to "language".
Returns:
callable: A function that takes a table and file name as input and
returns a list of transformed tables and file names.
"""

def my_transform(table, file_name):
out_table = table
if sorting_func:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@
}

repo_level_params = {
"repo_lvl_sorting_enabled": True,
"repo_lvl_sorting_algo": "SORT_SEMANTIC",
"repo_lvl_store_type": "ray",
"repo_lvl_output_by_langs": True,
"repo_lvl_combine_rows": True,
}

repo_level_flags = ["repo_lvl_output_by_langs", "repo_lvl_combine_rows", "repo_lvl_sorting_enabled"]

d = ParamsUtils.dict_to_req(d=params | repo_level_params)
sys.argv = d + [f"--{flag}" for flag in repo_level_flags]
sys.argv = ParamsUtils.dict_to_req(d=params)
sys.argv = ParamsUtils.dict_to_req(d=params | repo_level_params)
# for arg in sys.argv:
# print(arg)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pyarrow as pa
from data_processing.data_access import DataAccessFactoryBase
from data_processing.transform import AbstractTableTransform, TransformConfiguration
from data_processing.utils import CLIArgumentProvider, get_logger
from data_processing.utils import CLIArgumentProvider, get_logger, str2bool
from data_processing_ray.runtime.ray import DefaultRayTransformRuntime, RayUtils
from data_processing_ray.runtime.ray.runtime_configuration import (
RayTransformRuntimeConfiguration,
Expand All @@ -27,6 +27,7 @@
create_store,
create_store_params,
init_store_params,
store_type_value_ray,
validate_store_params,
)
from ray.actor import ActorHandle
Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(self, config: dict[str, Any]):
self.grouping_column = config.get(grouping_column_key, repo_column_default_value)
store_params = config.get(store_params_key)
validate_store_params(store_params)
self.store_type = store_params[store_type_key]
self.store = create_store(store_params)
self.group_batch_size = group_batch_size

Expand All @@ -126,6 +128,16 @@ def _create_batches(self, data, batch_size=1):
batches.append(batch)
return batches

def _normalize_file_name_for_store(self, file_name):
if self.store_type == store_type_value_ray:
# we can store full file_name consiting of full path in this store.
return file_name
else:
# since this store type uses filesystem as backend
# can't store full path in store since,
# store is currently flat filesystem.
return os.path.basename(file_name)

def transform(self, table: pa.Table, file_name: str = None) -> tuple[list[pa.Table], dict[str, Any]]:
"""
This step is used to do groupby with respect to `self.grouping_column` and update
Expand All @@ -145,11 +157,8 @@ def transform(self, table: pa.Table, file_name: str = None) -> tuple[list[pa.Tab
grp_flow = {}
for group in batch:
# This supports only flat folder structure, so all
# files should be in the same folder
# since store uses filesystem as backend
# can't store full path in store since,
# store is currently flat filesystem.
file_name = os.path.basename(file_name)
# files should be in the same folder.
file_name = self._normalize_file_name_for_store(file_name)
grp_flow[group] = file_name
self.logger.debug(f"Updating {group} to store")

Expand Down Expand Up @@ -286,10 +295,15 @@ def _prepare_mapper_function(self):

def _prepare_inputs(self):
store = create_store(self.store_params)
files_location = self.input_folder
store_type = self.store_params[store_type_key]

p_input = []
for repo, files in store.items_kv():
p_input.append((repo, [f"{files_location}/{file}" for file in files]))
if store_type == store_type_value_ray:
p_input.append((repo, [f"{file}" for file in files]))
else:
files_location = self.input_folder
p_input.append((repo, [f"{files_location}/{file}" for file in files]))
return p_input

def _group_and_sort(self):
Expand Down Expand Up @@ -361,8 +375,8 @@ def add_input_params(self, parser: ArgumentParser) -> None:
# See below for remove_from_metadata addition so that it is not reported.
parser.add_argument(
f"--{cli_prefix}{stage_one_only_key}",
action="store_true",
help="If this flag is set, transform only builds the repo grouping and doesn't write output",
type=lambda x: bool(str2bool(x)),
help="If this flag is True, transform only builds the repo grouping and doesn't write output",
)
parser.add_argument(
f"--{cli_prefix}{grouping_column_key}",
Expand Down Expand Up @@ -402,7 +416,7 @@ def add_input_params(self, parser: ArgumentParser) -> None:
parser.add_argument(
f"--{cli_prefix}{sorting_enable_key}",
default=sort_enable_default,
type=bool,
type=lambda x: bool(str2bool(x)),
help=f"Enables sorting of output by algorithm specified using {cli_prefix}{sorting_algo_key}. Defaults to SORT_BY_PATH if no algorithm is specified.",
)
parser.add_argument(
Expand All @@ -413,13 +427,13 @@ def add_input_params(self, parser: ArgumentParser) -> None:
)
parser.add_argument(
f"--{cli_prefix}{output_by_langs_key}",
type=bool,
type=lambda x: bool(str2bool(x)),
default=output_by_lang_default,
help="If specified, output is grouped into programming language folders.",
)
parser.add_argument(
f"--{cli_prefix}{output_superrows_key}",
type=bool,
type=lambda x: bool(str2bool(x)),
default=superrows_default,
help="If specified, output rows per repo are combined to form a single repo",
)
Expand Down

0 comments on commit 5942f42

Please sign in to comment.