Compare commits

...

3 Commits

2 changed files with 103 additions and 26 deletions

57
main.py
View File

@ -1,17 +1,19 @@
from functools import wraps import json
import logging import logging
import os import os
import re import re
import sqlite3 import sqlite3
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from functools import wraps
from pprint import pprint from pprint import pprint
from dotenv import load_dotenv
from telegram import Update, user
import telegram import telegram
from telegram.ext import (CallbackContext, CommandHandler, MessageHandler, from dotenv import load_dotenv
Updater) from telegram import (InlineKeyboardButton, InlineKeyboardMarkup, ParseMode,
Update)
from telegram.ext import (CallbackContext, CallbackQueryHandler,
CommandHandler, MessageHandler, Updater)
from models import Link from models import Link
@ -25,12 +27,13 @@ link_regex = re.compile(
def private_only(f): def private_only(f):
@wraps(f) @wraps(f)
def cb(update: Update, *args, **kwargs): def cb(self, update: Update, *args, **kwargs):
if update.effective_chat.type != 'private': if update.effective_chat.type != 'private':
return return
return f(update, *args, **kwargs) return f(self, update, *args, **kwargs)
return cb return cb
@dataclass @dataclass
class BotSettings: class BotSettings:
token: str token: str
@ -51,12 +54,13 @@ class Bot:
for f in dir(self): for f in dir(self):
if f.startswith("cmd_"): if f.startswith("cmd_"):
self.dispatcher.add_handler( self.dispatcher.add_handler(
CommandHandler(f[4:], private_only(getattr(self, f)))) CommandHandler(f[4:], getattr(self, f)))
def _register_handlers(self): def _register_handlers(self):
self.dispatcher.add_handler(MessageHandler( self.dispatcher.add_handler(MessageHandler(
filters=None, filters=None,
callback=self.message)) callback=self.message))
self.dispatcher.add_handler(CallbackQueryHandler(self.update))
@property @property
def db(self): def db(self):
@ -75,11 +79,13 @@ class Bot:
); );
""") """)
@private_only
def cmd_start(self, update: Update, context: CallbackContext): def cmd_start(self, update: Update, context: CallbackContext):
context.bot.send_message(chat_id=update.effective_chat.id, text='Hi!') context.bot.send_message(chat_id=update.effective_chat.id, text='Hi!')
# TODO: timezones # TODO: timezones
# TODO: ask to set time or something # TODO: ask to set time or something
@private_only
def cmd_test(self, update: Update, context: CallbackContext): def cmd_test(self, update: Update, context: CallbackContext):
l = Link( l = Link(
link="2137", link="2137",
@ -88,6 +94,7 @@ class Bot:
with self.db as db: with self.db as db:
l.create(db) l.create(db)
@private_only
def cmd_unread(self, update: Update, context: CallbackContext): def cmd_unread(self, update: Update, context: CallbackContext):
# TODO: ignore messages from group # TODO: ignore messages from group
user_id = update.effective_user.id user_id = update.effective_user.id
@ -98,10 +105,16 @@ class Bot:
unread = Link.get_unread(db, user_id) unread = Link.get_unread(db, user_id)
bot.send_message( bot.send_message(
user_id, f"**Your unread links as of {datetime.now().isoformat()}:", parse_mode="MARKDOWN") user_id, f"*Your unread links as of {datetime.now().isoformat()}:*", parse_mode=ParseMode.MARKDOWN)
for link in unread: for link in unread:
bot.send_message(user_id, link.link) keyboard = [
# TODO: button to read or postpone [InlineKeyboardButton(text='Mark as read',
callback_data=f"mark_as_read:{link.id}"),
InlineKeyboardButton(text='Delete', callback_data=f"delete:{link.id}")]
]
bot.send_message(user_id, link.link,
reply_markup=InlineKeyboardMarkup(keyboard))
# TODO: button to postpone
# TODO: button to mark all as read or postpone # TODO: button to mark all as read or postpone
def _natural_count(self, n, singular, plural): def _natural_count(self, n, singular, plural):
@ -111,8 +124,6 @@ class Bot:
@private_only @private_only
def message(self, update: Update, context: CallbackContext): def message(self, update: Update, context: CallbackContext):
if update.effective_chat.type != 'private':
return
user_id = update.effective_user.id user_id = update.effective_user.id
links = re.findall(link_regex, update.message.text) links = re.findall(link_regex, update.message.text)
with self.db as db: with self.db as db:
@ -125,6 +136,26 @@ class Bot:
context.bot.send_message( context.bot.send_message(
update.effective_chat.id, f"Added {self._natural_count(len(links), 'link', 'links')} to your list.") update.effective_chat.id, f"Added {self._natural_count(len(links), 'link', 'links')} to your list.")
@private_only
def update(self, update: Update, context: CallbackContext):
user_id = update.effective_user.id
action, link_id = update.callback_query.data.split(":", 1)
with self.db as db:
l = Link.get(db, user_id, link_id)
if l is None:
context.bot.send_message(
user_id, "Couldn't find the link you were looking for.")
return
if action == "mark_as_read":
l.mark_as_read(db)
update.callback_query.message.delete()
elif action == "delete":
l.delete(db)
update.callback_query.message.delete()
def run(self): def run(self):
self.updater.start_polling() self.updater.start_polling()

