-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathvalidate_data.py
More file actions
203 lines (164 loc) · 7.58 KB
/
validate_data.py
File metadata and controls
203 lines (164 loc) · 7.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import argparse
import gc
import sys
import os
from itertools import islice
from loguru import logger
from openai.types.chat import ChatCompletionMessageParam
from pydantic import TypeAdapter, ValidationError
from inference_hive.config import load_config_for_validation
from inference_hive.data_utils import load_data
from inference_hive import udf
def validate_input_data_format(sample_rows, input_column_name: str, id_column_name: str, api_type: str, log_samples: bool = True):
"""Validate that the input data format matches the expected API type format"""
if len(sample_rows) == 0:
logger.warning("Empty dataset, skipping format validation")
return
for i, row in enumerate(sample_rows):
if input_column_name not in row:
raise ValueError(f"Column '{input_column_name}' not found in dataset. Available columns: {list(row.keys())}")
if id_column_name not in row:
raise ValueError(f"Column '{id_column_name}' not found in dataset. Available columns: {list(row.keys())}")
data = row[input_column_name]
row_id = row[id_column_name]
# Validate that ID is a non-empty string
if row_id is None:
raise ValueError(f"ID column '{id_column_name}' contains None value in row {i}")
if not isinstance(row_id, str):
raise ValueError(
f"ID column '{id_column_name}' must contain strings. "
f"Found {type(row_id).__name__} in row {i}: {repr(row_id)}. "
f"Please convert your ID column to string dtype before running inference."
)
if row_id.strip() == "":
raise ValueError(f"ID column '{id_column_name}' contains empty string in row {i}")
if api_type == "completion":
# For completion API, data should be a string
if not isinstance(data, str):
raise ValueError(
f"For api_type='completion', input data must be strings. "
f"Found {type(data).__name__} in row {i}: {data}"
)
elif api_type == "chat-completion":
# For chat-completion API, use OpenAI's pydantic models for validation
if not isinstance(data, list):
raise ValueError(
f"For api_type='chat-completion', input data must be a list of messages. "
f"Found {type(data).__name__} in row {i}: {data}"
)
if len(data) == 0:
raise ValueError(f"Empty message list found in row {i}")
# Validate each message using OpenAI's pydantic models
message_adapter = TypeAdapter(list[ChatCompletionMessageParam])
try:
# This will validate the entire messages list according to OpenAI's schema
message_adapter.validate_python(data)
except ValidationError as e:
raise ValueError(
f"Invalid message format in row {i}. Messages must conform to OpenAI's ChatCompletionMessageParam format. "
f"Common issues: missing 'role' or 'content' fields, invalid role values, or incorrect data types.\n"
f"Validation error: {str(e)}\n"
f"Message data: {data}"
) from e
else:
raise ValueError(f"Invalid API type: {api_type}")
logger.info(f"Input data format validation passed for api_type='{api_type}' with string ID column '{id_column_name}' using OpenAI's pydantic models")
if log_samples:
logger.info("Sample rows:")
logger.info("=" * 80)
for i, row in enumerate(sample_rows):
row_id = row[id_column_name]
data = row[input_column_name]
logger.info("-" * 40)
logger.info(f"Sample {i+1}/{len(sample_rows)}")
logger.info(f"{row_id=}")
if api_type == "completion":
# For completion API, show the prompt string
logger.info(f"Prompt:\n{data}")
elif api_type == "chat-completion":
# For chat-completion API, show formatted messages
logger.info("Messages:")
for j, message in enumerate(data):
role = message['role']
content = message['content']
logger.info(f"[{j+1}]\n{role=}\n{content=}")
logger.info("=" * 80)
def validate_dataset_from_config(config_path: str, shard: int | None = None, num_shards: int | None = None):
"""
Validate dataset format based on config file.
Args:
config_path: Path to the YAML configuration file
shard: Optional shard number for validation (if None, validates full dataset)
num_shards: Total number of shards (required if shard is specified)
"""
logger.info(f"Loading configuration from: {config_path}")
config = load_config_for_validation(config_path)
logger.info("Loading dataset for validation...")
logger.info(f"Config: {config}")
ds = load_data(config, shard, num_shards)
logger.info(f"Dataset {shard=} loaded: {len(ds)} rows")
logger.info(f"{ds}")
# Validate the dataset format using a small sample (apply UDF if configured)
logger.info("Starting data validation...")
sample_size = 10
raw_sample_rows = list(islice(ds, sample_size))
if config.apply_udf:
try:
udf_func = getattr(udf, config.apply_udf)
except AttributeError:
raise ValueError(f"UDF function '{config.apply_udf}' not found in udf.py")
sample_rows = [udf_func(row, **(config.apply_udf_kwargs or {})) for row in raw_sample_rows]
else:
sample_rows = raw_sample_rows
validate_input_data_format(
sample_rows,
config.input_column_name,
config.id_column_name,
config.api_type,
)
logger.info("✓ Data validation completed successfully!")
# Proactively free dataset/Arrow resources before exiting, to prevent segfaults
try:
del ds
except Exception:
pass
return True
def main():
parser = argparse.ArgumentParser(description="Validate dataset format for LLM inference")
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to YAML configuration file"
)
parser.add_argument(
"--shard",
type=int,
default=None,
help="Shard number to validate (optional, validates full dataset if not specified)"
)
parser.add_argument(
"--num-shards",
type=int,
default=None,
help="Total number of shards (required if --shard is specified)"
)
args = parser.parse_args()
# Validate shard parameters
if args.shard is not None and args.num_shards is None:
logger.error("--num-shards must be specified when --shard is provided")
sys.exit(1)
if args.shard is not None and args.num_shards is not None:
if args.shard >= args.num_shards or args.shard < 0:
logger.error(f"shard ({args.shard}) must be between 0 and {args.num_shards - 1}")
sys.exit(1)
try:
validate_dataset_from_config(args.config, args.shard, args.num_shards)
logger.info("Data validation passed! Dataset is ready for inference.")
gc.collect()
os._exit(0)
except Exception as e:
logger.error(f"Data validation failed: {e}")
os._exit(1)
if __name__ == "__main__":
main()