A Python port for Flash Attention for MPS systems
Before you get your hopes up, no, I have not built this yet. I am putting this out there in hopes of some talented determined person out there wants to carry on the legacy of Steve Jobs and help out all the Mac ML developers.
For context I was looking at adding MPS support for flash attention to one of the open source libraries I contribute to, only to notice there's no implementation of the same. Upon a preliminary deep dive, I have been exposed to how hard this project would be. One would need to:
- Write C bindings for the Swift code
- Create Python C extensions that call those C bindings
- Handle memory management between Python, C, and Metal
- Deal with Swift/Metal compilation as part of the Python package build and all the debugging the comes along with this.
However it seems like a very fun, challenging, interesting project that would really payoff! I'm not sure if I personally have the bandwidth to tackle something of this scale just yet, but I'd love to see someone be the flagbearer of porting stuff for MPS systems :))