Skip to content

Commit

Permalink
Merge pull request #46 from risingsunomi/node-fixes-jan242025
Browse files Browse the repository at this point in the history
Node fixes jan242025
  • Loading branch information
risingsunomi authored Jan 25, 2025
2 parents 386ac0b + 4e3e53e commit c3bde74
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 29 deletions.
5 changes: 3 additions & 2 deletions exo/inference/torch/models/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def forward(
curr_layers = [self.layers[i] for i in range(self.shard.start_layer, self.shard.end_layer + 1)]
for i, layer in enumerate(curr_layers):
if DEBUG >= 8:
print(f"\nhidden layer in H[{i}]\n{h}")
print(f"\nhidden layer in H[{self.shard.start_layer+i}]\n{h}")
print(f"\nmask\n{mask}\ninput_pos\n{input_pos}")
print(f"\noutput_hidden_states\n{self.output_hidden_states}\n")

Expand All @@ -189,7 +189,7 @@ def forward(
# hidden.append(h)

if DEBUG >= 8:
print(f"\nhidden layer out H[{i}]->H[{i + 1}]\n{h}\n")
print(f"\nhidden layer out H[{self.shard.start_layer+i}]->H[{self.shard.start_layer+i+1}]\n{h}\n")

if self.shard.is_last_layer():
# Apply normalization
Expand Down Expand Up @@ -386,6 +386,7 @@ def generate(
print(f"mask: {mask}")
print(f"input_pos: {input_pos}")


model_output = self.model(
tokens=tokens,
mask=mask,
Expand Down
83 changes: 56 additions & 27 deletions exo/inference/torch/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, shard_downloader: HFShardDownloader):
self.state = None
self.oom_cnt = 0

# cache settings
self.use_cache = bool(os.getenv("TORCH_USE_CACHE", "True").lower() == "true")

# device settings
if os.environ.get("TORCH_DEVICE"):
self.device = torch.device(os.environ["TORCH_DEVICE"])
Expand Down Expand Up @@ -129,7 +132,7 @@ def encode_wrapper() -> np.ndarray:
total_response_length = tklng + self.sharded_model.max_generated_tokens

# setup cache
if not self.sharded_model.model.caches_are_enabled():
if not self.sharded_model.model.caches_are_enabled() and self.use_cache:
with self.device:
self.sharded_model.model.setup_caches(
bsz,
Expand Down Expand Up @@ -240,17 +243,16 @@ async def infer_tensor(
self.request_id = request_id if not self.request_id else self.request_id

hidden_state = None
input_tensor = None
if input_data.ndim == 3:
hidden_state = torch.tensor(input_data).to(self.device)
hidden_state = torch.tensor(input_data).to(
device=self.device,
dtype=self.model_config["torch_dtype"]
)
elif input_data.ndim == 2:
input_tensor = torch.tensor(input_data).to(self.device)
if self.state.tokens is not None:
self.state.tokens = torch.cat([
self.state.tokens.to(self.device),
input_tensor
], dim=-1).to(self.device)
else:
self.state.tokens = input_tensor.clone()
input_tensor = torch.tensor(input_data).to(
device=self.device
)

def infer_wrapper():
if DEBUG >= 4:
Expand All @@ -260,27 +262,49 @@ def infer_wrapper():

model_cache = self.sharded_model.model.caches_are_enabled()

if self.state.tokens is not None:
if input_data.ndim == 2 and input_tensor.size(-1) == 1:
self.state.tokens = torch.cat([
self.state.tokens.to(self.device),
input_tensor.clone()
], dim=-1).to(self.device)
else:
self.state.tokens = input_tensor.clone()

try:
in_tokens = self.state.tokens.clone().to(
device=self.device
)

in_input_pos = self.state.input_pos.clone().to(
device=self.device
)

in_mask = self.state.mask.clone().to(
device=self.device
)

if hidden_state is not None:
model_hs, model_logits = self.sharded_model.generate(
tokens=in_tokens,
hidden_state=hidden_state,
input_pos=self.state.input_pos.to(self.device),
mask=self.state.mask.to(self.device),
input_pos=in_input_pos,
mask=in_mask,
curr_pos=self.state.curr_pos
)
else:
if not model_cache:
model_hs, model_logits = self.sharded_model.generate(
tokens=self.state.tokens.to(self.device),
input_pos=self.state.input_pos.to(self.device),
mask=self.state.mask.to(self.device),
tokens=in_tokens,
input_pos=in_input_pos,
mask=in_mask,
curr_pos=self.state.curr_pos
)
else:
model_hs, model_logits = self.sharded_model.generate(
tokens=input_tensor,
input_pos=self.state.input_pos.to(self.device),
mask=self.state.mask.to(self.device),
input_pos=in_input_pos,
mask=in_mask,
curr_pos=self.state.curr_pos
)
except torch.cuda.OutOfMemoryError:
Expand Down Expand Up @@ -348,23 +372,28 @@ async def ensure_shard(self, shard: Shard):

# self.tokenizer = await _resolve_tokenizer(model_path)
self.tokenizer = await _resolve_tokenizer(self.model_path)

self.sharded_model = await asyncio.get_running_loop().run_in_executor(
self.executor,
functools.partial(
ShardedLlamaModel,

def start_model():
if DEBUG >= 4:
print("start_model called")

self.sharded_model = ShardedLlamaModel(
config=self.model_config,
shard=shard,
device=self.device,
dtype=self.model_config["torch_dtype"],
use_cache=bool(os.getenv("TORCH_USE_CACHE", "True").lower() == "true"),
),
)
use_cache=self.use_cache
)

# load sharded weights
load_weights_torch(
self.model_path,
self.sharded_model.model,
self.model_config
)

await asyncio.get_running_loop().run_in_executor(
self.executor,
functools.partial(load_weights_torch, self.model_path, self.sharded_model.model, self.model_config),
functools.partial(start_model),
)

async def load_checkpoint(self, shard: Shard, path: str):
Expand Down

0 comments on commit c3bde74

Please sign in to comment.