[D] Handle dozen of thousands of classes

Hello !

I'm working on a project of NLP classification with more or less 13k classes. The best model I had so far is a fine-tuned LLM encoder. However, with the number of classes I have now, it is very slow. So I searched for ways to deal with that, and found 2:

  1. Hierarchical Softmax
  2. Negative Sampling

However, both seems to have been used nearly only in the context of word2vec training, so I wonder if there is a reason why that would not work for a "classical" classification ? (or just my kind of problem too rare ?)

Also, I did find really few implementations of those with Pytorch, a fortiori with transformers... Is it because there is something better ? Do you know, if not, some recent implementations ?

Thank you in advance !