Skip to content

Commit

Permalink
Merge pull request #99 from ArnavK-09/feat/mode/version-control
Browse files Browse the repository at this point in the history
feat(models): set of functions to manage model versions!
  • Loading branch information
setohe0909 authored Nov 6, 2024
2 parents 064da88 + aa87076 commit d5f75f6
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 42 deletions.
129 changes: 119 additions & 10 deletions src/models/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class Model {
version: number,
accuracy?: number,
tag?: string,
active?: boolean
active?: boolean,
) {
this.modelsApiClient = modelsApiClient;
this.name = name;
Expand All @@ -168,7 +168,11 @@ class Model {
* @returns {Array<ModelFeatureDescription>} - All feature descriptions of the model.
*/
describe(): Promise<Array<ModelFeatureDescription>> {
return this.modelsApiClient.describeModel(this.name, this.project, this.version);
return this.modelsApiClient.describeModel(
this.name,
this.project,
this.version,
);
}

/**
Expand All @@ -177,8 +181,17 @@ class Model {
* @param {string} unique_id - Optional unique id to filter the accuracy by.
* @returns {Array<ModelDescribeAttribute>} - Result.
*/
describeAttribute(attribute: string, unique_id?: string): Promise<Array<ModelDescribeAttribute>> {
return this.modelsApiClient.describeModelAttribute(this.name, this.project, attribute, this.version, unique_id);
describeAttribute(
attribute: string,
unique_id?: string,
): Promise<Array<ModelDescribeAttribute>> {
return this.modelsApiClient.describeModelAttribute(
this.name,
this.project,
attribute,
this.version,
unique_id,
);
}

/**
Expand All @@ -201,7 +214,7 @@ class Model {
this.version,
this.targetColumn,
this.project,
options
options,
);
}

Expand All @@ -217,7 +230,7 @@ class Model {
this.version,
this.targetColumn,
this.project,
options
options,
);
}

Expand All @@ -232,13 +245,13 @@ class Model {
this.name,
this.targetColumn,
this.project,
options
options,
);
}
return this.modelsApiClient.retrainModel(
this.name,
this.targetColumn,
this.project
this.project,
);
}

Expand All @@ -251,6 +264,55 @@ class Model {
finetune(integration: string, options: FinetuneOptions): Promise<Model> {
return this.modelsApiClient.finetuneModel(this.name, this.project, options);
}
/**
* List all versions of the model.
*
* @returns {Promise<ModelVersion[]>} - A promise that resolves to an array of ModelVersion objects.
*/
listVersions(): Promise<ModelVersion[]> {
return this.modelsApiClient.listVersions(this.project);
}

/**
* Get a specific version of the model by its version number.
*
* @param {number} v - The version number to retrieve.
* @returns {Promise<ModelVersion>} - A promise that resolves to the requested ModelVersion.
*/
getVersion(v: number): Promise<ModelVersion> {
return this.modelsApiClient.getVersion(
Math.floor(v),
this.project,
this.name,
);
}

/**
* Drop a specific version of the model.
*
* @param {number} v - The version number to drop.
* @param {string} [project=this.project] - The project name. Defaults to the current project.
* @param {string} [model=this.name] - The model name. Defaults to the current model.
* @returns {Promise<void>} - A promise that resolves when the version is dropped.
*/
dropVersion(
v: number,
project: string = this.project,
model: string = this.name,
): Promise<void> {
return this.modelsApiClient.dropVersion(Math.floor(v), project, model);
}
/**
* Sets the active version of the specified model within a given project.
* @param {number} v - The version number to set as active.
*/
setActiveVersion(v: number): Promise<void> {
return this.modelsApiClient.setActiveVersion(
Math.floor(v),
this.project,
this,
);
}

/**
* Creates a Model instance from a row returned from the MindsDB database.
Expand All @@ -269,9 +331,56 @@ class Model {
obj['version'],
obj['accuracy'],
obj['tag'],
obj['active']
obj['active'],
);
}
}

/**
* Represents a MindsDB model with version and all supported operations.
*/
class ModelVersion extends Model {
/**
* Constructor for ModelVersion.
*
* @param {string} project - Name of the project the model belongs to.
* @param {object} data - Data containing the model details.
*/
constructor(
project: string,
data: {
modelsApiClient: ModelsApiClient;
name: string;
targetColumn: string;
status: string;
updateStatus: UpdateStatus;
version: number;
accuracy?: number;
tag?: string;
active?: boolean;
},
) {
super(
data.modelsApiClient,
data.name,
project,
data.targetColumn,
data.status,
data.updateStatus,
data.version,
data.accuracy,
data.tag,
data.active,
);
this.version = data.version;
}
}

export { Model, ModelFeatureDescription, ModelPrediction, ModelRow, ModelDescribeAttribute };
export {
Model,
ModelFeatureDescription,
ModelPrediction,
ModelRow,
ModelDescribeAttribute,
ModelVersion,
};
72 changes: 64 additions & 8 deletions src/models/modelsApiClient.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import { Model, ModelDescribeAttribute, ModelFeatureDescription, ModelPrediction } from './model';
import {
Model,
ModelDescribeAttribute,
ModelFeatureDescription,
ModelPrediction,
ModelVersion,
} from './model';
import { BatchQueryOptions, QueryOptions } from './queryOptions';
import { FinetuneOptions, TrainingOptions } from './trainingOptions';

Expand Down Expand Up @@ -31,7 +37,7 @@ export default abstract class ModelsApiClient {
abstract describeModel(
name: string,
project: string,
version?: number
version?: number,
): Promise<Array<ModelFeatureDescription>>;

/**
Expand All @@ -48,7 +54,7 @@ export default abstract class ModelsApiClient {
project: string,
attribute: string,
version?: number,
unique_id?: string
unique_id?: string,
): Promise<Array<ModelDescribeAttribute>>;

/**
Expand All @@ -74,7 +80,7 @@ export default abstract class ModelsApiClient {
version: number,
targetColumn: string,
project: string,
options: QueryOptions
options: QueryOptions,
): Promise<ModelPrediction>;

/**
Expand All @@ -91,7 +97,7 @@ export default abstract class ModelsApiClient {
version: number,
targetColumn: string,
project: string,
options: BatchQueryOptions
options: BatchQueryOptions,
): Promise<Array<ModelPrediction>>;

/**
Expand All @@ -106,7 +112,7 @@ export default abstract class ModelsApiClient {
name: string,
targetColumn: string,
project: string,
options: TrainingOptions
options: TrainingOptions,
): Promise<Model>;

/**
Expand All @@ -121,7 +127,7 @@ export default abstract class ModelsApiClient {
name: string,
targetColumn: string,
project: string,
options?: TrainingOptions
options?: TrainingOptions,
): Promise<Model>;

/**
Expand All @@ -134,6 +140,56 @@ export default abstract class ModelsApiClient {
abstract finetuneModel(
name: string,
project: string,
options?: FinetuneOptions
options?: FinetuneOptions,
): Promise<Model>;
/**
* Lists all versions of the model in the specified project.
*
* @param {string} project - The project to list the model versions from.
* @returns {Promise<ModelVersion[]>} - A promise that resolves to an array of ModelVersion objects.
*/
abstract listVersions(project: string): Promise<ModelVersion[]>;

/**
* Gets a specific version of the model by its version number and name.
*
* @param {number} v - The version number to retrieve.
* @param {string} project - The project name.
* @param {string} name - The model name.
* @returns {Promise<ModelVersion>} - A promise that resolves to the requested ModelVersion.
* @throws {Error} - Throws an error if the version is not found.
*/
abstract getVersion(
v: number,
project: string,
name: string,
): Promise<ModelVersion>;

/**
* Drops a specific version of the model in the given project.
*
* @param {number} v - The version number to drop.
* @param {string} project - The project name.
* @param {string} model - The model name.
* @returns {Promise<void>} - A promise that resolves when the version is dropped.
* @throws {MindsDbError} - Throws an error if something goes wrong during the operation.
*/
abstract dropVersion(
v: number,
project: string,
model: string,
): Promise<void>;

/**
* Sets the active version of the specified model within a given project.
* @param {number} v - The version number to set as active.
* @param {string} project - The name of the project the model belongs to.
* @param {string} model - The name of the model for which to set the active version.
* @throws {MindsDbError} - If an error occurs while setting the active version.
*/
abstract setActiveVersion(
v: number,
project: string,
model: Model,
): Promise<void>;
}
Loading

0 comments on commit d5f75f6

Please sign in to comment.