Skip to content

Commit

Permalink
fix: Use Django ORM .db_manager() instead of .using() (#23595)
Browse files Browse the repository at this point in the history
  • Loading branch information
Twixes committed Jul 11, 2024
1 parent b564696 commit e4eb50e
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 38 deletions.
8 changes: 4 additions & 4 deletions posthog/api/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def my_flags(self, request: request.Request, **kwargs):
methods=["GET"], detail=False, throttle_classes=[FeatureFlagThrottle], required_scopes=["feature_flag:read"]
)
def local_evaluation(self, request: request.Request, **kwargs):
feature_flags: QuerySet[FeatureFlag] = FeatureFlag.objects.using(DATABASE_FOR_LOCAL_EVALUATION).filter(
feature_flags: QuerySet[FeatureFlag] = FeatureFlag.objects.db_manager(DATABASE_FOR_LOCAL_EVALUATION).filter(
team_id=self.team_id, deleted=False, active=True
)

Expand All @@ -548,7 +548,7 @@ def local_evaluation(self, request: request.Request, **kwargs):
if should_send_cohorts:
seen_cohorts_cache = {
cohort.pk: cohort
for cohort in Cohort.objects.using(DATABASE_FOR_LOCAL_EVALUATION).filter(
for cohort in Cohort.objects.db_manager(DATABASE_FOR_LOCAL_EVALUATION).filter(
team_id=self.team_id, deleted=False
)
}
Expand Down Expand Up @@ -591,7 +591,7 @@ def local_evaluation(self, request: request.Request, **kwargs):
cohort = seen_cohorts_cache[id]
else:
cohort = (
Cohort.objects.using(DATABASE_FOR_LOCAL_EVALUATION)
Cohort.objects.db_manager(DATABASE_FOR_LOCAL_EVALUATION)
.filter(id=id, team_id=self.team_id, deleted=False)
.first()
)
Expand All @@ -611,7 +611,7 @@ def local_evaluation(self, request: request.Request, **kwargs):
],
"group_type_mapping": {
str(row.group_type_index): row.group_type
for row in GroupTypeMapping.objects.using(DATABASE_FOR_LOCAL_EVALUATION).filter(
for row in GroupTypeMapping.objects.db_manager(DATABASE_FOR_LOCAL_EVALUATION).filter(
team_id=self.team_id
)
},
Expand Down
30 changes: 16 additions & 14 deletions posthog/api/test/test_decide.py
Original file line number Diff line number Diff line change
Expand Up @@ -3660,16 +3660,16 @@ class TestDecideUsesReadReplica(TransactionTestCase):
databases = {"default", "replica"}

def setup_user_and_team_in_db(self, dbname: str = "default"):
organization = Organization.objects.using(dbname).create(
organization = Organization.objects.db_manager(dbname).create(
name="Org 1", slug=f"org-{dbname}-{random.randint(1, 1000000)}"
)
team = Team.objects.using(dbname).create(organization=organization, name="Team 1 org 1")
user = User.objects.using(dbname).create(
team = Team.objects.db_manager(dbname).create(organization=organization, name="Team 1 org 1")
user = User.objects.db_manager(dbname).create(
email=f"test-{random.randint(1, 100000)}@posthog.com",
password="password",
first_name="first_name",
)
OrganizationMembership.objects.using(dbname).create(
OrganizationMembership.objects.db_manager(dbname).create(
user=user,
organization=organization,
level=OrganizationMembership.Level.OWNER,
Expand All @@ -3681,7 +3681,7 @@ def setup_flags_in_db(self, dbname, team, user, flags, persons):
created_flags = []
created_persons = []
for flag in flags:
f = FeatureFlag.objects.using(dbname).create(
f = FeatureFlag.objects.db_manager(dbname).create(
team=team,
rollout_percentage=flag.get("rollout_percentage") or None,
filters=flag.get("filters") or {},
Expand All @@ -3692,13 +3692,13 @@ def setup_flags_in_db(self, dbname, team, user, flags, persons):
)
created_flags.append(f)
for person in persons:
p = Person.objects.using(dbname).create(
p = Person.objects.db_manager(dbname).create(
team=team,
properties=person["properties"],
)
created_persons.append(p)
for distinct_id in person["distinct_ids"]:
PersonDistinctId.objects.using(dbname).create(person=p, distinct_id=distinct_id, team=team)
PersonDistinctId.objects.db_manager(dbname).create(person=p, distinct_id=distinct_id, team=team)

return created_flags, created_persons

Expand Down Expand Up @@ -4132,15 +4132,15 @@ def test_feature_flags_v3_consistent_flags(self, mock_is_connected):
) # assigned by distinct_id hash

# new person, merged from old distinct ID
PersonDistinctId.objects.using("default").create(person=person, distinct_id="other_id", team=self.team)
PersonDistinctId.objects.db_manager("default").create(person=person, distinct_id="other_id", team=self.team)
# hash key override already exists
FeatureFlagHashKeyOverride.objects.using("default").create(
FeatureFlagHashKeyOverride.objects.db_manager("default").create(
team=self.team,
person=person,
hash_key="example_id",
feature_flag_key="beta-feature",
)
FeatureFlagHashKeyOverride.objects.using("default").create(
FeatureFlagHashKeyOverride.objects.db_manager("default").create(
team=self.team,
person=person,
hash_key="example_id",
Expand Down Expand Up @@ -4305,7 +4305,7 @@ def test_feature_flags_v3_consistent_flags_with_write_on_hash_key_overrides(self
) # assigned by distinct_id hash

# new person, merged from old distinct ID
PersonDistinctId.objects.using("default").create(person=person, distinct_id="other_id", team=self.team)
PersonDistinctId.objects.db_manager("default").create(person=person, distinct_id="other_id", team=self.team)

# request with hash key overrides and _new_ writes should go to main database
with self.assertNumQueries(8, using="replica"), self.assertNumQueries(9, using="default"):
Expand Down Expand Up @@ -4388,10 +4388,12 @@ def test_feature_flags_v2_with_groups(self, mock_is_connected):
]
self.setup_flags_in_db("replica", team, user, flags, persons)

GroupTypeMapping.objects.using("replica").create(team=self.team, group_type="organization", group_type_index=0)
GroupTypeMapping.objects.using("default").create(team=self.team, group_type="project", group_type_index=1)
GroupTypeMapping.objects.db_manager("replica").create(
team=self.team, group_type="organization", group_type_index=0
)
GroupTypeMapping.objects.db_manager("default").create(team=self.team, group_type="project", group_type_index=1)

Group.objects.using("replica").create(
Group.objects.db_manager("replica").create(
team_id=self.team.pk,
group_type_index=0,
group_key="foo",
Expand Down
2 changes: 1 addition & 1 deletion posthog/models/cohort/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def get_dependent_cohorts(
if not current_cohort:
continue
else:
current_cohort = Cohort.objects.using(using_database).get(
current_cohort = Cohort.objects.db_manager(using_database).get(
pk=cohort_id, team_id=cohort.team_id, deleted=False
)
seen_cohorts_cache[cohort_id] = current_cohort
Expand Down
6 changes: 3 additions & 3 deletions posthog/models/feature_flag/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def transform_cohort_filters_for_easy_evaluation(
if not cohort:
return self.conditions
else:
cohort = Cohort.objects.using(using_database).get(
cohort = Cohort.objects.db_manager(using_database).get(
pk=cohort_id, team_id=self.team_id, deleted=False
)
seen_cohorts_cache[cohort_id] = cohort
Expand Down Expand Up @@ -284,7 +284,7 @@ def get_cohort_ids(
if not cohort:
continue
else:
cohort = Cohort.objects.using(using_database).get(
cohort = Cohort.objects.db_manager(using_database).get(
pk=cohort_id, team_id=self.team_id, deleted=False
)
seen_cohorts_cache[cohort_id] = cohort
Expand Down Expand Up @@ -407,7 +407,7 @@ def set_feature_flags_for_team_in_cache(
all_feature_flags = feature_flags
else:
all_feature_flags = list(
FeatureFlag.objects.using(using_database).filter(team_id=team_id, active=True, deleted=False)
FeatureFlag.objects.db_manager(using_database).filter(team_id=team_id, active=True, deleted=False)
)

serialized_flags = MinimalFeatureFlagSerializer(all_feature_flags, many=True).data
Expand Down
16 changes: 9 additions & 7 deletions posthog/models/feature_flag/flag_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def group_types_to_indexes(self) -> dict[GroupTypeName, GroupTypeIndex]:
raise DatabaseError("Failed to fetch group type mapping previously, not trying again.")
try:
with execute_with_timeout(FLAG_MATCHING_QUERY_TIMEOUT_MS, DATABASE_FOR_FLAG_MATCHING):
group_type_mapping_rows = GroupTypeMapping.objects.using(DATABASE_FOR_FLAG_MATCHING).filter(
group_type_mapping_rows = GroupTypeMapping.objects.db_manager(DATABASE_FOR_FLAG_MATCHING).filter(
team_id=self.team_id
)
return {row.group_type: row.group_type_index for row in group_type_mapping_rows}
Expand Down Expand Up @@ -412,12 +412,14 @@ def query_conditions(self) -> dict[str, bool]:
with execute_with_timeout(FLAG_MATCHING_QUERY_TIMEOUT_MS * 2, DATABASE_FOR_FLAG_MATCHING):
all_conditions: dict = {}
team_id = self.feature_flags[0].team_id
person_query: QuerySet = Person.objects.using(DATABASE_FOR_FLAG_MATCHING).filter(
person_query: QuerySet = Person.objects.db_manager(DATABASE_FOR_FLAG_MATCHING).filter(
team_id=team_id,
persondistinctid__distinct_id=self.distinct_id,
persondistinctid__team_id=team_id,
)
basic_group_query: QuerySet = Group.objects.using(DATABASE_FOR_FLAG_MATCHING).filter(team_id=team_id)
basic_group_query: QuerySet = Group.objects.db_manager(DATABASE_FOR_FLAG_MATCHING).filter(
team_id=team_id
)
group_query_per_group_type_mapping: dict[GroupTypeIndex, tuple[QuerySet, list[str]]] = {}
# :TRICKY: Create a queryset for each group type that uniquely identifies a group, based on the groups passed in.
# If no groups for a group type are passed in, we can skip querying for that group type,
Expand Down Expand Up @@ -546,7 +548,7 @@ def condition_eval(key, condition):
if not self.cohorts_cache and any(feature_flag.uses_cohorts for feature_flag in self.feature_flags):
all_cohorts = {
cohort.pk: cohort
for cohort in Cohort.objects.using(DATABASE_FOR_FLAG_MATCHING).filter(
for cohort in Cohort.objects.db_manager(DATABASE_FOR_FLAG_MATCHING).filter(
team_id=team_id, deleted=False
)
}
Expand Down Expand Up @@ -692,7 +694,7 @@ def get_feature_flag_hash_key_overrides(

if not person_id_to_distinct_id_mapping:
person_and_distinct_ids = list(
PersonDistinctId.objects.using(using_database)
PersonDistinctId.objects.db_manager(using_database)
.filter(distinct_id__in=distinct_ids, team_id=team_id)
.values_list("person_id", "distinct_id")
)
Expand All @@ -703,7 +705,7 @@ def get_feature_flag_hash_key_overrides(
person_ids = list(person_id_to_distinct_id.keys())

for feature_flag, override, _ in sorted(
FeatureFlagHashKeyOverride.objects.using(using_database)
FeatureFlagHashKeyOverride.objects.db_manager(using_database)
.filter(person_id__in=person_ids, team_id=team_id)
.values_list("feature_flag_key", "hash_key", "person_id"),
key=lambda x: 1 if person_id_to_distinct_id.get(x[2], "") == distinct_ids[0] else -1,
Expand Down Expand Up @@ -1027,7 +1029,7 @@ def get_all_properties_with_math_operators(
cohort_id = int(cast(Union[str, int], prop.value))
if cohorts_cache.get(cohort_id) is None:
queried_cohort = (
Cohort.objects.using(DATABASE_FOR_FLAG_MATCHING)
Cohort.objects.db_manager(DATABASE_FOR_FLAG_MATCHING)
.filter(pk=cohort_id, team_id=team_id, deleted=False)
.first()
)
Expand Down
2 changes: 1 addition & 1 deletion posthog/models/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def bootstrap(
"""Instead of doing the legwork of creating an organization yourself, delegate the details with bootstrap."""
from .project import Project # Avoiding circular import

with transaction.atomic():
with transaction.atomic(using=self.db):
organization = Organization.objects.create(**kwargs)
_, team = Project.objects.create_with_team(organization=organization, team_fields=team_fields)
organization_membership: Optional[OrganizationMembership] = None
Expand Down
2 changes: 1 addition & 1 deletion posthog/models/person/person.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class PersonManager(models.Manager):
def create(self, *args: Any, **kwargs: Any):
with transaction.atomic():
with transaction.atomic(using=self.db):
if not kwargs.get("distinct_ids"):
return super().create(*args, **kwargs)
distinct_ids = kwargs.pop("distinct_ids")
Expand Down
2 changes: 1 addition & 1 deletion posthog/models/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ProjectManager(models.Manager):
def create_with_team(self, team_fields: Optional[dict] = None, **kwargs) -> tuple["Project", "Team"]:
from .team import Team

with transaction.atomic():
with transaction.atomic(using=self.db):
common_id = Team.objects.increment_id_sequence()
project = self.create(id=common_id, **kwargs)
team = Team.objects.create(
Expand Down
4 changes: 2 additions & 2 deletions posthog/models/team/team.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def create_with_data(self, user: Any = None, default_dashboards: bool = True, **
def create(self, *args, **kwargs) -> "Team":
from ..project import Project

with transaction.atomic():
with transaction.atomic(using=self.db):
if "id" not in kwargs:
kwargs["id"] = self.increment_id_sequence()
if kwargs.get("project") is None and kwargs.get("project_id") is None:
Expand All @@ -115,7 +115,7 @@ def create(self, *args, **kwargs) -> "Team":
project_kwargs["organization_id"] = organization_id
if name := kwargs.get("name"):
project_kwargs["name"] = name
kwargs["project"] = Project.objects.create(id=kwargs["id"], **project_kwargs)
kwargs["project"] = Project.objects.db_manager(self.db).create(id=kwargs["id"], **project_kwargs)
return super().create(*args, **kwargs)

def get_team_from_token(self, token: Optional[str]) -> Optional["Team"]:
Expand Down
2 changes: 1 addition & 1 deletion posthog/plugins/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_decide_site_apps(team: "Team", using_database: str = "default") -> list[
from posthog.models import PluginConfig, PluginSourceFile

sources = (
PluginConfig.objects.using(using_database)
PluginConfig.objects.db_manager(using_database)
.filter(
team=team,
enabled=True,
Expand Down
10 changes: 7 additions & 3 deletions posthog/queries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,17 @@ def property_to_Q(
if cohorts_cache is not None:
if cohorts_cache.get(cohort_id) is None:
queried_cohort = (
Cohort.objects.using(using_database).filter(pk=cohort_id, team_id=team_id, deleted=False).first()
Cohort.objects.db_manager(using_database)
.filter(pk=cohort_id, team_id=team_id, deleted=False)
.first()
)
cohorts_cache[cohort_id] = queried_cohort or ""

cohort = cohorts_cache[cohort_id]
else:
cohort = Cohort.objects.using(using_database).filter(pk=cohort_id, team_id=team_id, deleted=False).first()
cohort = (
Cohort.objects.db_manager(using_database).filter(pk=cohort_id, team_id=team_id, deleted=False).first()
)

if not cohort:
# Don't match anything if cohort doesn't exist
Expand All @@ -312,7 +316,7 @@ def property_to_Q(
if cohort.is_static:
return Q(
Exists(
CohortPeople.objects.using(using_database)
CohortPeople.objects.db_manager(using_database)
.filter(
cohort_id=cohort_id,
person_id=OuterRef("id"),
Expand Down

0 comments on commit e4eb50e

Please sign in to comment.