Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SafeNullComparisonPlugin plugin #1338

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions site/docs/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ A plugin that converts snake_case identifiers in the database into camelCase in
### Deduplicate joins plugin

Plugin that removes duplicate joins from queries. You can read more about it in the [examples](/docs/recipes/deduplicate-joins) section or check the [API docs](https://kysely-org.github.io/kysely-apidoc/classes/DeduplicateJoinsPlugin.html).

### Safe null comparison plugin

A plugin that automatically converts `=`, `!=` and `<>` to the equivalent `is` and `is not` predicates depending on the value of the variable. [Learn more](https://kysely-org.github.io/kysely-apidoc/classes/SafeNullComparisonPlugin.html).
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ export * from './plugin/camel-case/camel-case-plugin.js'
export * from './plugin/deduplicate-joins/deduplicate-joins-plugin.js'
export * from './plugin/with-schema/with-schema-plugin.js'
export * from './plugin/parse-json-results/parse-json-results-plugin.js'
export * from './plugin/safe-null-comparison/safe-null-comparison-plugin.js'

export * from './operation-node/add-column-node.js'
export * from './operation-node/add-constraint-node.js'
Expand Down
39 changes: 39 additions & 0 deletions src/plugin/safe-null-comparison/safe-null-comparison-plugin.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { QueryResult } from '../../driver/database-connection.js'
import { RootOperationNode } from '../../query-compiler/query-compiler.js'
import { UnknownRow } from '../../util/type-utils.js'
import {
KyselyPlugin,
PluginTransformQueryArgs,
PluginTransformResultArgs,
} from '../kysely-plugin.js'
import { SafeNullComparisonTransformer } from './safe-null-comparison-transformer.js'

/**
* Plugin that handles NULL comparisons to prevent common SQL mistakes.
*
* In SQL, comparing values with NULL using standard comparison operators (=, !=, <>)
* always yields NULL, which is usually not what developers expect. The correct way
* to compare with NULL is using IS NULL and IS NOT NULL.
*
* When working with nullable variables (e.g. string | null), you need to be careful to
* manually handle these cases with conditional WHERE clauses. This plugins automatically
* applies the correct operator based on the value, allowing you to simply write `query.where('name', '=', name)`.
*
* The plugin transforms the following operators when comparing with NULL:
* - `=` becomes `IS`
* - `!=` becomes `IS NOT`
* - `<>` becomes `IS NOT`
*/
export class SafeNullComparisonPlugin implements KyselyPlugin {
readonly #transformer = new SafeNullComparisonTransformer()

transformQuery(args: PluginTransformQueryArgs): RootOperationNode {
return this.#transformer.transformNode(args.node)
}

transformResult(
args: PluginTransformResultArgs,
): Promise<QueryResult<UnknownRow>> {
return Promise.resolve(args.result)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { BinaryOperationNode } from '../../operation-node/binary-operation-node.js'
import { OperationNodeTransformer } from '../../operation-node/operation-node-transformer.js'
import { OperatorNode } from '../../operation-node/operator-node.js'
import { ValueNode } from '../../operation-node/value-node.js'

export class SafeNullComparisonTransformer extends OperationNodeTransformer {
protected transformBinaryOperation(
node: BinaryOperationNode,
): BinaryOperationNode {
const { operator, leftOperand, rightOperand } =
super.transformBinaryOperation(node)

if (
!ValueNode.is(rightOperand) ||
rightOperand.value !== null ||
!OperatorNode.is(operator)
) {
return node
}

const op = operator.operator
if (op !== '=' && op !== '!=' && op !== '<>') {
return node
}

return BinaryOperationNode.create(
leftOperand,
OperatorNode.create(op === '=' ? 'is' : 'is not'),
rightOperand,
)
}
}
244 changes: 244 additions & 0 deletions test/node/src/safe-null-comparison-plugin.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import { SafeNullComparisonPlugin } from '../../..'

import {
clearDatabase,
destroyTest,
initTest,
TestContext,
testSql,
insertDefaultDataSet,
DIALECTS,
} from './test-setup.js'

for (const dialect of DIALECTS) {
describe(`${dialect}: safe null comparison`, () => {
let ctx: TestContext

before(async function () {
ctx = await initTest(this, dialect)
})

beforeEach(async () => {
await insertDefaultDataSet(ctx)
})

afterEach(async () => {
await clearDatabase(ctx)
})

after(async () => {
await destroyTest(ctx)
})

it('should replace = with is for null values', async () => {
const query = ctx.db
.withPlugin(new SafeNullComparisonPlugin())
.selectFrom('person')
.where('first_name', '=', null)

testSql(query, dialect, {
postgres: {
sql: 'select from "person" where "first_name" is $1',
parameters: [null],
},
mysql: {
sql: 'select from `person` where `first_name` is ?',
parameters: [null],
},
mssql: {
sql: 'select from "person" where "first_name" is @1',
parameters: [null],
},
sqlite: {
sql: 'select from "person" where "first_name" is ?',
parameters: [null],
},
})
})

it('should not replace = with is for non-null values', async () => {
const query = ctx.db
.withPlugin(new SafeNullComparisonPlugin())
.selectFrom('person')
.where('first_name', '=', 'Foo')

testSql(query, dialect, {
postgres: {
sql: 'select from "person" where "first_name" = $1',
parameters: ['Foo'],
},
mysql: {
sql: 'select from `person` where `first_name` = ?',
parameters: ['Foo'],
},
mssql: {
sql: 'select from "person" where "first_name" = @1',
parameters: ['Foo'],
},
sqlite: {
sql: 'select from "person" where "first_name" = ?',
parameters: ['Foo'],
},
})
})

it('should replace != with is not for null values', async () => {
const query = ctx.db
.withPlugin(new SafeNullComparisonPlugin())
.selectFrom('person')
.where('first_name', '!=', null)

testSql(query, dialect, {
postgres: {
sql: 'select from "person" where "first_name" is not $1',
parameters: [null],
},
mysql: {
sql: 'select from `person` where `first_name` is not ?',
parameters: [null],
},
mssql: {
sql: 'select from "person" where "first_name" is not @1',
parameters: [null],
},
sqlite: {
sql: 'select from "person" where "first_name" is not ?',
parameters: [null],
},
})
})

it('should not replace != with is not for non-null values', async () => {
const query = ctx.db
.withPlugin(new SafeNullComparisonPlugin())
.selectFrom('person')
.where('first_name', '!=', 'Foo')

testSql(query, dialect, {
postgres: {
sql: 'select from "person" where "first_name" != $1',
parameters: ['Foo'],
},
mysql: {
sql: 'select from `person` where `first_name` != ?',
parameters: ['Foo'],
},
mssql: {
sql: 'select from "person" where "first_name" != @1',
parameters: ['Foo'],
},
sqlite: {
sql: 'select from "person" where "first_name" != ?',
parameters: ['Foo'],
},
})
})

it('should replace <> with is not for null values', async () => {
const query = ctx.db
.withPlugin(new SafeNullComparisonPlugin())
.selectFrom('person')
.where('first_name', '<>', null)

testSql(query, dialect, {
postgres: {
sql: 'select from "person" where "first_name" is not $1',
parameters: [null],
},
mysql: {
sql: 'select from `person` where `first_name` is not ?',
parameters: [null],
},
mssql: {
sql: 'select from "person" where "first_name" is not @1',
parameters: [null],
},
sqlite: {
sql: 'select from "person" where "first_name" is not ?',
parameters: [null],
},
})
})

it('should not replace <> with is not for non-null values', async () => {
const query = ctx.db
.withPlugin(new SafeNullComparisonPlugin())
.selectFrom('person')
.where('first_name', '<>', 'Foo')

testSql(query, dialect, {
postgres: {
sql: 'select from "person" where "first_name" <> $1',
parameters: ['Foo'],
},
mysql: {
sql: 'select from `person` where `first_name` <> ?',
parameters: ['Foo'],
},
mssql: {
sql: 'select from "person" where "first_name" <> @1',
parameters: ['Foo'],
},
sqlite: {
sql: 'select from "person" where "first_name" <> ?',
parameters: ['Foo'],
},
})
})

it('should replace = with is with multiple where clauses', async () => {
const query = ctx.db
.withPlugin(new SafeNullComparisonPlugin())
.selectFrom('person')
.where('first_name', '=', null)
.where('last_name', '=', null)

testSql(query, dialect, {
postgres: {
sql: 'select from "person" where "first_name" is $1 and "last_name" is $2',
parameters: [null, null],
},
mysql: {
sql: 'select from `person` where `first_name` is ? and `last_name` is ?',
parameters: [null, null],
},
mssql: {
sql: 'select from "person" where "first_name" is @1 and "last_name" is @2',
parameters: [null, null],
},
sqlite: {
sql: 'select from "person" where "first_name" is ? and "last_name" is ?',
parameters: [null, null],
},
})
})

it('should work with mixed null and non-null values', async () => {
const query = ctx.db
.withPlugin(new SafeNullComparisonPlugin())
.selectFrom('person')
.where('first_name', '=', null)
.where('last_name', '!=', null)
.where('last_name', '=', 'Foo')

testSql(query, dialect, {
postgres: {
sql: 'select from "person" where "first_name" is $1 and "last_name" is not $2 and "last_name" = $3',
parameters: [null, null, 'Foo'],
},
mysql: {
sql: 'select from `person` where `first_name` is ? and `last_name` is not ? and `last_name` = ?',
parameters: [null, null, 'Foo'],
},
mssql: {
sql: 'select from "person" where "first_name" is @1 and "last_name" is not @2 and "last_name" = @3',
parameters: [null, null, 'Foo'],
},
sqlite: {
sql: 'select from "person" where "first_name" is ? and "last_name" is not ? and "last_name" = ?',
parameters: [null, null, 'Foo'],
},
})
})
})
}
Loading