From d71fe712cacbfe19536282a6fd876667d0159687 Mon Sep 17 00:00:00 2001 From: Pravali Uppugunduri Date: Wed, 16 Oct 2024 06:00:01 +0000 Subject: [PATCH] Auto-capture requirements --- albatross_test.py | 159 ++++++++++++++++++ .../serve/detector/dependency_manager.py | 75 +++++---- .../serve/detector/pickle_dependencies.py | 87 +++++++--- 3 files changed, 265 insertions(+), 56 deletions(-) create mode 100644 albatross_test.py diff --git a/albatross_test.py b/albatross_test.py new file mode 100644 index 0000000000..fa9867d5a4 --- /dev/null +++ b/albatross_test.py @@ -0,0 +1,159 @@ +import sys +print(sys.path) +sys.path.append("/home/upravali/telemetry/sagemaker-python-sdk/src/sagemaker") +sys.path.append('/home/upravali/langchain/langchain-aws/libs/aws/') +print("Updated sys.path: ", sys.path) + +import json +import os +import time + +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.serve.spec.inference_spec import InferenceSpec +import langchain_aws +import langchain_core +from langchain_aws import ChatBedrockConverse +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser + +INPUTS = { + 'CPU': { + 'INFERENCE_IMAGE': '763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.4.0-cpu-py311-ubuntu22.04-sagemaker', + 'INSTANCE_TYPE': 'ml.m5.xlarge' + }, + 'GPU': { + 'INFERENCE_IMAGE': '763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.4.0-gpu-py311-cu124-ubuntu22.04-sagemaker', + 'INSTANCE_TYPE': 'ml.g5.xlarge' + }, + 'SERVICE': { + 'ROLE': 'arn:aws:iam::971812153697:role/upravali-test-role' + } +} + +def deploy(device): + + class CustomerInferenceSpec(InferenceSpec): + + def load(self, model_dir): + from langchain_aws import ChatBedrockConverse + from langchain_core.prompts import ChatPromptTemplate + from langchain_core.output_parsers import StrOutputParser + return \ + ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a verbose assistant that gives long-winded responses at least 500 words long for every comment/question.", + ), + ("human", "{input}"), + ] + ) | \ + ChatBedrockConverse( + model = 'anthropic.claude-3-sonnet-20240229-v1:0', + temperature = 0, + region_name = 'us-west-2' + ) | \ + StrOutputParser() + + def invoke(self, x, model): + return model.invoke({'input': x['input']}) if x['stream'].lower() != 'true' \ + else model.stream({'input': x['input']}) + + + + model = ModelBuilder( + ################################################################## + # can be service or customer who defines these + ################################################################## + name = f'model-{int(time.time())}', + + ################################################################## + # service should define these + ################################################################## + image_uri = INPUTS[device]['INFERENCE_IMAGE'], + env_vars = { + 'TS_DISABLE_TOKEN_AUTHORIZATION' : 'true' # ABSOLUTELY NECESSARY + }, + + ################################################################## + # customer should define these + ################################################################## + schema_builder = SchemaBuilder( + json.dumps({ + 'stream': 'true', + 'input': 'hello' + }), + "" + ), + inference_spec = CustomerInferenceSpec(), # Won't be pickled correctly if Python version locally and DLC don't match + dependencies = { + "auto": True, + # 'requirements' : './inference/code/requirements2.txt' + }, + role_arn = INPUTS['SERVICE']['ROLE'] + ).build() + endpoint = model.deploy( + initial_instance_count = 1, + instance_type = INPUTS[device]['INSTANCE_TYPE'], + ) + return (model, endpoint) + + +################################################################################################### +# +# +# PoC DEMO CODE ONLY +# +# Note: invoke vs invoke_stream matters +################################################################################################### +def invoke(endpoint, x): + res = endpoint.predict(x) + return res + +def invoke_stream(endpoint, x): + res = endpoint.predict_stream(x) + print(str(res)) # Generator + return res + +def clean(model, endpoint): + try: + endpoint.delete_endpoint() + except Exception as e: + print(e) + pass + + try: + model.delete_model() + except Exception as e: + print(e) + pass + +def main(device): + print("before deploying") + model, endpoint = deploy(device) + print("after deploying") + + while True: + x = input(f">>> ") + if x == 'exit': + break + try: + if json.loads(x)['stream'].lower() == 'true': + for chunk in invoke_stream(endpoint, x): + print( + str(chunk, encoding = 'utf-8'), + end = "", + flush = True + ) + print() + else: + print(invoke(endpoint, x)) + except Exception as e: + print(e) + + clean(model, endpoint) + +if __name__ == '__main__': + os.environ['AWS_DEFAULT_REGION'] = 'us-west-2' + main('CPU') \ No newline at end of file diff --git a/src/sagemaker/serve/detector/dependency_manager.py b/src/sagemaker/serve/detector/dependency_manager.py index e72a84da30..d4e3e73ee1 100644 --- a/src/sagemaker/serve/detector/dependency_manager.py +++ b/src/sagemaker/serve/detector/dependency_manager.py @@ -1,37 +1,19 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker model builder dependency managing module. - -This must be kept independent of SageMaker PySDK -""" - -from __future__ import absolute_import - -from pathlib import Path import logging import subprocess import sys import re +from pathlib import Path _SUPPORTED_SUFFIXES = [".txt"] -# TODO : Move PKL_FILE_NAME to common location PKL_FILE_NAME = "serve.pkl" logger = logging.getLogger(__name__) def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool = False): - """Placeholder docstring""" + """Capture dependencies and print output.""" + print(f"Capturing dependencies: {dependencies}, work_dir: {work_dir}, capture_all: {capture_all}") + path = work_dir.joinpath("requirements.txt") if "auto" in dependencies and dependencies["auto"]: command = [ @@ -45,6 +27,8 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool = if capture_all: command.append("--capture_all") + + print(f"Running subprocess with command: {command}") subprocess.run( command, @@ -55,62 +39,83 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool = with open(path, "r") as f: autodetect_depedencies = f.read().splitlines() autodetect_depedencies.append("sagemaker[huggingface]>=2.199") + print(f"Auto-detected dependencies: {autodetect_depedencies}") else: autodetect_depedencies = ["sagemaker[huggingface]>=2.199"] + print(f"No auto-detection, using default dependencies: {autodetect_depedencies}") module_version_dict = _parse_dependency_list(autodetect_depedencies) + print(f"Parsed auto-detected dependencies: {module_version_dict}") if "requirements" in dependencies: module_version_dict = _process_customer_provided_requirements( requirements_file=dependencies["requirements"], module_version_dict=module_version_dict ) + print(f"After processing customer-provided requirements: {module_version_dict}") + if "custom" in dependencies: module_version_dict = _process_custom_dependencies( custom_dependencies=dependencies.get("custom"), module_version_dict=module_version_dict ) + print(f"After processing custom dependencies: {module_version_dict}") + with open(path, "w") as f: for module, version in module_version_dict.items(): f.write(f"{module}{version}\n") + print(f"Final dependencies written to {path}") def _process_custom_dependencies(custom_dependencies: list, module_version_dict: dict): - """Placeholder docstring""" + """Process custom dependencies and print output.""" + print(f"Processing custom dependencies: {custom_dependencies}") + custom_module_version_dict = _parse_dependency_list(custom_dependencies) + print(f"Parsed custom dependencies: {custom_module_version_dict}") + module_version_dict.update(custom_module_version_dict) + print(f"Updated module_version_dict with custom dependencies: {module_version_dict}") + return module_version_dict def _process_customer_provided_requirements(requirements_file: str, module_version_dict: dict): - """Placeholder docstring""" + """Process customer-provided requirements and print output.""" + print(f"Processing customer-provided requirements from file: {requirements_file}") + requirements_file = Path(requirements_file) if not requirements_file.is_file() or not _is_valid_requirement_file(requirements_file): raise Exception(f"Path: {requirements_file} to requirements.txt doesn't exist") + logger.debug("Packaging provided requirements.txt from %s", requirements_file) with open(requirements_file, "r") as f: custom_dependencies = f.read().splitlines() + + print(f"Customer-provided dependencies: {custom_dependencies}") module_version_dict.update(_parse_dependency_list(custom_dependencies)) + print(f"Updated module_version_dict with customer-provided requirements: {module_version_dict}") + return module_version_dict def _is_valid_requirement_file(path): - """Placeholder docstring""" - # In the future, we can also check the if the content of customer provided file has valid format + """Check if the requirements file is valid and print result.""" + print(f"Validating requirement file: {path}") + for suffix in _SUPPORTED_SUFFIXES: if path.name.endswith(suffix): + print(f"File {path} is valid with suffix {suffix}") return True + + print(f"File {path} is not valid") return False def _parse_dependency_list(depedency_list: list) -> dict: - """Placeholder docstring""" - - # Divide a string into 2 part, first part is the module name - # and second part is its version constraint or the url - # checkout tests/unit/sagemaker/serve/detector/test_dependency_manager.py - # for examples + """Parse the dependency list and print output.""" + print(f"Parsing dependency list: {depedency_list}") + pattern = r"^([\w.-]+)(@[^,\n]+|((?:[<>=!~]=?[\w.*-]+,?)+)?)$" - module_version_dict = {} for dependency in depedency_list: @@ -119,10 +124,10 @@ def _parse_dependency_list(depedency_list: list) -> dict: match = re.match(pattern, dependency) if match: package = match.group(1) - # Group 2 is either a URL or version constraint, if present url_or_version = match.group(2) if match.group(2) else "" module_version_dict.update({package: url_or_version}) else: module_version_dict.update({dependency: ""}) - + + print(f"Parsed module_version_dict: {module_version_dict}") return module_version_dict diff --git a/src/sagemaker/serve/detector/pickle_dependencies.py b/src/sagemaker/serve/detector/pickle_dependencies.py index 5a1cd43869..6a27c74b26 100644 --- a/src/sagemaker/serve/detector/pickle_dependencies.py +++ b/src/sagemaker/serve/detector/pickle_dependencies.py @@ -1,6 +1,7 @@ -"""Load a pickled object to detect the dependencies it requires""" - -from __future__ import absolute_import +import logging +import subprocess +import sys +import re from pathlib import Path from typing import List import argparse @@ -10,11 +11,9 @@ import inspect import itertools import subprocess -import sys import tqdm +import ast -# non native imports. Ideally add as little as possible here -# because it will add to requirements.txt import cloudpickle import boto3 @@ -22,22 +21,25 @@ def get_all_files_for_installed_packages_pip(packages: List[str]): - """Placeholder docstring""" + """Get all files for installed packages using pip.""" + print(f"Fetching files for installed packages: {packages}") proc = subprocess.Popen(pipcmd + ["show", "-f"] + packages, stdout=subprocess.PIPE) with proc.stdout: lines = [] for line in iter(proc.stdout.readline, b""): if line == b"---\n": + print(f"Package details: {lines}") yield lines lines = [] else: lines.append(line) yield lines - proc.wait(timeout=10) # wait for the subprocess to exit + proc.wait(timeout=10) def get_all_files_for_installed_packages(packages: List[str]): - """Placeholder docstring""" + """Get all files for installed packages.""" + print(f"Processing installed packages: {packages}") ret = {} for rawmsg in get_all_files_for_installed_packages_pip(packages): parser = email.parser.BytesParser(policy=email.policy.default) @@ -47,13 +49,13 @@ def get_all_files_for_installed_packages(packages: List[str]): ret[msg.get("Name")] = { Path(msg.get("Location")).joinpath(x) for x in msg.get("Files").split() } - + print(f"Package {msg.get('Name')} with files: {ret[msg.get('Name')]}") return ret def batched(iterable, n): - """Batch data into tuples of length n. The last batch may be shorter.""" - # batched('ABCDEFG', 3) --> ABC DEF G + """Batch data into tuples of length n.""" + print(f"Batching data into groups of {n}") if n < 1: raise ValueError("n must be at least one") it = iter(iterable) @@ -61,30 +63,40 @@ def batched(iterable, n): batch = tuple(itertools.islice(it, n)) if not batch: break + print(f"Batch: {batch}") yield batch def get_all_installed_packages(): - """Placeholder docstring""" + """Get all installed packages.""" + """from local env""" + print("Fetching all installed packages...") proc = subprocess.run(pipcmd + ["list", "--format", "json"], stdout=subprocess.PIPE, check=True) - return json.loads(proc.stdout) + all_packages = json.loads(proc.stdout) + print(f"All installed packages: {all_packages}") + return all_packages def map_package_names_to_files(package_names: List[str]): - """Placeholder docstring""" + """Map package names to their files.""" + print(f"Mapping package names to files for: {package_names}") m = {} batch_size = 20 with tqdm.tqdm(total=len(package_names), desc="Scanning for dependencies", ncols=100) as pbar: for pkg_names in batched(package_names, batch_size): m.update(get_all_files_for_installed_packages(list(pkg_names))) pbar.update(batch_size) + print(f"Processed batch: {pkg_names}") + print(f"Package name to file map: {m}") return m def get_currently_used_packages(): - """Placeholder docstring""" + """Get currently used packages.""" + print("Fetching currently used packages...") all_installed_packages = get_all_installed_packages() package_to_file_names = map_package_names_to_files([x["name"] for x in all_installed_packages]) + # print(f"package_to_file_names: {package_to_file_names}") currently_used_files = { Path(m.__file__) @@ -92,20 +104,40 @@ def get_currently_used_packages(): if inspect.ismodule(m) and hasattr(m, "__file__") and m.__file__ } + print(f"Currently used files: {currently_used_files}") + currently_used_packages = set() for file in currently_used_files: for package in package_to_file_names: if file in package_to_file_names[package]: + print(f"file: {file}") + print(f"package: {package}") currently_used_packages.add(package) + + # for module in sys.modules.values(): + # if inspect.ismodule(module): + # for _, obj in inspect.getmembers(module): + # if inspect.ismethod(obj) or inspect.isfunction(obj): + # source_code = inspect.getsource(obj) + # import_nodes = [node for node in ast.walk(ast.parse(source_code)) if isinstance(node, ast.Import)] + # for import_node in import_nodes: + # for alias in import_node.names: + # package_name = alias.name.split('.')[0] + # if package_name in package_to_file_names: + # currently_used_packages.add(package_name) + + print(f"Currently used packages: {currently_used_packages}") return currently_used_packages def get_requirements_for_pkl_file(pkl_path: Path, dest: Path): - """Placeholder docstring""" + """Get requirements for a pickled file.""" + print(f"Loading pickled file from {pkl_path}") with open(pkl_path, mode="rb") as file: cloudpickle.load(file) currently_used_packages = get_currently_used_packages() + print(f"Currently used packages after loading pkl: {currently_used_packages}") with open(dest, mode="w+") as out: for x in get_all_installed_packages(): @@ -115,24 +147,28 @@ def get_requirements_for_pkl_file(pkl_path: Path, dest: Path): if name == "boto3": boto3_version = boto3.__version__ out.write(f"boto3=={boto3_version}\n") + print(f"Added boto3=={boto3_version} to requirements") elif name in currently_used_packages: out.write(f"{name}=={version}\n") + print(f"Added {name}=={version} to requirements") def get_all_requirements(dest: Path): - """Placeholder docstring""" + """Get all installed requirements.""" + print(f"Getting all requirements and saving to {dest}") all_installed_packages = get_all_installed_packages() with open(dest, mode="w+") as out: for package_info in all_installed_packages: name = package_info.get("name") version = package_info.get("version") - out.write(f"{name}=={version}\n") + print(f"Added {name}=={version} to requirements2") def parse_args(): - """Placeholder docstring""" + """Parse command-line arguments.""" + print("Parsing command-line arguments...") parser = argparse.ArgumentParser( prog="pkl_requirements", description="Generates a requirements.txt for a cloudpickle file" ) @@ -144,17 +180,26 @@ def parse_args(): help="capture all dependencies in current environment", ) args = parser.parse_args() + print(f"Arguments parsed: {args}") return (Path(args.pkl_path), Path(args.dest), args.capture_all) def main(): - """Placeholder docstring""" + """Main function to execute the script.""" + print("Starting the main function...") pkl_path, dest, capture_all = parse_args() if capture_all: + print(f"Capturing all requirements to {dest}") get_all_requirements(dest) else: + print(f"Capturing requirements for pkl file {pkl_path} to {dest}") get_requirements_for_pkl_file(pkl_path, dest) if __name__ == "__main__": main() + +''' +capture_all is being set to False. Hence, we are getting reqs for pkl file. +Then we get currently used pkgs. Get all installed pkgs. +''' \ No newline at end of file