diff --git a/changelog.d/18635.feature b/changelog.d/18635.feature new file mode 100644 index 00000000000..af536f64d36 --- /dev/null +++ b/changelog.d/18635.feature @@ -0,0 +1 @@ +Support arbitrary profile fields. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 4958ab5e750..da392e115f3 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -93,9 +93,7 @@ async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDi if self.hs.is_mine(target_user): profileinfo = await self.store.get_profileinfo(target_user) - extra_fields = {} - if self.hs.config.experimental.msc4133_enabled: - extra_fields = await self.store.get_profile_fields(target_user) + extra_fields = await self.store.get_profile_fields(target_user) if ( profileinfo.display_name is None @@ -551,16 +549,16 @@ async def on_profile_query(self, args: JsonDict) -> JsonDict: # since then we send a null in the JSON response if avatar_url is not None: response["avatar_url"] = avatar_url - if self.hs.config.experimental.msc4133_enabled: - if just_field is None: - response.update(await self.store.get_profile_fields(user)) - elif just_field not in ( - ProfileFields.DISPLAYNAME, - ProfileFields.AVATAR_URL, - ): - response[just_field] = await self.store.get_profile_field( - user, just_field - ) + + if just_field is None: + response.update(await self.store.get_profile_fields(user)) + elif just_field not in ( + ProfileFields.DISPLAYNAME, + ProfileFields.AVATAR_URL, + ): + response[just_field] = await self.store.get_profile_field( + user, just_field + ) except StoreError as e: if e.code == 404: raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 8f3193fb470..a279db1cc5c 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -92,22 +92,22 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: "enabled": self.config.experimental.msc3664_enabled, } + disallowed_profile_fields = [] + response["capabilities"]["m.profile_fields"] = {"enabled": True} + if not self.config.registration.enable_set_displayname: + disallowed_profile_fields.append("displayname") + if not self.config.registration.enable_set_avatar_url: + disallowed_profile_fields.append("avatar_url") + if disallowed_profile_fields: + response["capabilities"]["m.profile_fields"]["disallowed"] = ( + disallowed_profile_fields + ) + + # For transition from unstable to stable identifiers. if self.config.experimental.msc4133_enabled: - response["capabilities"]["uk.tcpip.msc4133.profile_fields"] = { - "enabled": True, - } - - # Ensure this is consistent with the legacy m.set_displayname and - # m.set_avatar_url. - disallowed = [] - if not self.config.registration.enable_set_displayname: - disallowed.append("displayname") - if not self.config.registration.enable_set_avatar_url: - disallowed.append("avatar_url") - if disallowed: - response["capabilities"]["uk.tcpip.msc4133.profile_fields"][ - "disallowed" - ] = disallowed + response["capabilities"]["uk.tcpip.msc4133.profile_fields"] = response[ + "capabilities" + ]["m.profile_fields"] if self.config.experimental.msc4267_enabled: response["capabilities"]["org.matrix.msc4267.forget_forced_upon_leave"] = { diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index 8326d8017c9..243245f7393 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -57,161 +57,6 @@ def _read_propagate(hs: "HomeServer", request: SynapseRequest) -> bool: return propagate -class ProfileDisplaynameRestServlet(RestServlet): - PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) - CATEGORY = "Event sending requests" - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.profile_handler = hs.get_profile_handler() - self.auth = hs.get_auth() - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - requester_user = None - - if self.hs.config.server.require_auth_for_profile_requests: - requester = await self.auth.get_user_by_req(request) - requester_user = requester.user - - if not UserID.is_valid(user_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM - ) - - user = UserID.from_string(user_id) - await self.profile_handler.check_profile_query_allowed(user, requester_user) - - displayname = await self.profile_handler.get_displayname(user) - - ret = {} - if displayname is not None: - ret["displayname"] = displayname - - return 200, ret - - async def on_PUT( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - if not UserID.is_valid(user_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM - ) - - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user = UserID.from_string(user_id) - is_admin = await self.auth.is_server_admin(requester) - - content = parse_json_object_from_request(request) - - try: - new_name = content["displayname"] - except Exception: - raise SynapseError( - 400, "Missing key 'displayname'", errcode=Codes.MISSING_PARAM - ) - - propagate = _read_propagate(self.hs, request) - - requester_suspended = ( - await self.hs.get_datastores().main.get_user_suspended_status( - requester.user.to_string() - ) - ) - - if requester_suspended: - raise SynapseError( - 403, - "Updating displayname while account is suspended is not allowed.", - Codes.USER_ACCOUNT_SUSPENDED, - ) - - await self.profile_handler.set_displayname( - user, requester, new_name, is_admin, propagate=propagate - ) - - return 200, {} - - -class ProfileAvatarURLRestServlet(RestServlet): - PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) - CATEGORY = "Event sending requests" - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.profile_handler = hs.get_profile_handler() - self.auth = hs.get_auth() - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - requester_user = None - - if self.hs.config.server.require_auth_for_profile_requests: - requester = await self.auth.get_user_by_req(request) - requester_user = requester.user - - if not UserID.is_valid(user_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM - ) - - user = UserID.from_string(user_id) - await self.profile_handler.check_profile_query_allowed(user, requester_user) - - avatar_url = await self.profile_handler.get_avatar_url(user) - - ret = {} - if avatar_url is not None: - ret["avatar_url"] = avatar_url - - return 200, ret - - async def on_PUT( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - if not UserID.is_valid(user_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM - ) - - requester = await self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - is_admin = await self.auth.is_server_admin(requester) - - content = parse_json_object_from_request(request) - try: - new_avatar_url = content["avatar_url"] - except KeyError: - raise SynapseError( - 400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM - ) - - propagate = _read_propagate(self.hs, request) - - requester_suspended = ( - await self.hs.get_datastores().main.get_user_suspended_status( - requester.user.to_string() - ) - ) - - if requester_suspended: - raise SynapseError( - 403, - "Updating avatar URL while account is suspended is not allowed.", - Codes.USER_ACCOUNT_SUSPENDED, - ) - - await self.profile_handler.set_avatar_url( - user, requester, new_avatar_url, is_admin, propagate=propagate - ) - - return 200, {} - - class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) CATEGORY = "Event sending requests" @@ -244,12 +89,19 @@ async def on_GET( return 200, ret -class UnstableProfileFieldRestServlet(RestServlet): +class ProfileFieldRestServlet(RestServlet): PATTERNS = [ + *client_patterns( + "/profile/(?P[^/]*)/(?Pdisplayname)", v1=True + ), + *client_patterns( + "/profile/(?P[^/]*)/(?Pavatar_url)", v1=True + ), re.compile( - r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P[^/]*)/(?P[^/]*)" - ) + r"^/_matrix/client/v3/profile/(?P[^/]*)/(?P[^/]*)" + ), ] + CATEGORY = "Event sending requests" def __init__(self, hs: "HomeServer"): @@ -304,7 +156,10 @@ async def on_PUT( HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM ) - requester = await self.auth.get_user_by_req(request) + # Guest users are able to set their own displayname. + requester = await self.auth.get_user_by_req( + request, allow_guest=field_name == ProfileFields.DISPLAYNAME + ) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester) @@ -366,7 +221,10 @@ async def on_DELETE( HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM ) - requester = await self.auth.get_user_by_req(request) + # Guest users are able to set their own displayname. + requester = await self.auth.get_user_by_req( + request, allow_guest=field_name == ProfileFields.DISPLAYNAME + ) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester) @@ -413,11 +271,15 @@ async def on_DELETE( return 200, {} +class UnstableProfileFieldRestServlet(ProfileFieldRestServlet): + re.compile( + r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P[^/]*)/(?P[^/]*)" + ) + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - # The specific displayname / avatar URL / custom field endpoints *must* appear - # before their corresponding generic profile endpoint. - ProfileDisplaynameRestServlet(hs).register(http_server) - ProfileAvatarURLRestServlet(hs).register(http_server) + # The specific field endpoint *must* appear before the generic profile endpoint. + ProfileFieldRestServlet(hs).register(http_server) ProfileRestServlet(hs).register(http_server) if hs.config.experimental.msc4133_enabled: UnstableProfileFieldRestServlet(hs).register(http_server) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index 13831462e89..cdf31155fd9 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -130,6 +130,10 @@ def test_get_set_displayname_capabilities_displayname_disabled(self) -> None: self.assertEqual(channel.code, HTTPStatus.OK) self.assertFalse(capabilities["m.set_displayname"]["enabled"]) + self.assertTrue(capabilities["m.profile_fields"]["enabled"]) + self.assertEqual( + capabilities["m.profile_fields"]["disallowed"], ["displayname"] + ) @override_config({"enable_set_avatar_url": False}) def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None: @@ -141,6 +145,8 @@ def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None: self.assertEqual(channel.code, HTTPStatus.OK) self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["m.profile_fields"]["enabled"]) + self.assertEqual(capabilities["m.profile_fields"]["disallowed"], ["avatar_url"]) @override_config( { @@ -159,6 +165,10 @@ def test_get_set_displayname_capabilities_displayname_disabled_msc4133( self.assertEqual(channel.code, HTTPStatus.OK) self.assertFalse(capabilities["m.set_displayname"]["enabled"]) + self.assertTrue(capabilities["m.profile_fields"]["enabled"]) + self.assertEqual( + capabilities["m.profile_fields"]["disallowed"], ["displayname"] + ) self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"]) self.assertEqual( capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"], @@ -180,6 +190,8 @@ def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> No self.assertEqual(channel.code, HTTPStatus.OK) self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["m.profile_fields"]["enabled"]) + self.assertEqual(capabilities["m.profile_fields"]["disallowed"], ["avatar_url"]) self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"]) self.assertEqual( capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"], diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 708402b7929..49776d8e8cc 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -484,38 +484,34 @@ def test_msc4069_inhibit_propagation_like_default(self) -> None: # The client requested ?propagate=true, so it should have happened. self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif") - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_get_missing_custom_field(self) -> None: channel = self.make_request( "GET", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + f"/_matrix/client/v3/profile/{self.owner}/custom_field", ) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_get_missing_custom_field_invalid_field_name(self) -> None: channel = self.make_request( "GET", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/[custom_field]", + f"/_matrix/client/v3/profile/{self.owner}/[custom_field]", ) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_get_custom_field_rejects_bad_username(self) -> None: channel = self.make_request( "GET", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{urllib.parse.quote('@alice:')}/custom_field", + f"/_matrix/client/v3/profile/{urllib.parse.quote('@alice:')}/custom_field", ) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_set_custom_field(self) -> None: channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + f"/_matrix/client/v3/profile/{self.owner}/custom_field", content={"custom_field": "test"}, access_token=self.owner_tok, ) @@ -523,7 +519,7 @@ def test_set_custom_field(self) -> None: channel = self.make_request( "GET", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + f"/_matrix/client/v3/profile/{self.owner}/custom_field", ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.json_body, {"custom_field": "test"}) @@ -531,7 +527,7 @@ def test_set_custom_field(self) -> None: # Overwriting the field should work. channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + f"/_matrix/client/v3/profile/{self.owner}/custom_field", content={"custom_field": "new_Value"}, access_token=self.owner_tok, ) @@ -539,7 +535,7 @@ def test_set_custom_field(self) -> None: channel = self.make_request( "GET", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + f"/_matrix/client/v3/profile/{self.owner}/custom_field", ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.json_body, {"custom_field": "new_Value"}) @@ -547,7 +543,7 @@ def test_set_custom_field(self) -> None: # Deleting the field should work. channel = self.make_request( "DELETE", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + f"/_matrix/client/v3/profile/{self.owner}/custom_field", content={}, access_token=self.owner_tok, ) @@ -555,12 +551,11 @@ def test_set_custom_field(self) -> None: channel = self.make_request( "GET", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + f"/_matrix/client/v3/profile/{self.owner}/custom_field", ) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_non_string(self) -> None: """Non-string fields are supported for custom fields.""" fields = { @@ -574,7 +569,7 @@ def test_non_string(self) -> None: for key, value in fields.items(): channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + f"/_matrix/client/v3/profile/{self.owner}/{key}", content={key: value}, access_token=self.owner_tok, ) @@ -591,22 +586,20 @@ def test_non_string(self) -> None: for key, value in fields.items(): channel = self.make_request( "GET", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + f"/_matrix/client/v3/profile/{self.owner}/{key}", ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.json_body, {key: value}) - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_set_custom_field_noauth(self) -> None: channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + f"/_matrix/client/v3/profile/{self.owner}/custom_field", content={"custom_field": "test"}, ) self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN) - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_set_custom_field_size(self) -> None: """ Attempts to set a custom field name that is too long should get a 400 error. @@ -614,7 +607,7 @@ def test_set_custom_field_size(self) -> None: # Key is missing. channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/", + f"/_matrix/client/v3/profile/{self.owner}/", content={"": "test"}, access_token=self.owner_tok, ) @@ -625,7 +618,7 @@ def test_set_custom_field_size(self) -> None: key = "c" * 500 channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + f"/_matrix/client/v3/profile/{self.owner}/{key}", content={key: "test"}, access_token=self.owner_tok, ) @@ -634,7 +627,7 @@ def test_set_custom_field_size(self) -> None: channel = self.make_request( "DELETE", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + f"/_matrix/client/v3/profile/{self.owner}/{key}", content={key: "test"}, access_token=self.owner_tok, ) @@ -644,14 +637,13 @@ def test_set_custom_field_size(self) -> None: # Key doesn't match body. channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + f"/_matrix/client/v3/profile/{self.owner}/custom_field", content={"diff_key": "test"}, access_token=self.owner_tok, ) self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_set_custom_field_profile_too_long(self) -> None: """ Attempts to set a custom field that would push the overall profile too large. @@ -664,7 +656,7 @@ def test_set_custom_field_profile_too_long(self) -> None: key = "a" channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + f"/_matrix/client/v3/profile/{self.owner}/{key}", content={key: "a" * 65498}, access_token=self.owner_tok, ) @@ -692,7 +684,7 @@ def test_set_custom_field_profile_too_long(self) -> None: key = "b" channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + f"/_matrix/client/v3/profile/{self.owner}/{key}", content={key: "1" + "a" * ADDITIONAL_CHARS}, access_token=self.owner_tok, ) @@ -722,7 +714,7 @@ def test_set_custom_field_profile_too_long(self) -> None: key = "b" channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + f"/_matrix/client/v3/profile/{self.owner}/{key}", content={key: "" + "a" * ADDITIONAL_CHARS}, access_token=self.owner_tok, ) @@ -732,17 +724,16 @@ def test_set_custom_field_profile_too_long(self) -> None: key = "a" channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + f"/_matrix/client/v3/profile/{self.owner}/{key}", content={key: ""}, access_token=self.owner_tok, ) self.assertEqual(channel.code, 200, channel.result) - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_set_custom_field_displayname(self) -> None: channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/displayname", + f"/_matrix/client/v3/profile/{self.owner}/displayname", content={"displayname": "test"}, access_token=self.owner_tok, ) @@ -751,11 +742,10 @@ def test_set_custom_field_displayname(self) -> None: displayname = self._get_displayname() self.assertEqual(displayname, "test") - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_set_custom_field_avatar_url(self) -> None: channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/avatar_url", + f"/_matrix/client/v3/profile/{self.owner}/avatar_url", content={"avatar_url": "mxc://test/good"}, access_token=self.owner_tok, ) @@ -764,12 +754,11 @@ def test_set_custom_field_avatar_url(self) -> None: avatar_url = self._get_avatar_url() self.assertEqual(avatar_url, "mxc://test/good") - @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) def test_set_custom_field_other(self) -> None: """Setting someone else's profile field should fail""" channel = self.make_request( "PUT", - f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.other}/custom_field", + f"/_matrix/client/v3/profile/{self.other}/custom_field", content={"custom_field": "test"}, access_token=self.owner_tok, )