Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/granite_switch/vllm/core/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,33 @@ def __init__(
# num_adapters and max_lora_rank are config metadata, not runtime parameters.

# Detect layer properties (handles both standard and vLLM parallel layers)
if hasattr(base_layer, "weight"):
# We prefer explicit dimension attributes over weight.shape because
# quantized formats (e.g. BnB 4-bit) pack weights into different shapes
# (e.g. [total_elements//2, 1] for uint8-packed INT4).
if hasattr(base_layer, "input_size_per_partition"):
# vLLM parallel layer — authoritative dimensions
in_features = base_layer.input_size_per_partition
out_features = base_layer.output_size_per_partition
device = base_layer.weight.device
dtype = base_layer.weight.dtype
elif hasattr(base_layer, "weight"):
in_features = base_layer.weight.shape[1]
out_features = base_layer.weight.shape[0]
device = base_layer.weight.device
dtype = base_layer.weight.dtype
elif hasattr(base_layer, "qweight"):
# Quantized layer
# Quantized layer (GPTQ/AWQ style)
in_features = base_layer.input_size
out_features = base_layer.output_size
device = base_layer.qweight.device
dtype = torch.float16
else:
raise ValueError(f"Unsupported base layer type: {type(base_layer)}")

# BnB quantization stores weights as uint8 — LoRA buffers need a float dtype
if not dtype.is_floating_point:
dtype = torch.bfloat16

self.in_features = in_features
self.out_features = out_features

Expand Down
Loading