diff --git a/src/models/model.ts b/src/models/model.ts index 33ad284..441ab1d 100644 --- a/src/models/model.ts +++ b/src/models/model.ts @@ -149,7 +149,7 @@ class Model { version: number, accuracy?: number, tag?: string, - active?: boolean + active?: boolean, ) { this.modelsApiClient = modelsApiClient; this.name = name; @@ -168,7 +168,11 @@ class Model { * @returns {Array} - All feature descriptions of the model. */ describe(): Promise> { - return this.modelsApiClient.describeModel(this.name, this.project, this.version); + return this.modelsApiClient.describeModel( + this.name, + this.project, + this.version, + ); } /** @@ -177,8 +181,17 @@ class Model { * @param {string} unique_id - Optional unique id to filter the accuracy by. * @returns {Array} - Result. */ - describeAttribute(attribute: string, unique_id?: string): Promise> { - return this.modelsApiClient.describeModelAttribute(this.name, this.project, attribute, this.version, unique_id); + describeAttribute( + attribute: string, + unique_id?: string, + ): Promise> { + return this.modelsApiClient.describeModelAttribute( + this.name, + this.project, + attribute, + this.version, + unique_id, + ); } /** @@ -201,7 +214,7 @@ class Model { this.version, this.targetColumn, this.project, - options + options, ); } @@ -217,7 +230,7 @@ class Model { this.version, this.targetColumn, this.project, - options + options, ); } @@ -232,13 +245,13 @@ class Model { this.name, this.targetColumn, this.project, - options + options, ); } return this.modelsApiClient.retrainModel( this.name, this.targetColumn, - this.project + this.project, ); } @@ -251,6 +264,55 @@ class Model { finetune(integration: string, options: FinetuneOptions): Promise { return this.modelsApiClient.finetuneModel(this.name, this.project, options); } + /** + * List all versions of the model. + * + * @returns {Promise} - A promise that resolves to an array of ModelVersion objects. + */ + listVersions(): Promise { + return this.modelsApiClient.listVersions(this.project); + } + + /** + * Get a specific version of the model by its version number. + * + * @param {number} v - The version number to retrieve. + * @returns {Promise} - A promise that resolves to the requested ModelVersion. + */ + getVersion(v: number): Promise { + return this.modelsApiClient.getVersion( + Math.floor(v), + this.project, + this.name, + ); + } + + /** + * Drop a specific version of the model. + * + * @param {number} v - The version number to drop. + * @param {string} [project=this.project] - The project name. Defaults to the current project. + * @param {string} [model=this.name] - The model name. Defaults to the current model. + * @returns {Promise} - A promise that resolves when the version is dropped. + */ + dropVersion( + v: number, + project: string = this.project, + model: string = this.name, + ): Promise { + return this.modelsApiClient.dropVersion(Math.floor(v), project, model); + } + /** + * Sets the active version of the specified model within a given project. + * @param {number} v - The version number to set as active. + */ + setActiveVersion(v: number): Promise { + return this.modelsApiClient.setActiveVersion( + Math.floor(v), + this.project, + this, + ); + } /** * Creates a Model instance from a row returned from the MindsDB database. @@ -269,9 +331,56 @@ class Model { obj['version'], obj['accuracy'], obj['tag'], - obj['active'] + obj['active'], + ); + } +} + +/** + * Represents a MindsDB model with version and all supported operations. + */ +class ModelVersion extends Model { + /** + * Constructor for ModelVersion. + * + * @param {string} project - Name of the project the model belongs to. + * @param {object} data - Data containing the model details. + */ + constructor( + project: string, + data: { + modelsApiClient: ModelsApiClient; + name: string; + targetColumn: string; + status: string; + updateStatus: UpdateStatus; + version: number; + accuracy?: number; + tag?: string; + active?: boolean; + }, + ) { + super( + data.modelsApiClient, + data.name, + project, + data.targetColumn, + data.status, + data.updateStatus, + data.version, + data.accuracy, + data.tag, + data.active, ); + this.version = data.version; } } -export { Model, ModelFeatureDescription, ModelPrediction, ModelRow, ModelDescribeAttribute }; +export { + Model, + ModelFeatureDescription, + ModelPrediction, + ModelRow, + ModelDescribeAttribute, + ModelVersion, +}; diff --git a/src/models/modelsApiClient.ts b/src/models/modelsApiClient.ts index b49163e..4efb24f 100644 --- a/src/models/modelsApiClient.ts +++ b/src/models/modelsApiClient.ts @@ -1,4 +1,10 @@ -import { Model, ModelDescribeAttribute, ModelFeatureDescription, ModelPrediction } from './model'; +import { + Model, + ModelDescribeAttribute, + ModelFeatureDescription, + ModelPrediction, + ModelVersion, +} from './model'; import { BatchQueryOptions, QueryOptions } from './queryOptions'; import { FinetuneOptions, TrainingOptions } from './trainingOptions'; @@ -31,7 +37,7 @@ export default abstract class ModelsApiClient { abstract describeModel( name: string, project: string, - version?: number + version?: number, ): Promise>; /** @@ -48,7 +54,7 @@ export default abstract class ModelsApiClient { project: string, attribute: string, version?: number, - unique_id?: string + unique_id?: string, ): Promise>; /** @@ -74,7 +80,7 @@ export default abstract class ModelsApiClient { version: number, targetColumn: string, project: string, - options: QueryOptions + options: QueryOptions, ): Promise; /** @@ -91,7 +97,7 @@ export default abstract class ModelsApiClient { version: number, targetColumn: string, project: string, - options: BatchQueryOptions + options: BatchQueryOptions, ): Promise>; /** @@ -106,7 +112,7 @@ export default abstract class ModelsApiClient { name: string, targetColumn: string, project: string, - options: TrainingOptions + options: TrainingOptions, ): Promise; /** @@ -121,7 +127,7 @@ export default abstract class ModelsApiClient { name: string, targetColumn: string, project: string, - options?: TrainingOptions + options?: TrainingOptions, ): Promise; /** @@ -134,6 +140,56 @@ export default abstract class ModelsApiClient { abstract finetuneModel( name: string, project: string, - options?: FinetuneOptions + options?: FinetuneOptions, ): Promise; + /** + * Lists all versions of the model in the specified project. + * + * @param {string} project - The project to list the model versions from. + * @returns {Promise} - A promise that resolves to an array of ModelVersion objects. + */ + abstract listVersions(project: string): Promise; + + /** + * Gets a specific version of the model by its version number and name. + * + * @param {number} v - The version number to retrieve. + * @param {string} project - The project name. + * @param {string} name - The model name. + * @returns {Promise} - A promise that resolves to the requested ModelVersion. + * @throws {Error} - Throws an error if the version is not found. + */ + abstract getVersion( + v: number, + project: string, + name: string, + ): Promise; + + /** + * Drops a specific version of the model in the given project. + * + * @param {number} v - The version number to drop. + * @param {string} project - The project name. + * @param {string} model - The model name. + * @returns {Promise} - A promise that resolves when the version is dropped. + * @throws {MindsDbError} - Throws an error if something goes wrong during the operation. + */ + abstract dropVersion( + v: number, + project: string, + model: string, + ): Promise; + + /** + * Sets the active version of the specified model within a given project. + * @param {number} v - The version number to set as active. + * @param {string} project - The name of the project the model belongs to. + * @param {string} model - The name of the model for which to set the active version. + * @throws {MindsDbError} - If an error occurs while setting the active version. + */ + abstract setActiveVersion( + v: number, + project: string, + model: Model, + ): Promise; } diff --git a/src/models/modelsRestApiClient.ts b/src/models/modelsRestApiClient.ts index 0dea28f..90a69ea 100644 --- a/src/models/modelsRestApiClient.ts +++ b/src/models/modelsRestApiClient.ts @@ -8,10 +8,10 @@ import { ModelFeatureDescription, ModelPrediction, ModelRow, + ModelVersion, } from './model'; import { BatchQueryOptions, QueryOptions } from './queryOptions'; import { MindsDbError } from '../errors'; -import { version } from 'prettier'; /** Implementation of ModelsApiClient that goes through the REST API */ export default class ModelsRestApiClient extends ModelsApiClient { @@ -44,7 +44,7 @@ export default class ModelsRestApiClient extends ModelsApiClient { } private makeTrainingSelectClause( - options: TrainingOptions | FinetuneOptions + options: TrainingOptions | FinetuneOptions, ): string { const select = options['select']; if (select) { @@ -74,7 +74,7 @@ export default class ModelsRestApiClient extends ModelsApiClient { } private makeTrainingWindowHorizonClause( - trainingOptions: TrainingOptions + trainingOptions: TrainingOptions, ): string { const window = trainingOptions['window']; const horizon = trainingOptions['horizon']; @@ -101,7 +101,7 @@ export default class ModelsRestApiClient extends ModelsApiClient { // Escaping WHERE conditions is quite tricky. We should // come up with a better solution to indicate WHERE conditions // when querying so we aren't passing a raw string. - `AND ${o}` + `AND ${o}`, ) .join('\n'); } else { @@ -129,7 +129,7 @@ export default class ModelsRestApiClient extends ModelsApiClient { } private makeTrainingUsingClause( - options: FinetuneOptions | TrainingOptions + options: FinetuneOptions | TrainingOptions, ): string { const using = options['using']; if (!using) { @@ -159,7 +159,7 @@ export default class ModelsRestApiClient extends ModelsApiClient { override async getModel( name: string, project: string, - version?: number + version?: number, ): Promise { const selectQuery = `SELECT * FROM ${mysql.escapeId(project)}.models${ version ? '_versions' : '' @@ -182,7 +182,7 @@ export default class ModelsRestApiClient extends ModelsApiClient { const selectQuery = `SELECT * FROM ${mysql.escapeId(project)}.models`; const sqlQueryResult = await this.sqlClient.runQuery(selectQuery); return sqlQueryResult.rows.map((modelRow) => - Model.fromModelRow(modelRow as ModelRow, this) + Model.fromModelRow(modelRow as ModelRow, this), ); } @@ -196,11 +196,11 @@ export default class ModelsRestApiClient extends ModelsApiClient { override async describeModel( name: string, project: string, - version?: number + version?: number, ): Promise> { const describeQuery = `DESCRIBE ${mysql.escapeId(project)}.${mysql.escapeId( - name - )}.${ version ? `${mysql.escapeId(version.toString())}.` : ''}\`features\``; + name, + )}.${version ? `${mysql.escapeId(version.toString())}.` : ''}\`features\``; const sqlQueryResult = await this.sqlClient.runQuery(describeQuery); if (sqlQueryResult.rows.length === 0) { return []; @@ -222,11 +222,11 @@ export default class ModelsRestApiClient extends ModelsApiClient { project: string, attribute: string, version?: number, - unique_id?: string + unique_id?: string, ): Promise> { const describeQuery = `DESCRIBE ${mysql.escapeId(project)}.${mysql.escapeId( - name - )}.${ version ? `${mysql.escapeId(version.toString())}.` : ''}${mysql.escapeId(attribute)}${unique_id ? `.${mysql.escapeId(unique_id)}` : ''}`; + name, + )}.${version ? `${mysql.escapeId(version.toString())}.` : ''}${mysql.escapeId(attribute)}${unique_id ? `.${mysql.escapeId(unique_id)}` : ''}`; const sqlQueryResult = await this.sqlClient.runQuery(describeQuery); if (sqlQueryResult.rows.length === 0) { return []; @@ -243,7 +243,7 @@ export default class ModelsRestApiClient extends ModelsApiClient { */ override async deleteModel(name: string, project: string): Promise { const deleteQuery = `DROP MODEL ${mysql.escapeId(project)}.${mysql.escapeId( - name + name, )}`; const sqlQueryResult = await this.sqlClient.runQuery(deleteQuery); if (sqlQueryResult.error_message) { @@ -265,10 +265,10 @@ export default class ModelsRestApiClient extends ModelsApiClient { version: number, targetColumn: string, project: string, - options: QueryOptions + options: QueryOptions, ): Promise { const selectClause = `SELECT * FROM ${mysql.escapeId( - project + project, )}.${mysql.escapeId(name)}.${version}`; const whereClause = this.makeWhereClause(options['where'] || []); const usingClause = this.makeUsingClause(options['using'] || []); @@ -300,15 +300,15 @@ export default class ModelsRestApiClient extends ModelsApiClient { version: number, targetColumn: string, project: string, - options: BatchQueryOptions + options: BatchQueryOptions, ): Promise> { const selectClause = `SELECT m.${mysql.escapeId( - targetColumn + targetColumn, )} AS predicted, t.*, m.*`; const joinId = options['join']; const fromClause = `FROM ${mysql.escapeId(joinId)} AS t`; const joinClause = `JOIN ${mysql.escapeId(project)}.${mysql.escapeId( - name + name, )}.${version} AS m`; const whereClause = this.makeWhereClause(options['where'] || []); const limitClause = options['limit'] @@ -344,7 +344,7 @@ export default class ModelsRestApiClient extends ModelsApiClient { name: string, targetColumn: string, project: string, - trainingOptions: TrainingOptions + trainingOptions: TrainingOptions, ): Promise { const createClause = this.makeTrainingCreateClause(name, project); const fromClause = this.makeTrainingFromClause(trainingOptions); @@ -379,7 +379,7 @@ export default class ModelsRestApiClient extends ModelsApiClient { targetColumn, 'generating', 'up_to_date', - 1 + 1, ); } @@ -396,7 +396,7 @@ export default class ModelsRestApiClient extends ModelsApiClient { name: string, targetColumn: string, project: string, - trainingOptions?: TrainingOptions + trainingOptions?: TrainingOptions, ): Promise { const retrainClause = this.makeRetrainClause(name, project); let query = retrainClause; @@ -441,10 +441,10 @@ export default class ModelsRestApiClient extends ModelsApiClient { override async finetuneModel( name: string, project: string, - finetuneOptions: FinetuneOptions + finetuneOptions: FinetuneOptions, ): Promise { const finetuneClause = `FINETUNE ${mysql.escapeId(project)}.${mysql.escapeId( - name + name, )} FROM ${mysql.escapeId(finetuneOptions['integration'])}`; const selectClause = this.makeTrainingSelectClause(finetuneOptions); const usingClause = this.makeTrainingUsingClause(finetuneOptions); @@ -456,4 +456,82 @@ export default class ModelsRestApiClient extends ModelsApiClient { return Model.fromModelRow(sqlQueryResult.rows[0] as ModelRow, this); } + + /** + * List all versions of the model in the specified project. + * + * @param {string} project - The project to list the model versions from. + * @returns {Promise} - A promise that resolves to an array of ModelVersion objects. + */ + override async listVersions(project: string): Promise { + const allModels = await this.getAllModels(project); + return allModels.map((model: any) => new ModelVersion(project, model)); + } + + /** + * Get a specific version of the model by its version number and name. + * + * @param {number} v - The version number to retrieve. + * @param {string} project - The project name. + * @param {string} name - The model name. + * @returns {Promise} - A promise that resolves to the requested ModelVersion. + * @throws {Error} - Throws an error if the version is not found. + */ + override async getVersion( + v: number, + project: string, + name: string, + ): Promise { + const allModels = await this.listVersions(project); + for (const model of allModels) { + if (model.version === v && model.name === name) { + return model; + } + } + throw new Error('Version is not found'); + } + + /** + * Drop a specific version of the model in the given project. + * + * @param {number} v - The version number to drop. + * @param {string} project - The project name. + * @param {string} model - The model name. + * @returns {Promise} - A promise that resolves when the version is dropped. + * @throws {MindsDbError} - Throws an error if something goes wrong during the operation. + */ + override async dropVersion( + v: number, + project: string, + model: string, + ): Promise { + const deleteQuery = `DROP MODEL ${mysql.escapeId(project)}.${mysql.escapeId( + model, + )}.${mysql.escapeId(v)}`; + const sqlQueryResult = await this.sqlClient.runQuery(deleteQuery); + if (sqlQueryResult.error_message) { + throw new MindsDbError(sqlQueryResult.error_message); + } + } + + /** + * Sets the active version of the specified model within a given project. + * @param {number} v - The version number to set as active. + * @param {string} project - The name of the project the model belongs to. + * @param {Model} model - The model for which to set the active version. + * @throws {MindsDbError} - If an error occurs while setting the active version. + */ + override async setActiveVersion(v: number, project: string, model: Model) { + const query = `SET model_active = ${mysql.escapeId(project)}.${mysql.escapeId( + model.name, + )}.${mysql.escapeId(v.toString())};`; + await this.sqlClient + .runQuery(query) + .then( + async () => + (model = + (await this.getModel(model.name, project)) ?? + new ModelVersion(project, { ...model, version: v })), + ); + } }