Skip to content

Commit

Permalink
fix: formatting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Yash-1511 committed Dec 23, 2024
1 parent f66d90a commit 30c7a5d
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 65 deletions.
33 changes: 10 additions & 23 deletions api/core/tools/provider/builtin/x/tools/get_user_timeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Get User Timeline Tool for X (Twitter)
"""

from datetime import datetime
from typing import Any, Union

Expand Down Expand Up @@ -41,20 +42,16 @@ def _convert_tweet_to_dict(self, tweet: tweepy.Tweet) -> dict[str, Any]:
tweet_dict[field] = value.isoformat()
else:
tweet_dict[field] = value

# Handle media attachments
if "attachments" in tweet_dict and "media_keys" in tweet_dict["attachments"]:
media_keys = tweet_dict["attachments"]["media_keys"]
tweet_dict["media"] = []

for media_key in media_keys:
media_info = {
"media_key": media_key,
"type": None,
"url": None
}
media_info = {"media_key": media_key, "type": None, "url": None}
tweet_dict["media"].append(media_info)

return tweet_dict

def _invoke(
Expand All @@ -66,9 +63,7 @@ def _invoke(
try:
username = tool_parameters.get("username", "").strip().lstrip("@")
if not username:
return ToolInvokeMessage(
message="Username is required", status="error"
)
return ToolInvokeMessage(message="Username is required", status="error")

max_results = self._validate_max_results(tool_parameters.get("max_results"))

Expand Down Expand Up @@ -135,9 +130,7 @@ def _invoke(
username=username,
)
if not user_response.data:
return ToolInvokeMessage(
message=f"User @{username} not found", status="error"
)
return ToolInvokeMessage(message=f"User @{username} not found", status="error")

user_data = user_response.data
# Get user's tweets
Expand All @@ -147,9 +140,7 @@ def _invoke(
tweet_fields=tweet_fields,
user_fields=user_fields,
media_fields=media_fields,
exclude=[
"retweets"
], # Exclude retweets to get more original content
exclude=["retweets"], # Exclude retweets to get more original content
)

print(tweets_response.data)
Expand Down Expand Up @@ -183,11 +174,7 @@ def _invoke(
}
)
except tweepy.TweepyException as te:
return ToolInvokeMessage(
message=f"Twitter API error: {str(te)}", status="error"
)
return ToolInvokeMessage(message=f"Twitter API error: {str(te)}", status="error")

except Exception as e:
return ToolInvokeMessage(
message=f"Error retrieving user timeline: {str(e)}", status="error"
)
return ToolInvokeMessage(message=f"Error retrieving user timeline: {str(e)}", status="error")
8 changes: 2 additions & 6 deletions api/core/tools/provider/builtin/x/tools/like_tweet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ def _invoke(
)

except tweepy.TweepyException as te:
return ToolInvokeMessage(
message=f"Twitter API error: {str(te)}", status="error"
)
return ToolInvokeMessage(message=f"Twitter API error: {str(te)}", status="error")

except Exception as e:
return ToolInvokeMessage(
message=f"Error performing {action} action: {str(e)}", status="error"
)
return ToolInvokeMessage(message=f"Error performing {action} action: {str(e)}", status="error")
22 changes: 6 additions & 16 deletions api/core/tools/provider/builtin/x/tools/media_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def _invoke(
# Get file from parameters
media_file = tool_parameters.get("media_file")
if not media_file:
return ToolInvokeMessage(
message="No media file provided", status="error"
)
return ToolInvokeMessage(message="No media file provided", status="error")

# Validate file type
if not self._validate_file_type(media_file.type):
Expand All @@ -60,22 +58,18 @@ def _invoke(
# Download file content
file_content = download(media_file)
if not file_content:
return ToolInvokeMessage(
message="Failed to download media file", status="error"
)
return ToolInvokeMessage(message="Failed to download media file", status="error")

# Upload media
media = api.media_upload(
filename=media_file.filename or "media", # Use original filename if available
file=io.BytesIO(file_content)
file=io.BytesIO(file_content),
)

# Set alt text if provided
alt_text = tool_parameters.get("alt_text")
if alt_text and media.media_id:
api.create_media_metadata(
media_id=media.media_id, alt_text=alt_text
)
api.create_media_metadata(media_id=media.media_id, alt_text=alt_text)

response_data = {
"media_id": str(media.media_id),
Expand All @@ -87,11 +81,7 @@ def _invoke(
return self.create_json_message(response_data)

except tweepy.TweepyException as te:
return ToolInvokeMessage(
message=f"Twitter API error: {str(te)}", status="error"
)
return ToolInvokeMessage(message=f"Twitter API error: {str(te)}", status="error")

except Exception as e:
return ToolInvokeMessage(
message=f"Error uploading media: {str(e)}", status="error"
)
return ToolInvokeMessage(message=f"Error uploading media: {str(e)}", status="error")
24 changes: 6 additions & 18 deletions api/core/tools/provider/builtin/x/tools/post_tweet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,15 @@ def _parse_comma_separated_list(self, value: Optional[str]) -> list[str]:
return []
return [item.strip() for item in value.split(",") if item.strip()]

def _validate_poll_options(
self, options_str: Optional[str], duration: Optional[float]
) -> tuple[list[str], int]:
def _validate_poll_options(self, options_str: Optional[str], duration: Optional[float]) -> tuple[list[str], int]:
"""Validate poll options and duration"""
options = self._parse_comma_separated_list(options_str)

if not 2 <= len(options) <= 4:
raise ValueError("Poll must have between 2 and 4 options")

try:
duration_int = (
int(duration) if duration is not None else 1440
) # Default 24 hours
duration_int = int(duration) if duration is not None else 1440 # Default 24 hours
except (ValueError, TypeError):
raise ValueError("Poll duration must be a valid number")

Expand Down Expand Up @@ -77,9 +73,7 @@ def _invoke(
poll_options_str = tool_parameters.get("poll_options")
poll_duration = tool_parameters.get("poll_duration_minutes")
if poll_options_str:
poll_options, poll_duration = self._validate_poll_options(
poll_options_str, poll_duration
)
poll_options, poll_duration = self._validate_poll_options(poll_options_str, poll_duration)
else:
poll_options, poll_duration = [], None

Expand Down Expand Up @@ -114,14 +108,8 @@ def _invoke(
)

except ValueError as ve:
return ToolInvokeMessage(
message=f"Validation error: {str(ve)}", status="error"
)
return ToolInvokeMessage(message=f"Validation error: {str(ve)}", status="error")
except tweepy.TweepyException as te:
return ToolInvokeMessage(
message=f"Twitter API error: {str(te)}", status="error"
)
return ToolInvokeMessage(message=f"Twitter API error: {str(te)}", status="error")
except Exception as e:
return ToolInvokeMessage(
message=f"Unexpected error: {str(e)}", status="error"
)
return ToolInvokeMessage(message=f"Unexpected error: {str(e)}", status="error")
3 changes: 2 additions & 1 deletion api/core/tools/provider/builtin/x/x.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
X (Twitter) Tool Provider
"""

from typing import Any

from core.tools.errors import ToolProviderCredentialValidationError
Expand All @@ -26,4 +27,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None:
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
raise ToolProviderCredentialValidationError(str(e))
2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ zhipuai = "~2.1.5"
# Related transparent dependencies with pinned version
# required by main implementations
############################################################
tweepy = "^4.14.0"
tweepy = "~4.14.0"
[tool.poetry.group.indirect.dependencies]
kaleido = "0.2.1"
rank-bm25 = "~0.2.2"
Expand Down

0 comments on commit 30c7a5d

Please sign in to comment.