View File

@ -1,7 +1,8 @@
import sqlite3 import sqlite3
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Tuple from os import link
from typing import Dict, List, Tuple
@dataclass @dataclass
@ -12,6 +13,39 @@ class Link:
read_at: datetime = None read_at: datetime = None
added_at: datetime = None added_at: datetime = None
def as_dict(self):
def timestamp(t): return t.timestamp() if t is not None else None
return {
"id": self.id,
"user_id": self.user_id,
"read_at": timestamp(self.read_at),
"added_at": timestamp(self.added_at),
"link": self.link,
}
@classmethod
def from_dict(cls, dict: Dict) -> 'Link':
def convert_date(t): return datetime.fromtimestamp(
t) if t is not None else None
return Link(
id=dict.get('id', None),
user_id=dict["user_id"],
read_at=convert_date(dict.get("read_at", None)),
added_at=convert_date(dict.get("added_at", None)),
link=dict["link"]
)
@classmethod
def _from_tuple(cls, tuple) -> 'Link':
id, link, user_id, read_at, added_at = tuple
return cls(
id=id,
link=link,
user_id=user_id,
read_at=read_at,
added_at=added_at
)
def create(self, db: sqlite3.Connection): def create(self, db: sqlite3.Connection):
self.added_at = datetime.now() self.added_at = datetime.now()
cur = db.cursor() cur = db.cursor()
@ -25,23 +59,35 @@ class Link:
assert int(r[0][0]) assert int(r[0][0])
self.id = r[0][0] self.id = r[0][0]
@classmethod def mark_as_read(self, db: sqlite3.Connection):
def _from_tuple(cls, tuple) -> 'Link': assert self.id is not None
id, link, user_id, read_at, added_at = tuple self.read_at = datetime.now()
return cls( cur = db.cursor()
id=id, cur.execute(
link=link, "UPDATE links SET read_at=? WHERE user_id=? AND id=?", (self.read_at, self.user_id, self.id))
user_id=user_id,
read_at=read_at, def delete(self, db: sqlite3.Connection):
added_at=added_at assert self.id is not None
) cur = db.cursor()
cur.execute(
"DELETE FROM links WHERE user_id=? AND id=?", (self.user_id, self.id))
self.id = None
@classmethod @classmethod
def _get(cls, db: sqlite3.Connection, where: str, values: Tuple = ()) -> 'Link': def _get(cls, db: sqlite3.Connection, where: str, values: Tuple = ()) -> List['Link']:
cur = db.cursor() cur = db.cursor()
rows = cur.execute(f"SELECT * FROM links WHERE {where}", values) rows = cur.execute(f"SELECT * FROM links WHERE {where}", values)
return map(cls._from_tuple, rows) return map(cls._from_tuple, rows)
@classmethod @classmethod
def get_unread(cls, db: sqlite3.Connection, user_id: str): def get_unread(cls, db: sqlite3.Connection, user_id: str) -> List['Link']:
return cls._get(db, "user_id = ? AND read_at IS NULL", (user_id,)) return cls._get(db, "user_id = ? AND read_at IS NULL", (user_id,))
@classmethod
def get(cls, db: sqlite3.Connection, user_id: str, link_id: int) -> 'Link':
rows = list(cls._get(db, "user_id = ? AND id = ?", (user_id, link_id)))
if len(rows) < 1:
return None
else:
return rows[0]