diff --git a/include/ircd/m/event/auth.h b/include/ircd/m/event/auth.h index 4b5da7529..b3acaa1c5 100644 --- a/include/ircd/m/event/auth.h +++ b/include/ircd/m/event/auth.h @@ -13,33 +13,20 @@ struct ircd::m::event::auth { + struct hookdata; struct refs; struct chain; - struct hookdata; using passfail = std::tuple; + using events_view = vector_view; IRCD_M_EXCEPTION(error, FAIL, http::UNAUTHORIZED) static bool is_power_event(const event &); + + static passfail check(std::nothrow_t, const event &, hookdata &); static passfail check(std::nothrow_t, const event &); static void check(const event &); }; -struct ircd::m::event::auth::hookdata -{ - event::prev prev; - vector_view auth_events; - const event *auth_create {nullptr}; - const event *auth_power {nullptr}; - const event *auth_join_rules {nullptr}; - const event *auth_member_target {nullptr}; - const event *auth_member_sender {nullptr}; - - bool allow {false}; - std::exception_ptr fail; - - hookdata(const event &, const vector_view &auth_events); -}; - /// Interface to the references made by other power events to this power /// event in the `auth_events`. This interface only deals with power events, /// it doesn't care if a non-power event referenced a power event. This does @@ -86,3 +73,22 @@ struct ircd::m::event::auth::chain :idx{idx} {} }; + +class ircd::m::event::auth::hookdata +{ + const event *find(const event::closure_bool &) const; + + public: + event::prev prev; + vector_view auth_events; + const event *auth_create {nullptr}; + const event *auth_power {nullptr}; + const event *auth_join_rules {nullptr}; + const event *auth_member_target {nullptr}; + const event *auth_member_sender {nullptr}; + + bool allow {false}; + std::exception_ptr fail; + + hookdata(const event &, const events_view &auth_events); +}; diff --git a/ircd/m_event.cc b/ircd/m_event.cc index 6f4162810..e64fcf2d8 100644 --- a/ircd/m_event.cc +++ b/ircd/m_event.cc @@ -1845,20 +1845,23 @@ ircd::m::event::auth::check(std::nothrow_t, event, {authv, j} }; - try - { - event_auth_hook(event, data); - } - catch(const FAIL &e) - { - data.allow = false; - data.fail = std::current_exception(); - } + return check(std::nothrow, event, data); +} - return - { - data.allow, data.fail - }; +ircd::m::event::auth::passfail +ircd::m::event::auth::check(std::nothrow_t, + const event &event, + hookdata &data) +try +{ + event_auth_hook(event, data); + return {data.allow, data.fail}; +} +catch(const FAIL &e) +{ + data.allow = false; + data.fail = std::current_exception(); + return {data.allow, data.fail}; } ircd::m::event::auth::hookdata::hookdata(const m::event &event, @@ -1871,41 +1874,55 @@ ircd::m::event::auth::hookdata::hookdata(const m::event &event, { auth_events } +,auth_create { - for(size_t i(0); i < auth_events.size(); ++i) + find([](const auto &event) { - const m::event &a(*auth_events.at(i)); - const auto &type(json::get<"type"_>(a)); - if(type == "m.room.create") - { - assert(!auth_create); - auth_create = &a; - } - else if(type == "m.room.power_levels") - { - assert(!auth_power); - auth_power = &a; - } - else if(type == "m.room.join_rules") - { - assert(!auth_join_rules); - auth_join_rules = &a; - } - else if(type == "m.room.member") - { - if(json::get<"sender"_>(event) == json::get<"state_key"_>(a)) - { - assert(!auth_member_sender); - auth_member_sender = &a; - } + return json::get<"type"_>(event) == "m.room.create"; + }) +} +,auth_power +{ + find([](const auto &event) + { + return json::get<"type"_>(event) == "m.room.power_levels"; + }) +} +,auth_join_rules +{ + find([](const auto &event) + { + return json::get<"type"_>(event) == "m.room.join_rules"; + }) +} +,auth_member_target +{ + find([&event](const auto &auth_event) + { + return json::get<"type"_>(auth_event) == "m.room.member" && + json::get<"state_key"_>(auth_event) == json::get<"state_key"_>(event); + }) +} +,auth_member_sender +{ + find([&event](const auto &auth_event) + { + return json::get<"type"_>(auth_event) == "m.room.member" && + json::get<"state_key"_>(auth_event) == json::get<"sender"_>(event); + }) +} +{ +} - if(json::get<"state_key"_>(event) == json::get<"state_key"_>(a)) - { - assert(!auth_member_target); - auth_member_target = &a; - } - } - } +const ircd::m::event * +ircd::m::event::auth::hookdata::find(const event::closure_bool &closure) +const +{ + for(const auto *const &event : auth_events) + if(likely(event) && closure(*event)) + return event; + + return nullptr; } /*