mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 16:03:51 +01:00
Support providing an index predicate for upserts. (#13822)
This is useful to upsert against a table which has a unique partial index while avoiding conflicts.
This commit is contained in:
parent
742f9f9d78
commit
b2b0c85279
3 changed files with 25 additions and 7 deletions
1
changelog.d/13822.misc
Normal file
1
changelog.d/13822.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Support providing an index predicate clause when doing upserts.
|
|
@ -533,6 +533,7 @@ class BackgroundUpdater:
|
|||
index_name: name of index to add
|
||||
table: table to add index to
|
||||
columns: columns/expressions to include in index
|
||||
where_clause: A WHERE clause to specify a partial unique index.
|
||||
unique: true to make a UNIQUE index
|
||||
psql_only: true to only create this index on psql databases (useful
|
||||
for virtual sqlite tables)
|
||||
|
|
|
@ -1191,6 +1191,7 @@ class DatabasePool:
|
|||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
where_clause: Optional[str] = None,
|
||||
lock: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
|
@ -1203,6 +1204,7 @@ class DatabasePool:
|
|||
keyvalues: The unique key tables and their new values
|
||||
values: The nonunique columns and their new values
|
||||
insertion_values: additional key/values to use only when inserting
|
||||
where_clause: An index predicate to apply to the upsert.
|
||||
lock: True to lock the table when doing the upsert. Unused when performing
|
||||
a native upsert.
|
||||
Returns:
|
||||
|
@ -1213,7 +1215,12 @@ class DatabasePool:
|
|||
|
||||
if table not in self._unsafe_to_upsert_tables:
|
||||
return self.simple_upsert_txn_native_upsert(
|
||||
txn, table, keyvalues, values, insertion_values=insertion_values
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values=insertion_values,
|
||||
where_clause=where_clause,
|
||||
)
|
||||
else:
|
||||
return self.simple_upsert_txn_emulated(
|
||||
|
@ -1222,6 +1229,7 @@ class DatabasePool:
|
|||
keyvalues,
|
||||
values,
|
||||
insertion_values=insertion_values,
|
||||
where_clause=where_clause,
|
||||
lock=lock,
|
||||
)
|
||||
|
||||
|
@ -1232,6 +1240,7 @@ class DatabasePool:
|
|||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
where_clause: Optional[str] = None,
|
||||
lock: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
|
@ -1240,6 +1249,7 @@ class DatabasePool:
|
|||
keyvalues: The unique key tables and their new values
|
||||
values: The nonunique columns and their new values
|
||||
insertion_values: additional key/values to use only when inserting
|
||||
where_clause: An index predicate to apply to the upsert.
|
||||
lock: True to lock the table when doing the upsert.
|
||||
Returns:
|
||||
Returns True if a row was inserted or updated (i.e. if `values` is
|
||||
|
@ -1259,14 +1269,17 @@ class DatabasePool:
|
|||
else:
|
||||
return "%s = ?" % (key,)
|
||||
|
||||
# Generate a where clause of each keyvalue and optionally the provided
|
||||
# index predicate.
|
||||
where = [_getwhere(k) for k in keyvalues]
|
||||
if where_clause:
|
||||
where.append(where_clause)
|
||||
|
||||
if not values:
|
||||
# If `values` is empty, then all of the values we care about are in
|
||||
# the unique key, so there is nothing to UPDATE. We can just do a
|
||||
# SELECT instead to see if it exists.
|
||||
sql = "SELECT 1 FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join(_getwhere(k) for k in keyvalues),
|
||||
)
|
||||
sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where))
|
||||
sqlargs = list(keyvalues.values())
|
||||
txn.execute(sql, sqlargs)
|
||||
if txn.fetchall():
|
||||
|
@ -1277,7 +1290,7 @@ class DatabasePool:
|
|||
sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in values),
|
||||
" AND ".join(_getwhere(k) for k in keyvalues),
|
||||
" AND ".join(where),
|
||||
)
|
||||
sqlargs = list(values.values()) + list(keyvalues.values())
|
||||
|
||||
|
@ -1307,6 +1320,7 @@ class DatabasePool:
|
|||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
where_clause: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Use the native UPSERT functionality in PostgreSQL.
|
||||
|
@ -1316,6 +1330,7 @@ class DatabasePool:
|
|||
keyvalues: The unique key tables and their new values
|
||||
values: The nonunique columns and their new values
|
||||
insertion_values: additional key/values to use only when inserting
|
||||
where_clause: An index predicate to apply to the upsert.
|
||||
|
||||
Returns:
|
||||
Returns True if a row was inserted or updated (i.e. if `values` is
|
||||
|
@ -1331,11 +1346,12 @@ class DatabasePool:
|
|||
allvalues.update(values)
|
||||
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
|
||||
|
||||
sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
|
||||
table,
|
||||
", ".join(k for k in allvalues),
|
||||
", ".join("?" for _ in allvalues),
|
||||
", ".join(k for k in keyvalues),
|
||||
f"WHERE {where_clause}" if where_clause else "",
|
||||
latter,
|
||||
)
|
||||
txn.execute(sql, list(allvalues.values()))
|
||||
|
|
Loading…
Reference in a new issue