forked from MirrorHub/synapse
Support providing an index predicate for upserts. (#13822)
This is useful to upsert against a table which has a unique partial index while avoiding conflicts.
This commit is contained in:
parent
742f9f9d78
commit
b2b0c85279
3 changed files with 25 additions and 7 deletions
1
changelog.d/13822.misc
Normal file
1
changelog.d/13822.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Support providing an index predicate clause when doing upserts.
|
|
@ -533,6 +533,7 @@ class BackgroundUpdater:
|
||||||
index_name: name of index to add
|
index_name: name of index to add
|
||||||
table: table to add index to
|
table: table to add index to
|
||||||
columns: columns/expressions to include in index
|
columns: columns/expressions to include in index
|
||||||
|
where_clause: A WHERE clause to specify a partial unique index.
|
||||||
unique: true to make a UNIQUE index
|
unique: true to make a UNIQUE index
|
||||||
psql_only: true to only create this index on psql databases (useful
|
psql_only: true to only create this index on psql databases (useful
|
||||||
for virtual sqlite tables)
|
for virtual sqlite tables)
|
||||||
|
|
|
@ -1191,6 +1191,7 @@ class DatabasePool:
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
values: Dict[str, Any],
|
values: Dict[str, Any],
|
||||||
insertion_values: Optional[Dict[str, Any]] = None,
|
insertion_values: Optional[Dict[str, Any]] = None,
|
||||||
|
where_clause: Optional[str] = None,
|
||||||
lock: bool = True,
|
lock: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -1203,6 +1204,7 @@ class DatabasePool:
|
||||||
keyvalues: The unique key tables and their new values
|
keyvalues: The unique key tables and their new values
|
||||||
values: The nonunique columns and their new values
|
values: The nonunique columns and their new values
|
||||||
insertion_values: additional key/values to use only when inserting
|
insertion_values: additional key/values to use only when inserting
|
||||||
|
where_clause: An index predicate to apply to the upsert.
|
||||||
lock: True to lock the table when doing the upsert. Unused when performing
|
lock: True to lock the table when doing the upsert. Unused when performing
|
||||||
a native upsert.
|
a native upsert.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -1213,7 +1215,12 @@ class DatabasePool:
|
||||||
|
|
||||||
if table not in self._unsafe_to_upsert_tables:
|
if table not in self._unsafe_to_upsert_tables:
|
||||||
return self.simple_upsert_txn_native_upsert(
|
return self.simple_upsert_txn_native_upsert(
|
||||||
txn, table, keyvalues, values, insertion_values=insertion_values
|
txn,
|
||||||
|
table,
|
||||||
|
keyvalues,
|
||||||
|
values,
|
||||||
|
insertion_values=insertion_values,
|
||||||
|
where_clause=where_clause,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.simple_upsert_txn_emulated(
|
return self.simple_upsert_txn_emulated(
|
||||||
|
@ -1222,6 +1229,7 @@ class DatabasePool:
|
||||||
keyvalues,
|
keyvalues,
|
||||||
values,
|
values,
|
||||||
insertion_values=insertion_values,
|
insertion_values=insertion_values,
|
||||||
|
where_clause=where_clause,
|
||||||
lock=lock,
|
lock=lock,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1232,6 +1240,7 @@ class DatabasePool:
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
values: Dict[str, Any],
|
values: Dict[str, Any],
|
||||||
insertion_values: Optional[Dict[str, Any]] = None,
|
insertion_values: Optional[Dict[str, Any]] = None,
|
||||||
|
where_clause: Optional[str] = None,
|
||||||
lock: bool = True,
|
lock: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -1240,6 +1249,7 @@ class DatabasePool:
|
||||||
keyvalues: The unique key tables and their new values
|
keyvalues: The unique key tables and their new values
|
||||||
values: The nonunique columns and their new values
|
values: The nonunique columns and their new values
|
||||||
insertion_values: additional key/values to use only when inserting
|
insertion_values: additional key/values to use only when inserting
|
||||||
|
where_clause: An index predicate to apply to the upsert.
|
||||||
lock: True to lock the table when doing the upsert.
|
lock: True to lock the table when doing the upsert.
|
||||||
Returns:
|
Returns:
|
||||||
Returns True if a row was inserted or updated (i.e. if `values` is
|
Returns True if a row was inserted or updated (i.e. if `values` is
|
||||||
|
@ -1259,14 +1269,17 @@ class DatabasePool:
|
||||||
else:
|
else:
|
||||||
return "%s = ?" % (key,)
|
return "%s = ?" % (key,)
|
||||||
|
|
||||||
|
# Generate a where clause of each keyvalue and optionally the provided
|
||||||
|
# index predicate.
|
||||||
|
where = [_getwhere(k) for k in keyvalues]
|
||||||
|
if where_clause:
|
||||||
|
where.append(where_clause)
|
||||||
|
|
||||||
if not values:
|
if not values:
|
||||||
# If `values` is empty, then all of the values we care about are in
|
# If `values` is empty, then all of the values we care about are in
|
||||||
# the unique key, so there is nothing to UPDATE. We can just do a
|
# the unique key, so there is nothing to UPDATE. We can just do a
|
||||||
# SELECT instead to see if it exists.
|
# SELECT instead to see if it exists.
|
||||||
sql = "SELECT 1 FROM %s WHERE %s" % (
|
sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where))
|
||||||
table,
|
|
||||||
" AND ".join(_getwhere(k) for k in keyvalues),
|
|
||||||
)
|
|
||||||
sqlargs = list(keyvalues.values())
|
sqlargs = list(keyvalues.values())
|
||||||
txn.execute(sql, sqlargs)
|
txn.execute(sql, sqlargs)
|
||||||
if txn.fetchall():
|
if txn.fetchall():
|
||||||
|
@ -1277,7 +1290,7 @@ class DatabasePool:
|
||||||
sql = "UPDATE %s SET %s WHERE %s" % (
|
sql = "UPDATE %s SET %s WHERE %s" % (
|
||||||
table,
|
table,
|
||||||
", ".join("%s = ?" % (k,) for k in values),
|
", ".join("%s = ?" % (k,) for k in values),
|
||||||
" AND ".join(_getwhere(k) for k in keyvalues),
|
" AND ".join(where),
|
||||||
)
|
)
|
||||||
sqlargs = list(values.values()) + list(keyvalues.values())
|
sqlargs = list(values.values()) + list(keyvalues.values())
|
||||||
|
|
||||||
|
@ -1307,6 +1320,7 @@ class DatabasePool:
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
values: Dict[str, Any],
|
values: Dict[str, Any],
|
||||||
insertion_values: Optional[Dict[str, Any]] = None,
|
insertion_values: Optional[Dict[str, Any]] = None,
|
||||||
|
where_clause: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Use the native UPSERT functionality in PostgreSQL.
|
Use the native UPSERT functionality in PostgreSQL.
|
||||||
|
@ -1316,6 +1330,7 @@ class DatabasePool:
|
||||||
keyvalues: The unique key tables and their new values
|
keyvalues: The unique key tables and their new values
|
||||||
values: The nonunique columns and their new values
|
values: The nonunique columns and their new values
|
||||||
insertion_values: additional key/values to use only when inserting
|
insertion_values: additional key/values to use only when inserting
|
||||||
|
where_clause: An index predicate to apply to the upsert.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns True if a row was inserted or updated (i.e. if `values` is
|
Returns True if a row was inserted or updated (i.e. if `values` is
|
||||||
|
@ -1331,11 +1346,12 @@ class DatabasePool:
|
||||||
allvalues.update(values)
|
allvalues.update(values)
|
||||||
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
|
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
|
||||||
|
|
||||||
sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
|
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
|
||||||
table,
|
table,
|
||||||
", ".join(k for k in allvalues),
|
", ".join(k for k in allvalues),
|
||||||
", ".join("?" for _ in allvalues),
|
", ".join("?" for _ in allvalues),
|
||||||
", ".join(k for k in keyvalues),
|
", ".join(k for k in keyvalues),
|
||||||
|
f"WHERE {where_clause}" if where_clause else "",
|
||||||
latter,
|
latter,
|
||||||
)
|
)
|
||||||
txn.execute(sql, list(allvalues.values()))
|
txn.execute(sql, list(allvalues.values()))
|
||||||
|
|
Loading…
Reference in a new issue