Skip to content

Commit

Permalink
fix unsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
cirospaciari committed Jan 20, 2025
1 parent 230ab9f commit 02ccb9c
Showing 1 changed file with 55 additions and 7 deletions.
62 changes: 55 additions & 7 deletions src/js/bun/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ enum SQLQueryResultMode {
values = 1,
raw = 2,
}
const escapeIdentifier = function escape(str) {
return '"' + str.replace(/"/g, '""').replace(/\./g, '"."') + '"';
};
class SQLResultArray extends PublicArray {
static [Symbol.toStringTag] = "SQLResults";

Expand All @@ -43,8 +46,8 @@ const _queryStatus = Symbol("status");
const _handler = Symbol("handler");
const _strings = Symbol("strings");
const _values = Symbol("values");
const _allowUnsafeTransaction = Symbol("allowUnsafeTransaction");
const _poolSize = Symbol("poolSize");
const _flags = Symbol("flags");
const PublicPromise = Promise;
type TransactionCallback = (sql: (strings: string, ...values: any[]) => Query) => Promise<any>;

Expand Down Expand Up @@ -83,14 +86,19 @@ function normalizeSSLMode(value: string): SSLMode {
throw $ERR_INVALID_ARG_VALUE("sslmode", value);
}

enum SQLQueryFlags {
none = 0,
allowUnsafeTransaction = 1 << 0,
unsafe = 1 << 1,
}
function getQueryHandle(query) {
let handle = query[_handle];
if (!handle) {
try {
query[_handle] = handle = doCreateQuery(
query[_strings],
query[_values],
query[_allowUnsafeTransaction],
query[_flags] & SQLQueryFlags.allowUnsafeTransaction,
query[_poolSize],
);
} catch (err) {
Expand Down Expand Up @@ -132,7 +140,7 @@ class Query extends PublicPromise {
this[_poolSize] = poolSize;
this[_strings] = strings;
this[_values] = values;
this[_allowUnsafeTransaction] = allowUnsafeTransaction;
this[_flags] = allowUnsafeTransaction;
}

async [_run]() {
Expand Down Expand Up @@ -915,16 +923,22 @@ function doCreateQuery(strings, values, allowUnsafeTransaction, poolSize) {
for (let i = 0; i < values.length; i++) {
const value = values[i];
if (value instanceof Query) {
const sub_strings = value[_strings];
let sub_strings = value[_strings];
var is_unsafe = value[_flags] & SQLQueryFlags.unsafe;

if (typeof sub_strings === "string") {
// just a single fixed query string fragment
if (!is_unsafe) {
// identifier
sub_strings = escapeIdentifier(sub_strings);
}
//@ts-ignore
final_strings.push(strings[strings_idx] + sub_strings + strings[strings_idx + 1]);
strings_idx += 2; // we merged 2 strings into 1
// in this case we dont have values to merge
} else {
// complex fragment, we need to merge values
const sub_values = value[_values];

if (final_strings.length > 0) {
// complex not the first
const current_idx = final_strings.length - 1;
Expand Down Expand Up @@ -1271,7 +1285,15 @@ function SQL(o, e = {}) {
}
function queryFromPool(strings, values) {
try {
return new Query(strings, values, false, connectionInfo.max, queryFromPoolHandler);
return new Query(strings, values, SQLQueryFlags.none, connectionInfo.max, queryFromPoolHandler);
} catch (err) {
return Promise.reject(err);
}
}

function unsafeQuery(strings, values) {
try {
return new Query(strings, values, SQLQueryFlags.unsafe, connectionInfo.max, queryFromPoolHandler);
} catch (err) {
return Promise.reject(err);
}
Expand Down Expand Up @@ -1301,7 +1323,7 @@ function SQL(o, e = {}) {
const query = new Query(
strings,
values,
true,
SQLQueryFlags.allowUnsafeTransaction,
connectionInfo.max,
queryFromTransactionHandler.bind(pooledConnection, transactionQueries),
);
Expand All @@ -1311,6 +1333,22 @@ function SQL(o, e = {}) {
return Promise.reject(err);
}
}
function unsafeQueryFromTransaction(strings, values, pooledConnection, transactionQueries) {
try {
const query = new Query(
strings,
values,
SQLQueryFlags.allowUnsafeTransaction | SQLQueryFlags.unsafe,
connectionInfo.max,
queryFromTransactionHandler.bind(pooledConnection, transactionQueries),
);
transactionQueries.add(query);
return query;
} catch (err) {
return Promise.reject(err);
}
}

function onTransactionDisconnected(err) {
const reject = this.reject;
this.connectionState |= ReservedConnectionState.closed;
Expand Down Expand Up @@ -1353,6 +1391,9 @@ function SQL(o, e = {}) {
// we use the same code path as the transaction sql
return queryFromTransaction(strings, values, pooledConnection, state.queries);
}
reserved_sql.unsafe = (string, args = []) => {
return unsafeQueryFromTransaction(string, args, pooledConnection, state.queries);
};
reserved_sql.connect = () => {
if (state.connectionState & ReservedConnectionState.closed) {
return Promise.reject(connectionClosedError());
Expand Down Expand Up @@ -1667,6 +1708,9 @@ function SQL(o, e = {}) {

return queryFromTransaction(strings, values, pooledConnection, state.queries);
}
transaction_sql.unsafe = (string, args = []) => {
return unsafeQueryFromTransaction(string, args, pooledConnection, state.queries);
};
// reserve is allowed to be called inside transaction connection but will return a new reserved connection from the pool and will not be part of the transaction
// this matchs the behavior of the postgres package
transaction_sql.reserve = () => sql.reserve();
Expand Down Expand Up @@ -1896,6 +1940,10 @@ function SQL(o, e = {}) {
return queryFromPool(strings, values);
}

sql.unsafe = (string, args = []) => {
return unsafeQuery(string, args);
};

sql.reserve = () => {
if (pool.closed) {
return Promise.reject(connectionClosedError());
Expand Down

0 comments on commit 02ccb9c

Please sign in to comment.