Skip to content

Commit

Permalink
Merge pull request #6606 from hotosm/fastapi-refactor
Browse files Browse the repository at this point in the history
Project and notification duplication handled, organisation teams and …
  • Loading branch information
prabinoid authored Oct 25, 2024
2 parents 9c3ec17 + 18b2ecb commit ecb376a
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 120 deletions.
5 changes: 0 additions & 5 deletions backend/api/users/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@
from backend.models.dtos.user_dto import AuthUserDTO, UserSearchQuery
from backend.services.project_service import ProjectService
from backend.services.users.authentication_service import login_required

# from backend.services.users.authentication_service import token_auth
from backend.services.users.user_service import UserService

# from flask_restful import , current_app, request
# from schematics.exceptions import DataError


router = APIRouter(
prefix="/users",
Expand Down
4 changes: 2 additions & 2 deletions backend/api/users/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ async def get(
task_status=status,
start_date=start_date,
end_date=end_date,
page=request.query_params.get("page", 1),
page_size=request.query_params.get("page_size", 10),
page=int(request.query_params.get("page", 1)),
page_size=int(request.query_params.get("page_size", 10)),
sort_by=sort_by,
db=db,
)
Expand Down
6 changes: 3 additions & 3 deletions backend/models/dtos/message_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ class MessageDTO(BaseModel):
"""DTO used to define a message that will be sent to a user"""

message_id: Optional[int] = Field(None, alias="messageId")
subject: str = Field(min_length=1, alias="subject")
message: str = Field(min_length=1, alias="message")
from_user_id: int = Field(alias="fromUserId")
subject: Optional[str] = Field(min_length=1, alias="subject")
message: Optional[str] = Field(min_length=1, alias="message")
from_user_id: Optional[int] = Field(alias="fromUserId")
from_username: Optional[str] = Field("", alias="fromUsername")
display_picture_url: Optional[str] = Field("", alias="displayPictureUrl")
project_id: Optional[int] = Field(None, alias="projectId")
Expand Down
30 changes: 15 additions & 15 deletions backend/models/postgis/message.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
from sqlalchemy.sql.expression import false
from enum import Enum

from databases import Database
from loguru import logger
from sqlalchemy import (
BigInteger,
Boolean,
Column,
Integer,
String,
DateTime,
ForeignKey,
BigInteger,
Boolean,
ForeignKeyConstraint,
Integer,
String,
)
from sqlalchemy.orm import relationship
from loguru import logger
from enum import Enum
from sqlalchemy.sql.expression import false

from backend.db import Base, get_session
from backend.exceptions import NotFound
from backend.models.dtos.message_dto import MessageDTO, MessagesDTO
from backend.models.postgis.user import User
from backend.models.postgis.task import Task, TaskAction
from backend.models.postgis.project import Project
from backend.models.postgis.task import Task, TaskAction
from backend.models.postgis.user import User
from backend.models.postgis.utils import timestamp
from backend.db import Base, get_session
from databases import Database

session = get_session()

Expand Down Expand Up @@ -272,10 +272,10 @@ async def mark_all_messages_read(
AND read = FALSE
"""

params = {"user_id": user_id}

if message_type_filters:
query += " AND message_type = ANY(:message_type_filters)"
params["message_type_filters"] = message_type_filters

await db.execute(
query,
{"user_id": user_id, "message_type_filters": message_type_filters},
)
await db.execute(query, params)
14 changes: 7 additions & 7 deletions backend/models/postgis/organisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,20 @@ async def update(organisation_dto: UpdateOrganisationDTO, db: Database):
if key not in ["organisation_id", "managers"]
}
set_clause = ", ".join(f"{key} = :{key}" for key in update_keys.keys())
update_query = f"""
UPDATE organisations
SET {set_clause}
WHERE id = :id
"""
await db.execute(update_query, values={**update_keys, "id": org_id})
if set_clause:
update_query = f"""
UPDATE organisations
SET {set_clause}
WHERE id = :id
"""
await db.execute(update_query, values={**update_keys, "id": org_id})

if organisation_dto.managers:
clear_managers_query = """
DELETE FROM organisation_managers
WHERE organisation_id = :id
"""
await db.execute(clear_managers_query, values={"id": org_id})

for manager_username in organisation_dto.managers:
user_query = "SELECT id FROM users WHERE username = :username"
user = await db.fetch_one(
Expand Down
13 changes: 10 additions & 3 deletions backend/models/postgis/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,9 +1457,16 @@ def get_project_teams(self):

return project_teams

def get_project_title(self, preferred_locale):
project_info = ProjectInfo.get_dto_for_locale(
self.id, preferred_locale, self.default_locale
# def get_project_title(self, preferred_locale):
# project_info = ProjectInfo.get_dto_for_locale(
# self.id, preferred_locale, self.default_locale
# )
# return project_info.name

@staticmethod
async def get_project_title(db: Database, project_id: int, preferred_locale):
project_info = await ProjectInfo.get_dto_for_locale(
db, project_id, preferred_locale
)
return project_info.name

Expand Down
83 changes: 33 additions & 50 deletions backend/models/postgis/project_info.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
# # from flask import current_app
from sqlalchemy.dialects.postgresql import TSVECTOR
from typing import List

from databases import Database
from sqlalchemy import (
Column,
String,
Integer,
ForeignKey,
Index,
inspect,
Integer,
String,
insert,
inspect,
update,
)
from backend.models.dtos.project_dto import ProjectInfoDTO
from sqlalchemy.dialects.postgresql import TSVECTOR

from backend.db import Base, get_session
from backend.models.dtos.project_dto import ProjectInfoDTO

session = get_session()
from databases import Database


class ProjectInfo(Base):
Expand Down Expand Up @@ -107,13 +109,10 @@ async def get_dto_for_locale(
:return: ProjectInfoDTO
:raises: ValueError if no info found for Default Locale
"""

# Define the SQL query to get project info by locale
query = """
SELECT * FROM project_info
WHERE project_id = :project_id AND locale = :locale
"""

# Execute the query for the requested locale
project_info = await db.fetch_one(
query, values={"project_id": project_id, "locale": locale}
Expand All @@ -140,7 +139,6 @@ async def get_dto_for_locale(
if locale == default_locale:
# Return the DTO for the default locale
return ProjectInfoDTO(**project_info)

# Define the SQL query to get project info by default locale for partial translations
query_default = """
SELECT * FROM project_info
Expand All @@ -157,46 +155,31 @@ async def get_dto_for_locale(
error_message = f"BAD DATA: no info for project {project_id}, locale: {locale}, default {default_locale}"
raise ValueError(error_message)

combined_info = {**default_locale_info, **project_info}
return ProjectInfoDTO(**combined_info)

# def get_dto(self, default_locale=ProjectInfoDTO()) -> ProjectInfoDTO:
# """
# Get DTO for current ProjectInfo
# :param default_locale: The default locale string for any empty fields
# """
# project_info_dto = ProjectInfoDTO()
# project_info_dto.locale = self.locale
# project_info_dto.name = self.name if self.name else default_locale.name
# project_info_dto.description = (
# self.description if self.description else default_locale.description
# )
# project_info_dto.short_description = (
# self.short_description
# if self.short_description
# else default_locale.short_description
# )
# project_info_dto.instructions = (
# self.instructions if self.instructions else default_locale.instructions
# )
# project_info_dto.per_task_instructions = (
# self.per_task_instructions
# if self.per_task_instructions
# else default_locale.per_task_instructions
# )

# return project_info_dto

# @staticmethod
# def get_dto_for_all_locales(project_id) -> List[ProjectInfoDTO]:
# locales = ProjectInfo.query.filter_by(project_id=project_id).all()

# project_info_dtos = []
# for locale in locales:
# project_info_dto = locale.get_dto()
# project_info_dtos.append(project_info_dto)

# return project_info_dtos
combined_info = ProjectInfoDTO(locale=project_info.locale)
combined_info.name = (
project_info.name if project_info.name else default_locale_info.name
)
combined_info.description = (
project_info.description
if project_info.description
else default_locale_info.description
)
combined_info.short_description = (
project_info.short_description
if project_info.short_description
else default_locale_info.short_description
)
combined_info.instructions = (
project_info.instructions
if project_info.instructions
else default_locale_info.instructions
)
combined_info.per_task_instructions = (
project_info.per_task_instructions
if project_info.per_task_instructions
else default_locale_info.per_task_instructions
)
return combined_info

# Function to get a single ProjectInfoDTO
async def get_project_info_dto(locale_record) -> ProjectInfoDTO:
Expand Down
20 changes: 13 additions & 7 deletions backend/services/messaging/message_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,12 +812,9 @@ async def get_all_messages(
m.read,
m.project_id,
u.username AS from_username,
u.picture_url AS display_picture_url,
pi.name AS project_title
u.picture_url AS display_picture_url
FROM
messages m
LEFT JOIN
project_info pi ON m.project_id = pi.project_id
LEFT JOIN
users u ON m.from_user_id = u.id
WHERE
Expand Down Expand Up @@ -859,13 +856,22 @@ async def get_all_messages(
messages_dto = MessagesDTO()
for msg in messages:
message_dict = dict(msg)
print(message_dict)
if message_dict["message_type"]:
message_dict["message_type"] = MessageType(
message_dict["message_type"]
).name
msg_dto = MessageDTO(**message_dict).dict(
exclude={"from_user_id"}, by_alias=True
)
if message_dict["project_id"]:
try:
message_dict["project_title"] = (
await Project.get_project_title(
db, message_dict["project_id"], locale
)
or ""
)
except:
pass
msg_dto = MessageDTO(**message_dict).copy(exclude={"from_user_id"})
messages_dto.user_messages.append(msg_dto)

total_count_query = """
Expand Down
47 changes: 27 additions & 20 deletions backend/services/project_search_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ async def create_search_query(db, user=None):
o.logo AS organisation_logo
FROM projects p
LEFT JOIN organisations o ON o.id = p.organisation_id
LEFT JOIN project_info pi ON pi.project_id = p.id
WHERE p.geometry IS NOT NULL
"""

Expand Down Expand Up @@ -148,7 +147,7 @@ async def create_result_dto(
project_info_dto = await ProjectInfo.get_dto_for_locale(
db, project.id, preferred_locale, project.default_locale
)
# project_obj = await Project.get(project.id, db)

list_dto = ListSearchResultDTO()
list_dto.project_id = project.id
list_dto.locale = project_info_dto.locale
Expand Down Expand Up @@ -255,10 +254,32 @@ async def _filter_projects(search_dto: ProjectSearchDTO, user, db: Database):
base_query, params = await ProjectSearchService.create_search_query(db, user)
# Initialize filter list and parameters dictionary
filters = []
# Filters based on search_dto
if search_dto.preferred_locale:
filters.append("pi.locale IN (:preferred_locale, 'en')")
params["preferred_locale"] = search_dto.preferred_locale

if search_dto.preferred_locale or search_dto.text_search:
subquery_filters = []
if search_dto.preferred_locale:
subquery_filters.append("locale IN (:preferred_locale, 'en')")
params["preferred_locale"] = search_dto.preferred_locale

if search_dto.text_search:
search_text = "".join(
char for char in search_dto.text_search if char not in "@|&!><\\():"
)
or_search = " | ".join([x for x in search_text.split(" ") if x])
subquery_filters.append(
"text_searchable @@ to_tsquery('english', :text_search) OR name ILIKE :text_search"
)
params["text_search"] = or_search

filters.append(
"""
p.id IN (
SELECT project_id
FROM project_info
WHERE {}
)
""".format(" AND ".join(subquery_filters))
)

if search_dto.project_statuses:
statuses = [
Expand Down Expand Up @@ -372,16 +393,6 @@ async def _filter_projects(search_dto: ProjectSearchDTO, user, db: Database):
for mapping_type in search_dto.mapping_types
)

if search_dto.text_search:
search_text = "".join(
char for char in search_dto.text_search if char not in "@|&!><\\():"
)
or_search = " | ".join([x for x in search_text.split(" ") if x])
filters.append(
"pi.text_searchable @@ to_tsquery('english', :text_search) OR pi.name ILIKE :text_search"
)
params["text_search"] = or_search

if search_dto.country:
filters.append(
"LOWER(:country) = ANY(ARRAY(SELECT LOWER(c) FROM unnest(p.country) AS c))"
Expand Down Expand Up @@ -716,15 +727,11 @@ async def _make_4326_polygon_from_bbox(
async def _get_area_sqm(polygon: Polygon, db: Database) -> float:
"""Get the area of the polygon in square meters."""
try:
# Convert the polygon to its WKT format

geometry_wkt = polygon.wkt

# Prepare the raw SQL query to calculate the area
query = "SELECT ST_Area(ST_Transform(ST_GeomFromText(:wkt, 4326), 3857)) AS area"
values = {"wkt": geometry_wkt}

# Execute the query asynchronously using encode databases
result = await db.fetch_one(query=query, values=values)
return result["area"]

Expand Down
Loading

0 comments on commit ecb376a

Please sign in to comment.