Add SIMD impl of `memset` for LoongArch (#547)

This commit is contained in:
hev 2025-07-02 23:55:19 +08:00 committed by GitHub
parent 259a198dc0
commit e9ad75685f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 126 additions and 2 deletions

View File

@ -72,7 +72,7 @@ pub fn memset<T: MemsetSafe>(dst: &mut [T], val: T) {
#[inline]
fn memset_raw(beg: *mut u8, end: *mut u8, val: u64) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "loongarch64"))]
return unsafe { MEMSET_DISPATCH(beg, end, val) };
#[cfg(target_arch = "aarch64")]
@ -108,7 +108,7 @@ unsafe fn memset_fallback(mut beg: *mut u8, end: *mut u8, val: u64) {
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "loongarch64"))]
static mut MEMSET_DISPATCH: unsafe fn(beg: *mut u8, end: *mut u8, val: u64) = memset_dispatch;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
@ -235,6 +235,130 @@ fn memset_avx2(mut beg: *mut u8, end: *mut u8, val: u64) {
}
}
#[cfg(target_arch = "loongarch64")]
fn memset_dispatch(beg: *mut u8, end: *mut u8, val: u64) {
use std::arch::is_loongarch_feature_detected;
let func = if is_loongarch_feature_detected!("lasx") {
memset_lasx
} else if is_loongarch_feature_detected!("lsx") {
memset_lsx
} else {
memset_fallback
};
unsafe { MEMSET_DISPATCH = func };
unsafe { func(beg, end, val) }
}
#[cfg(target_arch = "loongarch64")]
#[target_feature(enable = "lasx")]
fn memset_lasx(mut beg: *mut u8, end: *mut u8, val: u64) {
unsafe {
use std::arch::loongarch64::*;
use std::mem::transmute as T;
let fill: v32i8 = T(lasx_xvreplgr2vr_d(val as i64));
if end.offset_from_unsigned(beg) >= 32 {
lasx_xvst::<0>(fill, beg as *mut _);
let off = beg.align_offset(32);
beg = beg.add(off);
}
if end.offset_from_unsigned(beg) >= 128 {
loop {
lasx_xvst::<0>(fill, beg as *mut _);
lasx_xvst::<32>(fill, beg as *mut _);
lasx_xvst::<64>(fill, beg as *mut _);
lasx_xvst::<96>(fill, beg as *mut _);
beg = beg.add(128);
if end.offset_from_unsigned(beg) < 128 {
break;
}
}
}
if end.offset_from_unsigned(beg) >= 16 {
let fill: v16i8 = T(lsx_vreplgr2vr_d(val as i64));
loop {
lsx_vst::<0>(fill, beg as *mut _);
beg = beg.add(16);
if end.offset_from_unsigned(beg) < 16 {
break;
}
}
}
if end.offset_from_unsigned(beg) >= 8 {
// 8-15 bytes
(beg as *mut u64).write_unaligned(val);
(end.sub(8) as *mut u64).write_unaligned(val);
} else if end.offset_from_unsigned(beg) >= 4 {
// 4-7 bytes
(beg as *mut u32).write_unaligned(val as u32);
(end.sub(4) as *mut u32).write_unaligned(val as u32);
} else if end.offset_from_unsigned(beg) >= 2 {
// 2-3 bytes
(beg as *mut u16).write_unaligned(val as u16);
(end.sub(2) as *mut u16).write_unaligned(val as u16);
} else if end.offset_from_unsigned(beg) >= 1 {
// 1 byte
beg.write(val as u8);
}
}
}
#[cfg(target_arch = "loongarch64")]
#[target_feature(enable = "lsx")]
unsafe fn memset_lsx(mut beg: *mut u8, end: *mut u8, val: u64) {
unsafe {
use std::arch::loongarch64::*;
use std::mem::transmute as T;
if end.offset_from_unsigned(beg) >= 16 {
let fill: v16i8 = T(lsx_vreplgr2vr_d(val as i64));
lsx_vst::<0>(fill, beg as *mut _);
let off = beg.align_offset(16);
beg = beg.add(off);
while end.offset_from_unsigned(beg) >= 32 {
lsx_vst::<0>(fill, beg as *mut _);
lsx_vst::<16>(fill, beg as *mut _);
beg = beg.add(32);
}
if end.offset_from_unsigned(beg) >= 16 {
// 16-31 bytes remaining
lsx_vst::<0>(fill, beg as *mut _);
lsx_vst::<-16>(fill, end as *mut _);
return;
}
}
if end.offset_from_unsigned(beg) >= 8 {
// 8-15 bytes remaining
(beg as *mut u64).write_unaligned(val);
(end.sub(8) as *mut u64).write_unaligned(val);
} else if end.offset_from_unsigned(beg) >= 4 {
// 4-7 bytes remaining
(beg as *mut u32).write_unaligned(val as u32);
(end.sub(4) as *mut u32).write_unaligned(val as u32);
} else if end.offset_from_unsigned(beg) >= 2 {
// 2-3 bytes remaining
(beg as *mut u16).write_unaligned(val as u16);
(end.sub(2) as *mut u16).write_unaligned(val as u16);
} else if end.offset_from_unsigned(beg) >= 1 {
// 1 byte remaining
beg.write(val as u8);
}
}
}
#[cfg(target_arch = "aarch64")]
unsafe fn memset_neon(mut beg: *mut u8, end: *mut u8, val: u64) {
unsafe {