Compare commits
No commits in common. "13bc3321babea5f80aa8336cf5bc79f2bd520046" and "5d7698462a37d49a90d3336ab00c00e57a6a764e" have entirely different histories.
13bc3321ba
...
5d7698462a
57
main.py
57
main.py
|
|
@ -1,19 +1,17 @@
|
||||||
import json
|
from functools import wraps
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import wraps
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
import telegram
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from telegram import (InlineKeyboardButton, InlineKeyboardMarkup, ParseMode,
|
from telegram import Update, user
|
||||||
Update)
|
import telegram
|
||||||
from telegram.ext import (CallbackContext, CallbackQueryHandler,
|
from telegram.ext import (CallbackContext, CommandHandler, MessageHandler,
|
||||||
CommandHandler, MessageHandler, Updater)
|
Updater)
|
||||||
|
|
||||||
from models import Link
|
from models import Link
|
||||||
|
|
||||||
|
|
@ -27,13 +25,12 @@ link_regex = re.compile(
|
||||||
|
|
||||||
def private_only(f):
|
def private_only(f):
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def cb(self, update: Update, *args, **kwargs):
|
def cb(update: Update, *args, **kwargs):
|
||||||
if update.effective_chat.type != 'private':
|
if update.effective_chat.type != 'private':
|
||||||
return
|
return
|
||||||
return f(self, update, *args, **kwargs)
|
return f(update, *args, **kwargs)
|
||||||
return cb
|
return cb
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BotSettings:
|
class BotSettings:
|
||||||
token: str
|
token: str
|
||||||
|
|
@ -54,13 +51,12 @@ class Bot:
|
||||||
for f in dir(self):
|
for f in dir(self):
|
||||||
if f.startswith("cmd_"):
|
if f.startswith("cmd_"):
|
||||||
self.dispatcher.add_handler(
|
self.dispatcher.add_handler(
|
||||||
CommandHandler(f[4:], getattr(self, f)))
|
CommandHandler(f[4:], private_only(getattr(self, f))))
|
||||||
|
|
||||||
def _register_handlers(self):
|
def _register_handlers(self):
|
||||||
self.dispatcher.add_handler(MessageHandler(
|
self.dispatcher.add_handler(MessageHandler(
|
||||||
filters=None,
|
filters=None,
|
||||||
callback=self.message))
|
callback=self.message))
|
||||||
self.dispatcher.add_handler(CallbackQueryHandler(self.update))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db(self):
|
def db(self):
|
||||||
|
|
@ -79,13 +75,11 @@ class Bot:
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@private_only
|
|
||||||
def cmd_start(self, update: Update, context: CallbackContext):
|
def cmd_start(self, update: Update, context: CallbackContext):
|
||||||
context.bot.send_message(chat_id=update.effective_chat.id, text='Hi!')
|
context.bot.send_message(chat_id=update.effective_chat.id, text='Hi!')
|
||||||
# TODO: timezones
|
# TODO: timezones
|
||||||
# TODO: ask to set time or something
|
# TODO: ask to set time or something
|
||||||
|
|
||||||
@private_only
|
|
||||||
def cmd_test(self, update: Update, context: CallbackContext):
|
def cmd_test(self, update: Update, context: CallbackContext):
|
||||||
l = Link(
|
l = Link(
|
||||||
link="2137",
|
link="2137",
|
||||||
|
|
@ -94,7 +88,6 @@ class Bot:
|
||||||
with self.db as db:
|
with self.db as db:
|
||||||
l.create(db)
|
l.create(db)
|
||||||
|
|
||||||
@private_only
|
|
||||||
def cmd_unread(self, update: Update, context: CallbackContext):
|
def cmd_unread(self, update: Update, context: CallbackContext):
|
||||||
# TODO: ignore messages from group
|
# TODO: ignore messages from group
|
||||||
user_id = update.effective_user.id
|
user_id = update.effective_user.id
|
||||||
|
|
@ -105,16 +98,10 @@ class Bot:
|
||||||
unread = Link.get_unread(db, user_id)
|
unread = Link.get_unread(db, user_id)
|
||||||
|
|
||||||
bot.send_message(
|
bot.send_message(
|
||||||
user_id, f"*Your unread links as of {datetime.now().isoformat()}:*", parse_mode=ParseMode.MARKDOWN)
|
user_id, f"**Your unread links as of {datetime.now().isoformat()}:", parse_mode="MARKDOWN")
|
||||||
for link in unread:
|
for link in unread:
|
||||||
keyboard = [
|
bot.send_message(user_id, link.link)
|
||||||
[InlineKeyboardButton(text='Mark as read',
|
# TODO: button to read or postpone
|
||||||
callback_data=f"mark_as_read:{link.id}"),
|
|
||||||
InlineKeyboardButton(text='Delete', callback_data=f"delete:{link.id}")]
|
|
||||||
]
|
|
||||||
bot.send_message(user_id, link.link,
|
|
||||||
reply_markup=InlineKeyboardMarkup(keyboard))
|
|
||||||
# TODO: button to postpone
|
|
||||||
# TODO: button to mark all as read or postpone
|
# TODO: button to mark all as read or postpone
|
||||||
|
|
||||||
def _natural_count(self, n, singular, plural):
|
def _natural_count(self, n, singular, plural):
|
||||||
|
|
@ -124,6 +111,8 @@ class Bot:
|
||||||
|
|
||||||
@private_only
|
@private_only
|
||||||
def message(self, update: Update, context: CallbackContext):
|
def message(self, update: Update, context: CallbackContext):
|
||||||
|
if update.effective_chat.type != 'private':
|
||||||
|
return
|
||||||
user_id = update.effective_user.id
|
user_id = update.effective_user.id
|
||||||
links = re.findall(link_regex, update.message.text)
|
links = re.findall(link_regex, update.message.text)
|
||||||
with self.db as db:
|
with self.db as db:
|
||||||
|
|
@ -136,26 +125,6 @@ class Bot:
|
||||||
context.bot.send_message(
|
context.bot.send_message(
|
||||||
update.effective_chat.id, f"Added {self._natural_count(len(links), 'link', 'links')} to your list.")
|
update.effective_chat.id, f"Added {self._natural_count(len(links), 'link', 'links')} to your list.")
|
||||||
|
|
||||||
@private_only
|
|
||||||
def update(self, update: Update, context: CallbackContext):
|
|
||||||
user_id = update.effective_user.id
|
|
||||||
action, link_id = update.callback_query.data.split(":", 1)
|
|
||||||
with self.db as db:
|
|
||||||
l = Link.get(db, user_id, link_id)
|
|
||||||
if l is None:
|
|
||||||
context.bot.send_message(
|
|
||||||
user_id, "Couldn't find the link you were looking for.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if action == "mark_as_read":
|
|
||||||
l.mark_as_read(db)
|
|
||||||
update.callback_query.message.delete()
|
|
||||||
elif action == "delete":
|
|
||||||
l.delete(db)
|
|
||||||
update.callback_query.message.delete()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.updater.start_polling()
|
self.updater.start_polling()
|
||||||
|
|
||||||
|
|
|
||||||
72
models.py
72
models.py
|
|
@ -1,8 +1,7 @@
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from os import link
|
from typing import Tuple
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -13,39 +12,6 @@ class Link:
|
||||||
read_at: datetime = None
|
read_at: datetime = None
|
||||||
added_at: datetime = None
|
added_at: datetime = None
|
||||||
|
|
||||||
def as_dict(self):
|
|
||||||
def timestamp(t): return t.timestamp() if t is not None else None
|
|
||||||
return {
|
|
||||||
"id": self.id,
|
|
||||||
"user_id": self.user_id,
|
|
||||||
"read_at": timestamp(self.read_at),
|
|
||||||
"added_at": timestamp(self.added_at),
|
|
||||||
"link": self.link,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, dict: Dict) -> 'Link':
|
|
||||||
def convert_date(t): return datetime.fromtimestamp(
|
|
||||||
t) if t is not None else None
|
|
||||||
return Link(
|
|
||||||
id=dict.get('id', None),
|
|
||||||
user_id=dict["user_id"],
|
|
||||||
read_at=convert_date(dict.get("read_at", None)),
|
|
||||||
added_at=convert_date(dict.get("added_at", None)),
|
|
||||||
link=dict["link"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _from_tuple(cls, tuple) -> 'Link':
|
|
||||||
id, link, user_id, read_at, added_at = tuple
|
|
||||||
return cls(
|
|
||||||
id=id,
|
|
||||||
link=link,
|
|
||||||
user_id=user_id,
|
|
||||||
read_at=read_at,
|
|
||||||
added_at=added_at
|
|
||||||
)
|
|
||||||
|
|
||||||
def create(self, db: sqlite3.Connection):
|
def create(self, db: sqlite3.Connection):
|
||||||
self.added_at = datetime.now()
|
self.added_at = datetime.now()
|
||||||
cur = db.cursor()
|
cur = db.cursor()
|
||||||
|
|
@ -59,35 +25,23 @@ class Link:
|
||||||
assert int(r[0][0])
|
assert int(r[0][0])
|
||||||
self.id = r[0][0]
|
self.id = r[0][0]
|
||||||
|
|
||||||
def mark_as_read(self, db: sqlite3.Connection):
|
@classmethod
|
||||||
assert self.id is not None
|
def _from_tuple(cls, tuple) -> 'Link':
|
||||||
self.read_at = datetime.now()
|
id, link, user_id, read_at, added_at = tuple
|
||||||
cur = db.cursor()
|
return cls(
|
||||||
cur.execute(
|
id=id,
|
||||||
"UPDATE links SET read_at=? WHERE user_id=? AND id=?", (self.read_at, self.user_id, self.id))
|
link=link,
|
||||||
|
user_id=user_id,
|
||||||
def delete(self, db: sqlite3.Connection):
|
read_at=read_at,
|
||||||
assert self.id is not None
|
added_at=added_at
|
||||||
cur = db.cursor()
|
)
|
||||||
cur.execute(
|
|
||||||
"DELETE FROM links WHERE user_id=? AND id=?", (self.user_id, self.id))
|
|
||||||
self.id = None
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get(cls, db: sqlite3.Connection, where: str, values: Tuple = ()) -> List['Link']:
|
def _get(cls, db: sqlite3.Connection, where: str, values: Tuple = ()) -> 'Link':
|
||||||
cur = db.cursor()
|
cur = db.cursor()
|
||||||
rows = cur.execute(f"SELECT * FROM links WHERE {where}", values)
|
rows = cur.execute(f"SELECT * FROM links WHERE {where}", values)
|
||||||
return map(cls._from_tuple, rows)
|
return map(cls._from_tuple, rows)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_unread(cls, db: sqlite3.Connection, user_id: str) -> List['Link']:
|
def get_unread(cls, db: sqlite3.Connection, user_id: str):
|
||||||
return cls._get(db, "user_id = ? AND read_at IS NULL", (user_id,))
|
return cls._get(db, "user_id = ? AND read_at IS NULL", (user_id,))
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls, db: sqlite3.Connection, user_id: str, link_id: int) -> 'Link':
|
|
||||||
rows = list(cls._get(db, "user_id = ? AND id = ?", (user_id, link_id)))
|
|
||||||
if len(rows) < 1:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return rows[0]
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue