Skip to content

Commit bc14a70

Browse files
authored
[Convert] Fix DeepSeek-R1 convert issue. (#145)
1 parent 83f531b commit bc14a70

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ xFasterTransformer supports a different model format from Huggingface, but it's
171171
1. Download the huggingface format model firstly.
172172
2. After that, convert the model into xFasterTransformer format by using model convert module in xfastertransformer. If output directory is not provided, converted model will be placed into `${HF_DATASET_DIR}-xft`.
173173
```
174-
python -c 'import xfastertransformer as xft; xft.LlamaConvert().convert("${HF_DATASET_DIR}","${OUTPUT_DIR}")'
174+
python -c "import xfastertransformer as xft; xft.DeepSeekR1Convert().convert('${HF_DATASET_DIR}', '${OUTPUT_DIR}')"
175175
```
176176
***PS: Due to the potential compatibility issues between the model file and the `transformers` version, please select the appropriate `transformers` version.***
177177

README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ xFasterTransformer 支持的模型格式与 Huggingface 有所不同,但与 Fa
172172
1. 首先下载 huggingface 格式的模型。
173173
2. 然后,使用 xfastertransformer 中的模型转换模块将模型转换为 xFasterTransformer 格式。如果没有提供输出目录,转换后的模型将被默认放置到 `${HF_DATASET_DIR}-xft`.
174174
```
175-
python -c 'import xfastertransformer as xft; xft.LlamaConvert().convert("${HF_DATASET_DIR}","${OUTPUT_DIR}")'
175+
python -c "import xfastertransformer as xft; xft.DeepSeekR1Convert().convert('${HF_DATASET_DIR}', '${OUTPUT_DIR}')"
176176
```
177177
***PS: 由于模型文件和 `transformers` 版本之间可能存在兼容性问题,请选择相应的 `transformers` 版本。***
178178

src/xfastertransformer/tools/convert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def map_np_dtype_to_torch(dtype: np.dtype):
102102
np.float32: [torch.float32, torch.float32],
103103
np.float16: [torch.float16, torch.float16],
104104
np.uint16: [torch.bfloat16, torch.uint16],
105+
np.uint8: [torch.float8_e4m3fn, torch.uint8],
105106
}
106107
if dtype in MAPPING:
107108
return MAPPING[dtype]

0 commit comments

Comments
 (0)