Skip to content

Commit 8616294

Browse files
committed
Mark all tracked users as dirty on expired SSS connections
See matrix-org/matrix-rust-sdk#3965 for more information. Requires `Extension.onRequest` to be `async`.
1 parent 1fd6675 commit 8616294

File tree

6 files changed

+77
-54
lines changed

6 files changed

+77
-54
lines changed

spec/integ/sliding-sync-sdk.spec.ts

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -640,11 +640,13 @@ describe("SlidingSyncSdk", () => {
640640
client!.crypto!.stop();
641641
});
642642

643-
it("gets enabled on the initial request only", () => {
644-
expect(ext.onRequest(true)).toEqual({
643+
it("gets enabled all the time", async () => {
644+
expect(await ext.onRequest(true)).toEqual({
645+
enabled: true,
646+
});
647+
expect(await ext.onRequest(false)).toEqual({
645648
enabled: true,
646649
});
647-
expect(ext.onRequest(false)).toEqual(undefined);
648650
});
649651

650652
it("can update device lists", () => {
@@ -686,11 +688,13 @@ describe("SlidingSyncSdk", () => {
686688
ext = findExtension("account_data");
687689
});
688690

689-
it("gets enabled on the initial request only", () => {
690-
expect(ext.onRequest(true)).toEqual({
691+
it("gets enabled all the time", async () => {
692+
expect(await ext.onRequest(true)).toEqual({
693+
enabled: true,
694+
});
695+
expect(await ext.onRequest(false)).toEqual({
691696
enabled: true,
692697
});
693-
expect(ext.onRequest(false)).toEqual(undefined);
694698
});
695699

696700
it("processes global account data", async () => {
@@ -814,8 +818,12 @@ describe("SlidingSyncSdk", () => {
814818
ext = findExtension("to_device");
815819
});
816820

817-
it("gets enabled with a limit on the initial request only", () => {
818-
const reqJson: any = ext.onRequest(true);
821+
it("gets enabled all the time", async () => {
822+
let reqJson: any = await ext.onRequest(true);
823+
expect(reqJson.enabled).toEqual(true);
824+
expect(reqJson.limit).toBeGreaterThan(0);
825+
expect(reqJson.since).toBeUndefined();
826+
reqJson = await ext.onRequest(false);
819827
expect(reqJson.enabled).toEqual(true);
820828
expect(reqJson.limit).toBeGreaterThan(0);
821829
expect(reqJson.since).toBeUndefined();
@@ -826,7 +834,7 @@ describe("SlidingSyncSdk", () => {
826834
next_batch: "12345",
827835
events: [],
828836
});
829-
expect(ext.onRequest(false)).toEqual({
837+
expect(await ext.onRequest(false)).toMatchObject({
830838
since: "12345",
831839
});
832840
});
@@ -910,11 +918,13 @@ describe("SlidingSyncSdk", () => {
910918
ext = findExtension("typing");
911919
});
912920

913-
it("gets enabled on the initial request only", () => {
914-
expect(ext.onRequest(true)).toEqual({
921+
it("gets enabled all the time", async () => {
922+
expect(await ext.onRequest(true)).toEqual({
923+
enabled: true,
924+
});
925+
expect(await ext.onRequest(false)).toEqual({
915926
enabled: true,
916927
});
917-
expect(ext.onRequest(false)).toEqual(undefined);
918928
});
919929

920930
it("processes typing notifications", async () => {
@@ -1035,11 +1045,13 @@ describe("SlidingSyncSdk", () => {
10351045
ext = findExtension("receipts");
10361046
});
10371047

1038-
it("gets enabled on the initial request only", () => {
1039-
expect(ext.onRequest(true)).toEqual({
1048+
it("gets enabled all the time", async () => {
1049+
expect(await ext.onRequest(true)).toEqual({
1050+
enabled: true,
1051+
});
1052+
expect(await ext.onRequest(false)).toEqual({
10401053
enabled: true,
10411054
});
1042-
expect(ext.onRequest(false)).toEqual(undefined);
10431055
});
10441056

10451057
it("processes receipts", async () => {

spec/integ/sliding-sync.spec.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ describe("SlidingSync", () => {
104104
};
105105
const ext: Extension<any, any> = {
106106
name: () => "custom_extension",
107-
onRequest: (initial) => {
108-
return { initial: initial };
107+
onRequest: async (_) => {
108+
return { initial: true };
109109
},
110110
onResponse: async (res) => {
111111
return;
@@ -827,7 +827,7 @@ describe("SlidingSync", () => {
827827

828828
const extPre: Extension<any, any> = {
829829
name: () => preExtName,
830-
onRequest: (initial) => {
830+
onRequest: async (initial) => {
831831
return onPreExtensionRequest(initial);
832832
},
833833
onResponse: (res) => {
@@ -837,7 +837,7 @@ describe("SlidingSync", () => {
837837
};
838838
const extPost: Extension<any, any> = {
839839
name: () => postExtName,
840-
onRequest: (initial) => {
840+
onRequest: async (initial) => {
841841
return onPostExtensionRequest(initial);
842842
},
843843
onResponse: (res) => {
@@ -852,7 +852,7 @@ describe("SlidingSync", () => {
852852

853853
const callbackOrder: string[] = [];
854854
let extensionOnResponseCalled = false;
855-
onPreExtensionRequest = () => {
855+
onPreExtensionRequest = async () => {
856856
return extReq;
857857
};
858858
onPreExtensionResponse = async (resp) => {
@@ -892,7 +892,7 @@ describe("SlidingSync", () => {
892892
});
893893

894894
it("should be able to send nothing in an extension request/response", async () => {
895-
onPreExtensionRequest = () => {
895+
onPreExtensionRequest = async () => {
896896
return undefined;
897897
};
898898
let responseCalled = false;
@@ -927,7 +927,7 @@ describe("SlidingSync", () => {
927927

928928
it("is possible to register extensions after start() has been called", async () => {
929929
slidingSync.registerExtension(extPost);
930-
onPostExtensionRequest = () => {
930+
onPostExtensionRequest = async () => {
931931
return extReq;
932932
};
933933
let responseCalled = false;

src/common-crypto/CryptoBackend.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,15 @@ export interface SyncCryptoCallbacks {
177177
* @param syncState - information about the completed sync.
178178
*/
179179
onSyncCompleted(syncState: OnSyncCompletedData): void;
180+
181+
/**
182+
* Mark all tracked user's device lists as dirty.
183+
*
184+
* This method will cause additional /keys/query requests on the server, so should be used only
185+
* when the client has desynced tracking device list deltas from the server.
186+
* In MSC4186: Simplified Sliding Sync, this can happen when the server expires the connection.
187+
*/
188+
markAllTrackedUsersAsDirty(): Promise<void>;
180189
}
181190

182191
/**

src/crypto/index.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3445,6 +3445,13 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
34453445
}
34463446
}
34473447

3448+
/**
3449+
* Implementation of {@link CryptoApi#markAllTrackedUsersAsDirty}.
3450+
*/
3451+
public async markAllTrackedUsersAsDirty(): Promise<void> {
3452+
// no op: we only expect rust crypto to be used in MSC4186.
3453+
}
3454+
34483455
/**
34493456
* Trigger the appropriate invalidations and removes for a given
34503457
* device list

src/sliding-sync-sdk.ts

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,16 @@ class ExtensionE2EE implements Extension<ExtensionE2EERequest, ExtensionE2EEResp
7575
return ExtensionState.PreProcess;
7676
}
7777

78-
public onRequest(isInitial: boolean): ExtensionE2EERequest | undefined {
79-
if (!isInitial) {
80-
return undefined;
78+
public async onRequest(isInitial: boolean): Promise<ExtensionE2EERequest> {
79+
if (isInitial) {
80+
// In SSS, the `?pos=` contains the stream position for device list updates.
81+
// If we do not have a `?pos=` (e.g because we forgot it, or because the server
82+
// invalidated our connection) then we MUST invlaidate all device lists because
83+
// the server will not tell us the delta. This will then cause UTDs as we will fail
84+
// to encrypt for new devices. This is an expensive call, so we should
85+
// really really remember `?pos=` wherever possible.
86+
logger.log("ExtensionE2EE: invalidating all device lists due to missing 'pos'");
87+
await this.crypto.markAllTrackedUsersAsDirty();
8188
}
8289
return {
8390
enabled: true, // this is sticky so only send it on the initial request
@@ -127,15 +134,12 @@ class ExtensionToDevice implements Extension<ExtensionToDeviceRequest, Extension
127134
return ExtensionState.PreProcess;
128135
}
129136

130-
public onRequest(isInitial: boolean): ExtensionToDeviceRequest {
131-
const extReq: ExtensionToDeviceRequest = {
137+
public async onRequest(isInitial: boolean): Promise<ExtensionToDeviceRequest> {
138+
return {
132139
since: this.nextBatch !== null ? this.nextBatch : undefined,
140+
limit: 100,
141+
enabled: true,
133142
};
134-
if (isInitial) {
135-
extReq["limit"] = 100;
136-
extReq["enabled"] = true;
137-
}
138-
return extReq;
139143
}
140144

141145
public async onResponse(data: ExtensionToDeviceResponse): Promise<void> {
@@ -209,10 +213,7 @@ class ExtensionAccountData implements Extension<ExtensionAccountDataRequest, Ext
209213
return ExtensionState.PostProcess;
210214
}
211215

212-
public onRequest(isInitial: boolean): ExtensionAccountDataRequest | undefined {
213-
if (!isInitial) {
214-
return undefined;
215-
}
216+
public async onRequest(isInitial: boolean): Promise<ExtensionAccountDataRequest> {
216217
return {
217218
enabled: true,
218219
};
@@ -279,10 +280,7 @@ class ExtensionTyping implements Extension<ExtensionTypingRequest, ExtensionTypi
279280
return ExtensionState.PostProcess;
280281
}
281282

282-
public onRequest(isInitial: boolean): ExtensionTypingRequest | undefined {
283-
if (!isInitial) {
284-
return undefined; // don't send a JSON object for subsequent requests, we don't need to.
285-
}
283+
public async onRequest(isInitial: boolean): Promise<ExtensionTypingRequest> {
286284
return {
287285
enabled: true,
288286
};
@@ -318,13 +316,10 @@ class ExtensionReceipts implements Extension<ExtensionReceiptsRequest, Extension
318316
return ExtensionState.PostProcess;
319317
}
320318

321-
public onRequest(isInitial: boolean): ExtensionReceiptsRequest | undefined {
322-
if (isInitial) {
323-
return {
324-
enabled: true,
325-
};
326-
}
327-
return undefined; // don't send a JSON object for subsequent requests, we don't need to.
319+
public async onRequest(isInitial: boolean): Promise<ExtensionReceiptsRequest> {
320+
return {
321+
enabled: true,
322+
};
328323
}
329324

330325
public async onResponse(data: ExtensionReceiptsResponse): Promise<void> {

src/sliding-sync.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,10 @@ export interface Extension<Req extends {}, Res extends {}> {
229229
/**
230230
* A function which is called when the request JSON is being formed.
231231
* Returns the data to insert under this key.
232-
* @param isInitial - True when this is part of the initial request (send sticky params)
232+
* @param isInitial - True when this is part of the initial request.
233233
* @returns The request JSON to send.
234234
*/
235-
onRequest(isInitial: boolean): Req | undefined;
235+
onRequest(isInitial: boolean): Promise<Req>;
236236
/**
237237
* A function which is called when there is response JSON under this extension.
238238
* @param data - The response JSON under the extension name.
@@ -471,11 +471,11 @@ export class SlidingSync extends TypedEventEmitter<SlidingSyncEvent, SlidingSync
471471
this.extensions[ext.name()] = ext;
472472
}
473473

474-
private getExtensionRequest(): Record<string, object | undefined> {
474+
private async getExtensionRequest(isInitial: boolean): Promise<Record<string, object | undefined>> {
475475
const ext: Record<string, object | undefined> = {};
476-
Object.keys(this.extensions).forEach((extName) => {
477-
ext[extName] = this.extensions[extName].onRequest(true);
478-
});
476+
for (const extName in this.extensions) {
477+
ext[extName] = await this.extensions[extName].onRequest(isInitial);
478+
}
479479
return ext;
480480
}
481481

@@ -582,7 +582,7 @@ export class SlidingSync extends TypedEventEmitter<SlidingSyncEvent, SlidingSync
582582
pos: currentPos,
583583
timeout: this.timeoutMS,
584584
clientTimeout: this.timeoutMS + BUFFER_PERIOD_MS,
585-
extensions: this.getExtensionRequest(),
585+
extensions: await this.getExtensionRequest(currentPos === undefined),
586586
};
587587
// check if we are (un)subscribing to a room and modify request this one time for it
588588
const newSubscriptions = difference(this.desiredRoomSubscriptions, this.confirmedRoomSubscriptions);

0 commit comments

Comments
 (0)