前言 在实际开发中,根据业务拼接SQL
所需要考虑的内容太多了。于是,有没有一种办法,可以像MyBatisPlus
一样通过配置注解实现SQL
注入呢?
就像是:
1 2 @mybatis.select("select * from user where id = #{id}" ) def get_user (id ): ...
那可就降低了好多工作量。
P.S.:本文并不希望完全复现MyBatisPlus
的所有功能,能够基本配置SQL
注解就基本能够完成大部分工作了。
实现思路 那我们这么考虑:
首先,我们需要定义一个类,类中给一个或者多个装饰器;
我们先在类内定义一个字符串,这个字符串能够配置到指定的DTO
类,用于存储结果;
我们针对装饰器中的SQL
字符串进行解析,解析到其中的变量个数与名称;
我们针对被装饰的函数进行解析,与SQL
变量进行匹配;
替换变量;
执行SQL
;
听起来并不难。我们一步步来。
定义一个类 首先定义:
1 2 3 4 5 class Student : def __init__ (self, name, age ): self.name = name self.age = age
为了简化操作,这个类就不放在任意位置了,直接放在dto
文件夹下,后续导入这个类也就直接从dto
文件夹中引入,就不考虑做这个包名定位的接口了。
当然,为了更方便后续的操作,我们需要在dto
文件夹中定义一个__init__.py
文件,用于对外暴露这个类:
1 2 3 from dto.student import Student__all__ = ["Student" ]
最后呢,我们为了方便这个类的序列化,让他能够变成dict
类型,加一些魔法函数:
1 2 3 4 5 6 7 8 9 10 11 12 class Student : def __init__ (self, name, age ): self.name = name self.age = age def __iter__ (self ): for key, value in self.__dict__.items(): yield key, value def __getitem__ (self, key ): return getattr (self, key) def keys (self ): return self.__dict__.keys()
当然,一个项目里面肯定不止这一个返回结果,所以各位也可以这么操作:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 class CommonResult : def __init__ (self ): ... def __iter__ (self ): for key, value in self.__dict__.items(): yield key, value def __getitem__ (self, key ): return getattr (self, key) def keys (self ): return self.__dict__.keys() from dto.common import CommonResultclass Student (CommonResult ): def __init__ (self, name, age ): self.name = name self.age = age
至于实际业务中还有很多复杂的联立等操作需要新的类,受限于篇幅,就不展开了。如果能够把本篇看懂的话,相信各位也没什么其他的困难了。
然后开始手撸这个微型框架 1 2 3 4 5 6 7 8 from pydantic import BaseModel, Fieldclass DBManager (BaseModel ): base_type: str = Field(..., description="数据库表名" ) link: str = Field(..., description="数据库连接地址" ) local_generator: Any = Field(..., description="实体类实例化解析生成器" ) def search (query_template ): ...
在这里呢,我们定义了一个DBManager
作为父类,要求后面的子类必须有:
str
类型的base_type
,表示返回结果类的名称;
str
类型的link
,表示数据库连接地址;
Any
类型的local_generator
,表示实体类实例化解析生成器,- 任意返回值的query
方法,用于执行SQL
。
为什么一定要用BaseModel
定义?直接定义self.xxx
不好吗?
因为这样会看起来代码量很大(逃)
看着差不多。
根据字符串获取到所定义的DTO
类 考虑到实际上我们所有的方法都需要特定到具体的位置,所以这个方法还是直接写到DBManager
类中,这样子类就不需要再重写了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 from pydantic import BaseModel, Fieldclass DBManager (BaseModel ): base_type: str = Field(..., description="数据库表名" ) link: str = Field(..., description="数据库连接地址" ) local_generator: Any = Field(..., description="实体类实例化解析生成器" ) def search (query_template ): ... def import_class_from_package (self, package_name, class_name ): _package = importlib.import_module(package_name) if class_name not in _package.__all__: raise ImportError(f"{class_name} not found in {package_name} " ) cls = getattr (_package, class_name) if cls is not None : return cls else : raise ImportError(f"{class_name} not found in {package_name} " )
这样子类就可以调用这个方法获得所需的类了。
构建返回结果 既然都已经能够动态导入类了,那我把返回结果导入到Student
中,没问题吧?
其中需要注意的是,我这边采用的数据库驱动是sqlalchemy
,所以构造返回结果所需要的参数是sqlalchemy
的Row
类型。
同样的,为了减少子类重写的代码量,直接在父类给出来:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 from pydantic import BaseModel, Fieldfrom sqlalchemy.engine.row import Rowclass DBManager (BaseModel ): base_type: str = Field(..., description="数据库表名" ) link: str = Field(..., description="数据库连接地址" ) local_generator: Any = Field(..., description="实体类实例化解析生成器" ) def search (query_template ): ... def import_class_from_package (self, package_name, class_name ): ... def build_obj (self, row: Row ): return self.local_generator(**row._asdict()) if self.local_generator else None
装饰器 那么接下来就是重头戏了,怎么定义这个装饰器。
我们先构建一个子类:
1 2 3 4 5 6 7 8 9 10 class StudentDBManager (DBManager ): base_type: ClassVar[str ] = "Student" link: ClassVar[str ] = 'sqlite:///school.db' local_generator: ClassVar[Any ] = None """ 自定义PyMyBatis """ def __init__ (self ): StudentDBManager.local_generator = self.import_class_from_package("dto" , self.base_type)
在这里,首先需要注意的是,需要用ClassVar
修饰,将变量名定义为类内成员变量,否则无法使用self.xxx
访问。
其次,我们利用base_type
指定返回值对应的DTO
类、link
指定数据库连接地址,local_generator
指定实体类实例化解析生成器。
在这个类实例化的过程中,我们还需要进一步构建local_generator
,也就是动态执行from xxx import xxx
。
然后定义一个装饰器:
1 2 3 4 5 6 7 def query (query_template: str ): def decorator (func ): @wraps(func ) def wrapper (*args, **kwargs ): return func(*args, **kwargs) return wrapper return decorator
这可以算得上是比较基础的模板了。至于之后怎么改,管他呢,先套公式。
在这里,我们首先定义的装饰器是decorator
,没有参数;其次再用query
装饰器包装,从而给无参的装饰器给一个参数,从而接收一个SQL
字符串参数。
好的,我们再进一步。
解析字符串,获得变量 首先当然是解析SQL
字符串,获得变量。如何做呢?为了简便,这里直接采用正则匹配的方式:
1 2 3 4 5 6 7 8 9 10 def query (self, query_template ): def decorator (func ): param_pattern = re.compile (r"#{(\w+)}" ) required_params = set (param_pattern.findall(query_template)) @wraps(func ) def wrapper (*args, **kwargs ): return func(*args, **kwargs) return wrapper return decorator
没啥问题。
接下来,调用的时候,我们需要检测是否完整给出了SQL
字符串所需的参数。
我们考虑到,如果但凡SQL
中的参数有变化,方法就会有变化,因此每个SQL
都有一个方法也太麻烦了。主要是这么多相似的方法起方法名太烦了
所以,直接上反射 ,获取 调用 的时侯传入的参数。
值得注意的是,这里说的是 调用 的时候。因为Python
中 定义 方法的时候可以使用**kargs
传入多个参数,但是如果反射直接获取到 定义 的参数,将会只有一个kargs
,这显然不是我们所希望的。
所以,再加一些:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def query (self, query_template ): def decorator (func ): param_pattern = re.compile (r"#{(\w+)}" ) required_params = set (param_pattern.findall(query_template)) @wraps(func ) def wrapper (*args, **kwargs ): sig = inspect.signature(func) bound_args = sig.bind_partial(*args, **kwargs) bound_args.apply_defaults() provided_params = set (bound_args.arguments.keys()) | set (kwargs.keys()) missing_params = required_params - provided_params if missing_params: raise ValueError(f"Missing required parameters: {', ' .join(missing_params)} " ) return func(*args, **kwargs) return wrapper return decorator
这下应该就能够适配到所有的SQL
情况了。
SQL字符串拼接 接下来就是直接替换值了。但是,拼接真的就是对的吗?我们不光是需要考虑不同的变量有着不同的植入格式,同时也需要考虑到植入过程中可能的SQL
注入问题。
所以,我们就直接采用sqlalchemy
的text
函数,对SQL
进行拼接与赋值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def query (self, query_template ): def decorator (func ): param_pattern = re.compile (r"#{(\w+)}" ) required_params = set (param_pattern.findall(query_template)) @wraps(func ) def wrapper (*args, **kwargs ): sig = inspect.signature(func) bound_args = sig.bind_partial(*args, **kwargs) bound_args.apply_defaults() provided_params = set (bound_args.arguments.keys()) | set (kwargs.keys()) missing_params = required_params - provided_params if missing_params: raise ValueError(f"Missing required parameters: {', ' .join(missing_params)} " ) sql_query = text(query_template.replace("#{" , ":" ).replace("}" , "" )) print (f"Executing SQL: {sql_query} " ) return func(*args, **kwargs) return wrapper return decorator
好了,到这一步也就基本完成了。最后,我们根据数据库存储数据的特点,最后修整一下查询的格式细节,就可以了:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 def query (self, query_template ): def decorator (func ): param_pattern = re.compile (r"#{(\w+)}" ) required_params = set (param_pattern.findall(query_template)) @wraps(func ) def wrapper (*args, **kwargs ): sig = inspect.signature(func) bound_args = sig.bind_partial(*args, **kwargs) bound_args.apply_defaults() provided_params = set (bound_args.arguments.keys()) | set (kwargs.keys()) missing_params = required_params - provided_params if missing_params: raise ValueError(f"Missing required parameters: {', ' .join(missing_params)} " ) sql_query = text(query_template.replace("#{" , ":" ).replace("}" , "" )) print (f"Executing SQL: {sql_query} " ) params = bound_args.arguments.copy() for key, value in params.items(): if isinstance (value, datetime): params[key] = value.strftime('%Y-%m-%d' ) engine = create_engine(self.link) with engine.connect() as conn: result = conn.execute(sql_query, params) search_result = [self.create_item_obj(row) for row in result] return search_result return wrapper return decorator
就是这样,我们就完成了这样一个装饰器。
使用装饰器 使用过程,其实就可以类比@Service
中的调用了。而如果拿Python
举例的话,其实更像Flask
的app.route
。于是我们可以这么使用:
1 2 3 sbd = StudentDBManager() @sbd.query("SELECT * FROM student WHERE id = #{id}" ) def find_student_by_id (**kargs ): ...
这也就实现了一个方法。
当然,他也没那么智能。虽然写起来是这样,但是依然相当于:
1 2 3 sbd = StudentDBManager() @sbd.query("SELECT * FROM student WHERE id = #{id}" ) def find_student_by_id (id : str ): ...
只是说,我们并不需要重复地去写驱动罢了。