-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bounty] PyTorch & HuggingFace Interface #139
base: main
Are you sure you want to change the base?
Conversation
Hey, sorry for the delay. I haven't had a chance to check this properly yet. I'll be able to look next week. |
Sounds good. Let me know anything needed. Thank you |
exo/inference/pytorch/model/hf.py
Outdated
|
||
# self.past_key_values = DynamicCache() | ||
|
||
def forward_layers( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like this approach of generalising this so it works for other models without having to explicitly implement them.
Can you write a test for a model with a different architecture to make sure this generalises e.g. recurrent Gemma?
I wonder if we need a little bit of model-specific behaviour to enable this in general?
exo/inference/pytorch/helpers.py
Outdated
async for chunk in response.content.iter_chunked(8192): | ||
f.write(chunk) | ||
|
||
async def download_files(urls: List[str], output_paths: List[Path]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I can remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove
Sorry forgot to. Will do that now.
exo/inference/pytorch/inference.py
Outdated
self.shard = None | ||
self.model = None | ||
self.tokenizer = None | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these the only options? I think supporting e.g. Mac with mps would be great since then you can run heterogeneous clusters.
One thing to try at some point would be mixing MLX and PyTorch and see if they are interoperable with exactly the same model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With pytorch I don't think mac is fully rolled out yet. There seems to be some work arounds but CUDA and CPU are the only options on the pytorch download website. pytorch even stopped ROCm support for AMD
They have a nightly for testing MPS https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about this in the official "stable" docs: https://pytorch.org/docs/stable/notes/mps.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will try that but currently no mac to test. When I get through these other fixes though I can definitely add it for you or other mac users to test.
exo/inference/pytorch/model/hf.py
Outdated
# Load the model | ||
self.full_model = AutoModelForCausalLM.from_pretrained( | ||
shard.model_id, | ||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, are these the only options? Would want support across other platforms
exo/inference/pytorch/model/hf.py
Outdated
|
||
layers.append(layer) | ||
|
||
self.full_model.model.layers = nn.ModuleList(layers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the peak memory usage look like here? I'm not sure of the specifics of python if this is going to hold each layer twice. Not sure but perhaps setting them in place would be more memory efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They shouldn't be held twice as when the ensure_shard function is called in the infer_prompt or infer_tensor the init class function is called which loads the needed layers each time depending on the shard. Will make sure about memory limits though and usage.
exo/inference/pytorch/model/hf.py
Outdated
|
||
|
||
# Load the model | ||
self.full_model = AutoModelForCausalLM.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this download the entire model?
We have code to selectively download the model from HuggingFace so you don't have to download all layers on every device: exo/download
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this won't work with our download progress code. We show in the TUI what the download progress of the model is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this download the entire model? We have code to selectively download the model from HuggingFace so you don't have to download all layers on every device:
exo/download
Will look at using that code because yes it currently does download all the model
exo/inference/pytorch/model/utils.py
Outdated
import torch | ||
from torch.nn import functional as F | ||
|
||
def sample_logits(logits, temperature=1.0, top_k=0, top_p=1.0, alpha_f=0.0, alpha_p=0.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be imported from somewhere rather than copy-pasta into the codebase? It looks like boilerplate code from somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was testing it as the default values but will clean that part up. I will set it in the Interface class settings to be used.
& .\.venv\Scripts\Activate.ps1 | ||
|
||
# Install the package in the virtual environment | ||
pip install . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Windows? Did this work on windows? Curious if it works there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was testing on Windows but couldn't fully get it working right. Will test again and make sure as I switched to using Linux to do further dev
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AlexCheema I can confirm that it is working on Windows, but there are a few issues:
-
PyTorch 2.4 doesn't install. In order to get it working, it needs the nightly build
-
In main.py
for s in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(s, handle_exit)
isn't supported on windows.
If I change handle_exit()
to:
def handle_exit():
asyncio.ensure_future(shutdown(loop))
if platform.system() != "Windows":
for s in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(s, handle_exit)
else:
# On Windows, we can only reliably catch SIGINT (Ctrl+C)
signal.signal(signal.SIGINT, lambda signum, frame: handle_exit())
- Getting some kind of network error between the GUI and the backend
It seems to work. It's insanely slow for me though (no GPU... the raspberry pis are much faster 😄). Windows changes perhaps out of scope for this PR though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wowzah!
This is exciting. A lot of the community were mad we didn't support Windows. We have a bounty here if you want to get it working there once this one is merged: #186
Great work. You clearly thought about this and implemented a really nice solution. I particularly like the generalisation of model splitting, rather than doing each one separately. Take a look through the comments I left. |
The main thing I want to address and test is device support. We can make this the default inference engine if it works reliably across many devices. On that point, if we can automate the bootstrapping of the environment for each user (e.g. install drivers, whatever else is needed to run on their device) that would be great. We don't have to do this in this PR/bounty, we can do another. But I would love to discuss and figure out how this can best be done. |
@@ -0,0 +1,21 @@ | |||
import unittest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Run this test in circle ci ./.circleci/config.yml
@@ -0,0 +1,33 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Run this test in circle ci ./.circleci/config.yml
n_layers=12 | ||
) | ||
|
||
engine = PyTorchDynamicShardInferenceEngine( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this test complete? We need a test that tests the model splitting. Take a look at exo/inference/test_inference_engine.py
. You can just add the test there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed and looks good. Will work through notes to improve
I have updated my main fork branch with the pytorch interface changes. Please take a look and test. Thank you! |
Hey @risingsunomi I'm thinking of making this the default inference engine on linux machines. |
Will do and clean up more. |
@AlexCheema clean up finished and no conflicts with base branch |
torch not added as a dependency |
accelerate package needs to be installed |
Let me add the dependences to the exo setup.py install_requires |
Grpc fix jan242025
Node fixes jan242025
Hello all,
I’ve made some updates to the exo library based on the bounty mentioned in this tweet/X post. These changes aim to integrate PyTorch and expand access to various language models through Hugging Face’s
AutoModelForCausalLM
.What's New?
These updates enable the exo library to use PyTorch, allowing access to a broader range of language models.
Limitations and Bugs
Right now the ShardedHuggingFaceModel is focused on using LlamaForCausalLM from the huggingface transformers library. From that model we break it up using LLamaModel and the layers it contains. We can then select the layers and run the pytorch tensors over them as need. I focused on using llama3.1 8B as I could only slightly run that.
Due to my current hardware limitations (specifically GPU and VRAM), I wasn’t able to fully test this across multiple nodes. The model currently takes about 30 seconds per token to generate for me (I have slow GPUs), which might be related to the absence of caching (not implemented due to VRAM constraints). It’s running without reaching an EOT and the outputs seem random.
Request for Feedback
I’m sharing this in the hope that others can test it on more capable setups and provide feedback on how to enhance performance and stability.
Important Note on Meta LLaMA 3.1 Model
If you plan to test with the official Meta LLaMA 3.1 model, please note:
huggingface-cli
to download it.Chat API Update
Looking forward to any feedback or suggestions you might have.
Thank you