ksys/gdt: Implement more TriggerParam copy functions

This commit is contained in:
Léo Lam
2020-10-31 15:23:18 +01:00
parent 1ccd65836d
commit 6c686fb962
3 changed files with 450 additions and 51 deletions
+407 -16
View File
@@ -18,6 +18,7 @@ TriggerParam::TriggerParam() {
cs.constructDefault();
mBitFlags.constructDefault();
mFlagChangeRecordIndices.fill(0);
mNumBoolFlagsPerCategory0.fill(0);
mNumBoolFlagsPerCategory.fill(0);
}
@@ -431,10 +432,9 @@ void makeFlagProxies(sead::PtrArray<FlagBase>& dest, const sead::PtrArray<FlagBa
}
}
// TODO: remove noinline
template <typename T>
[[gnu::noinline]] void addFlagCopyRecord(sead::ObjArray<TriggerParam::FlagCopyRecord>& records,
Flag<T>* flag, s32 sub_index, bool find_existing_record) {
void addFlagCopyRecord(sead::ObjArray<TriggerParam::FlagCopyRecord>& records, Flag<T>* flag,
s32 sub_index, bool find_existing_record) {
if (records.isFull())
return;
@@ -458,10 +458,152 @@ template <typename T>
record->name_hash = flag->getHash();
record->sub_index = sub_index;
}
record->bit_flag.makeAllZero();
record->bf.makeAllZero();
}
} // namespace
/**
* @param dest Destination flag array
* @param src Source flag array
* @param records Copy records (only used if add_records is true)
* @param counts Boolean flag counts per category (only used if T = bool)
* @param record_copies Whether copies should be recorded
* @param ignore_temp_flags Whether temporary flags (i.e. flags that aren't saved)
* should be ignored
* @param find_existing_record See addFlagCopyRecord.
* @param is_array Whether the flag arrays are associated with an array-type flag (e.g. bool array)
*/
template <typename T>
void copyFlags(sead::PtrArray<FlagBase>& dest, const sead::PtrArray<FlagBase>& src,
sead::ObjArray<TriggerParam::FlagCopyRecord>& records,
sead::SafeArray<s32, 15>& counts, bool record_copies, bool ignore_temp_flags,
bool find_existing_record, bool is_array) {
// This function only handles two cases:
//
// 1. If both the source and destination have the same number of flags, we assume that
// both arrays store the same flags in the exact same order.
//
// 2. If the source has fewer flags than the destination, it is assumed that src is a subset of
// dest (i.e. every flag that is in src is also in dest) and src forms a subsequence of dest.
// Any flag in dest that isn't in src will be reset to its initial value.
const auto update_records = [&](Flag<T>* flag, const Flag<T>* flag_src, s32 i) {
flag->setValue(flag_src->getValue());
s32 category, category_idx;
if constexpr (std::is_same<T, bool>()) {
category = flag->getCategory();
category_idx = category - 1;
}
if (record_copies) {
addFlagCopyRecord<T>(records, flag, is_array ? i : -1, find_existing_record);
if constexpr (std::is_same<T, bool>()) {
if (category > 0) {
if (flag->getValue())
++counts[category_idx];
else if (counts[category_idx] > 0)
--counts[category_idx];
}
}
} else {
if constexpr (std::is_same<T, bool>()) {
if (category > 0 && flag->getValue())
++counts[category_idx];
}
}
};
const auto dest_size = dest.size();
const auto src_size = src.size();
if (dest_size == src_size) {
for (s32 i = 0; i < dest_size; ++i) {
if (!dest[i])
continue;
if (ignore_temp_flags && !dest[i]->isSave())
continue;
if (record_copies && dest[i]->isSave())
continue;
if (!src[i]) {
if (!dest[i]->isSave())
dest[i]->resetToInitialValue();
} else {
auto* flag = static_cast<Flag<T>*>(dest[i]);
auto* flag_src = static_cast<Flag<T>*>(src[i]);
update_records(flag, flag_src, i);
}
}
} else if (src_size < dest_size) {
s32 j = 0;
for (s32 i = 0; i < dest_size; ++i) {
if (ignore_temp_flags && !dest[i]->isSave())
continue;
if (record_copies && dest[i]->isSave())
continue;
if (j >= src_size || src[j]->getHash() != dest[i]->getHash()) {
if (!dest[i]->isSave())
dest[i]->resetToInitialValue();
} else if (src[j]->getHash() == dest[i]->getHash()) {
auto* flag = static_cast<Flag<T>*>(dest[i]);
auto* flag_src = static_cast<Flag<T>*>(src[j]);
update_records(flag, flag_src, i);
++j;
}
}
}
}
template <typename T>
void copyFlagArrays(sead::PtrArray<sead::PtrArray<FlagBase>>& dest,
const sead::PtrArray<sead::PtrArray<FlagBase>>& src,
sead::ObjArray<TriggerParam::FlagCopyRecord>& records,
sead::SafeArray<s32, 15>& counts, bool record_copies, bool ignore_temp_flags,
bool find_existing_record) {
const auto dest_size = dest.size();
const auto src_size = src.size();
if (src_size == dest_size) {
for (s32 i = 0; i < dest_size; ++i) {
if (!src[i]) {
const auto n = dest[i]->size();
if (!dest[i]->at(0)->isSave()) {
for (s32 j = 0; j < n; ++j)
dest[i]->at(j)->resetToInitialValue();
}
} else {
copyFlags<T>(*dest[i], *src[i], records, counts, record_copies, ignore_temp_flags,
find_existing_record, true);
}
}
} else {
s32 j = 0;
for (s32 i = 0; i < dest_size; ++i) {
if (j >= src_size) {
const auto n = dest[i]->size();
if (!dest[i]->at(0)->isSave()) {
for (s32 k = 0; k < n; ++k)
dest[i]->at(k)->resetToInitialValue();
}
} else if (src[j]->at(0)->getHash() != dest[i]->at(0)->getHash()) {
const auto n = dest[i]->size();
if (!dest[i]->at(0)->isSave()) {
for (s32 k = 0; k < n; ++k)
dest[i]->at(k)->resetToInitialValue();
}
} else {
copyFlags<T>(*dest[i], *src[j], records, counts, record_copies, ignore_temp_flags,
find_existing_record, true);
++j;
}
}
}
}
void TriggerParam::copyFromGameDataRes(res::GameData* gdata, sead::Heap* heap) {
if (!gdata)
return;
@@ -1448,20 +1590,74 @@ void TriggerParam::resetAllFlagsToInitialValues() {
}
}
bool TriggerParam::getBoolIfCopied(bool* value, const sead::SafeString& name, bool x,
bool y) const {
if (mCopiedBoolFlags.isEmpty())
return false;
const u32 hash = sead::HashCRC32::calcStringHash(name);
for (s32 i = 0; i < mCopiedBoolFlags.size(); ++i) {
if (mCopiedBoolFlags[i]->name_hash == hash) {
if (!mCopiedBoolFlags[i]->checkBitFlags(x, y))
return false;
*value = static_cast<FlagBool*>(mBoolFlags[getFlagIndex(mBoolFlags, hash)])->getValue();
return true;
}
}
return false;
}
bool TriggerParam::getS32IfCopied(s32* value, const sead::SafeString& name, bool x, bool y) const {
if (mCopiedS32Flags.isEmpty())
return false;
const u32 hash = sead::HashCRC32::calcStringHash(name);
for (s32 i = 0; i < mCopiedS32Flags.size(); ++i) {
if (mCopiedS32Flags[i]->name_hash == hash) {
if (!mCopiedS32Flags[i]->checkBitFlags(x, y))
return false;
*value = static_cast<FlagS32*>(mS32Flags[getFlagIndex(mS32Flags, hash)])->getValue();
return true;
}
}
return false;
}
bool TriggerParam::getF32IfCopied(f32* value, const sead::SafeString& name, bool x, bool y) const {
if (mCopiedF32Flags.isEmpty())
return false;
const u32 hash = sead::HashCRC32::calcStringHash(name);
for (s32 i = 0; i < mCopiedF32Flags.size(); ++i) {
if (mCopiedF32Flags[i]->name_hash == hash) {
if (!mCopiedF32Flags[i]->checkBitFlags(x, y))
return false;
*value = static_cast<FlagF32*>(mF32Flags[getFlagIndex(mF32Flags, hash)])->getValue();
return true;
}
}
return false;
}
// FIXME: very incomplete
void TriggerParam::copyChangedFlags(TriggerParam& other, bool set_all_flags, bool y, bool z) {
void TriggerParam::copyChangedFlags(TriggerParam& other, bool set_all_flags, bool record_copies,
bool ignore_temp_flags) {
if (!set_all_flags) {
for (s32 i = 0; i < 3; ++i) {
auto lock = sead::makeScopedLock(other.mCriticalSections[i].ref());
for (s32 j = 0; j < mFlagChangeRecordIndices[i]; ++i) {
const FlagChangeRecord& record = mFlagChangeRecords[i].ref()[j];
const auto idx = record.index;
const auto sub_idx = record.sub_index;
switch (record.type) {
case FlagType::Bool: {
auto* flag = static_cast<FlagBool*>(mBoolFlags[record.index]);
const bool find_existing = mBitFlags.ref().isOn(BitFlag::_7);
flag->setValue(
static_cast<FlagBool*>(other.mBoolFlags[record.index])->getValue());
auto* flag = static_cast<FlagBool*>(mBoolFlags[idx]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(static_cast<FlagBool*>(other.mBoolFlags[idx])->getValue());
const auto category = flag->getCategory();
addFlagCopyRecord(mCopiedBoolFlags, flag, -1, find_existing);
@@ -1473,15 +1669,75 @@ void TriggerParam::copyChangedFlags(TriggerParam& other, bool set_all_flags, boo
}
break;
}
case FlagType::S32: {
auto* flag = static_cast<FlagS32*>(mS32Flags[idx]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(static_cast<FlagS32*>(other.mS32Flags[idx])->getValue());
addFlagCopyRecord(mCopiedS32Flags, flag, -1, find_existing);
break;
}
case FlagType::F32: {
auto* flag = static_cast<FlagF32*>(mF32Flags[idx]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(static_cast<FlagF32*>(other.mF32Flags[idx])->getValue());
addFlagCopyRecord(mCopiedF32Flags, flag, -1, find_existing);
break;
}
case FlagType::String: {
auto* flag = static_cast<FlagString*>(mStringFlags[idx]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(static_cast<FlagString*>(other.mStringFlags[idx])->getValue());
addFlagCopyRecord(mCopiedStringFlags, flag, -1, find_existing);
break;
}
case FlagType::String64: {
auto* flag = static_cast<FlagString64*>(mString64Flags[idx]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(
static_cast<FlagString64*>(other.mString64Flags[idx])->getValue());
addFlagCopyRecord(mCopiedString64Flags, flag, -1, find_existing);
break;
}
case FlagType::String256: {
auto* flag = static_cast<FlagString256*>(mString256Flags[idx]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(
static_cast<FlagString256*>(other.mString256Flags[idx])->getValue());
addFlagCopyRecord(mCopiedString256Flags, flag, -1, find_existing);
break;
}
case FlagType::Vector2f: {
auto* flag = static_cast<FlagVector2f*>(mVector2fFlags[idx]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(
static_cast<FlagVector2f*>(other.mVector2fFlags[idx])->getValue());
addFlagCopyRecord(mCopiedVector2fFlags, flag, -1, find_existing);
break;
}
case FlagType::Vector3f: {
auto* flag = static_cast<FlagVector3f*>(mVector3fFlags[idx]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(
static_cast<FlagVector3f*>(other.mVector3fFlags[idx])->getValue());
addFlagCopyRecord(mCopiedVector3fFlags, flag, -1, find_existing);
break;
}
case FlagType::Vector4f: {
auto* flag = static_cast<FlagVector4f*>(mVector4fFlags[idx]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(
static_cast<FlagVector4f*>(other.mVector4fFlags[idx])->getValue());
addFlagCopyRecord(mCopiedVector4fFlags, flag, -1, find_existing);
break;
}
case FlagType::BoolArray: {
auto* flag =
static_cast<FlagBool*>((*mBoolArrayFlags[record.index])[record.sub_index]);
const bool find_existing = mBitFlags.ref().isOn(BitFlag::_7);
flag->setValue(static_cast<FlagBool*>(
(*other.mBoolArrayFlags[record.index])[record.sub_index])
->getValue());
auto* flag = static_cast<FlagBool*>((*mBoolArrayFlags[idx])[sub_idx]);
auto* flag2 = static_cast<FlagBool*>((*other.mBoolArrayFlags[idx])[sub_idx]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(flag2->getValue());
const auto category = flag->getCategory();
addFlagCopyRecord(mCopiedBoolFlags, flag, record.sub_index, find_existing);
addFlagCopyRecord(mCopiedBoolFlags, flag, sub_idx, find_existing);
if (category > 0) {
if (flag->getValue())
@@ -1491,6 +1747,86 @@ void TriggerParam::copyChangedFlags(TriggerParam& other, bool set_all_flags, boo
}
break;
}
case FlagType::S32Array: {
auto* flag = static_cast<FlagS32*>((*mS32ArrayFlags[idx])[record.sub_index]);
auto* flag2 =
static_cast<FlagS32*>((*other.mS32ArrayFlags[idx])[record.sub_index]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(flag2->getValue());
addFlagCopyRecord(mCopiedS32Flags, flag, sub_idx, find_existing);
break;
}
case FlagType::F32Array: {
auto* flag = static_cast<FlagF32*>((*mF32ArrayFlags[idx])[record.sub_index]);
auto* flag2 =
static_cast<FlagF32*>((*other.mF32ArrayFlags[idx])[record.sub_index]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(flag2->getValue());
addFlagCopyRecord(mCopiedF32Flags, flag, sub_idx, find_existing);
break;
}
case FlagType::StringArray: {
auto* flag =
static_cast<FlagString*>((*mStringArrayFlags[idx])[record.sub_index]);
auto* flag2 =
static_cast<FlagString*>((*other.mStringArrayFlags[idx])[record.sub_index]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(flag2->getValue());
addFlagCopyRecord(mCopiedStringFlags, flag, sub_idx, find_existing);
break;
}
case FlagType::String64Array: {
auto* flag =
static_cast<FlagString64*>((*mString64ArrayFlags[idx])[record.sub_index]);
auto* flag2 = static_cast<FlagString64*>(
(*other.mString64ArrayFlags[idx])[record.sub_index]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(flag2->getValue());
addFlagCopyRecord(mCopiedString64Flags, flag, sub_idx, find_existing);
break;
}
case FlagType::String256Array: {
auto* flag =
static_cast<FlagString256*>((*mString256ArrayFlags[idx])[record.sub_index]);
auto* flag2 = static_cast<FlagString256*>(
(*other.mString256ArrayFlags[idx])[record.sub_index]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(flag2->getValue());
addFlagCopyRecord(mCopiedString256Flags, flag, sub_idx, find_existing);
break;
}
case FlagType::Vector2fArray: {
auto* flag =
static_cast<FlagVector2f*>((*mVector2fArrayFlags[idx])[record.sub_index]);
auto* flag2 = static_cast<FlagVector2f*>(
(*other.mVector2fArrayFlags[idx])[record.sub_index]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(flag2->getValue());
addFlagCopyRecord(mCopiedVector2fFlags, flag, sub_idx, find_existing);
break;
}
case FlagType::Vector3fArray: {
auto* flag =
static_cast<FlagVector3f*>((*mVector3fArrayFlags[idx])[record.sub_index]);
auto* flag2 = static_cast<FlagVector3f*>(
(*other.mVector3fArrayFlags[idx])[record.sub_index]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(flag2->getValue());
addFlagCopyRecord(mCopiedVector3fFlags, flag, sub_idx, find_existing);
break;
}
case FlagType::Vector4fArray: {
auto* flag =
static_cast<FlagVector4f*>((*mVector4fArrayFlags[idx])[record.sub_index]);
auto* flag2 = static_cast<FlagVector4f*>(
(*other.mVector4fArrayFlags[idx])[record.sub_index]);
const bool find_existing = shouldFindExistingCopyRecord();
flag->setValue(flag2->getValue());
addFlagCopyRecord(mCopiedVector4fFlags, flag, sub_idx, find_existing);
break;
}
case FlagType::Invalid:
break;
}
}
@@ -1499,6 +1835,61 @@ void TriggerParam::copyChangedFlags(TriggerParam& other, bool set_all_flags, boo
}
return;
}
copyFlags<bool>(mBoolFlags, other.mBoolFlags, mCopiedBoolFlags, mNumBoolFlagsPerCategory0,
record_copies, ignore_temp_flags, shouldFindExistingCopyRecord(), false);
copyFlags<s32>(mS32Flags, other.mS32Flags, mCopiedS32Flags, mNumBoolFlagsPerCategory0,
record_copies, ignore_temp_flags, shouldFindExistingCopyRecord(), false);
copyFlags<f32>(mF32Flags, other.mF32Flags, mCopiedF32Flags, mNumBoolFlagsPerCategory0,
record_copies, ignore_temp_flags, shouldFindExistingCopyRecord(), false);
copyFlags<sead::FixedSafeString<32>>(mStringFlags, other.mStringFlags, mCopiedStringFlags,
mNumBoolFlagsPerCategory0, record_copies,
ignore_temp_flags, shouldFindExistingCopyRecord(), false);
copyFlags<sead::FixedSafeString<64>>(mString64Flags, other.mString64Flags, mCopiedString64Flags,
mNumBoolFlagsPerCategory0, record_copies,
ignore_temp_flags, shouldFindExistingCopyRecord(), false);
copyFlags<sead::FixedSafeString<256>>(
mString256Flags, other.mString256Flags, mCopiedString256Flags, mNumBoolFlagsPerCategory0,
record_copies, ignore_temp_flags, shouldFindExistingCopyRecord(), false);
copyFlags<sead::Vector2f>(mVector2fFlags, other.mVector2fFlags, mCopiedVector2fFlags,
mNumBoolFlagsPerCategory0, record_copies, ignore_temp_flags,
shouldFindExistingCopyRecord(), false);
copyFlags<sead::Vector3f>(mVector3fFlags, other.mVector3fFlags, mCopiedVector3fFlags,
mNumBoolFlagsPerCategory0, record_copies, ignore_temp_flags,
shouldFindExistingCopyRecord(), false);
copyFlags<sead::Vector4f>(mVector4fFlags, other.mVector4fFlags, mCopiedVector4fFlags,
mNumBoolFlagsPerCategory0, record_copies, ignore_temp_flags,
shouldFindExistingCopyRecord(), false);
copyFlagArrays<bool>(mBoolArrayFlags, other.mBoolArrayFlags, mCopiedBoolFlags,
mNumBoolFlagsPerCategory0, record_copies, ignore_temp_flags,
shouldFindExistingCopyRecord());
copyFlagArrays<s32>(mS32ArrayFlags, other.mS32ArrayFlags, mCopiedS32Flags,
mNumBoolFlagsPerCategory0, record_copies, ignore_temp_flags,
shouldFindExistingCopyRecord());
copyFlagArrays<f32>(mF32ArrayFlags, other.mF32ArrayFlags, mCopiedF32Flags,
mNumBoolFlagsPerCategory0, record_copies, ignore_temp_flags,
shouldFindExistingCopyRecord());
copyFlagArrays<sead::FixedSafeString<32>>(
mStringArrayFlags, other.mStringArrayFlags, mCopiedStringFlags, mNumBoolFlagsPerCategory0,
record_copies, ignore_temp_flags, shouldFindExistingCopyRecord());
copyFlagArrays<sead::FixedSafeString<64>>(mString64ArrayFlags, other.mString64ArrayFlags,
mCopiedString64Flags, mNumBoolFlagsPerCategory0,
record_copies, ignore_temp_flags,
shouldFindExistingCopyRecord());
copyFlagArrays<sead::FixedSafeString<256>>(mString256ArrayFlags, other.mString256ArrayFlags,
mCopiedString256Flags, mNumBoolFlagsPerCategory0,
record_copies, ignore_temp_flags,
shouldFindExistingCopyRecord());
copyFlagArrays<sead::Vector2f>(mVector2fArrayFlags, other.mVector2fArrayFlags,
mCopiedVector2fFlags, mNumBoolFlagsPerCategory0, record_copies,
ignore_temp_flags, shouldFindExistingCopyRecord());
copyFlagArrays<sead::Vector3f>(mVector3fArrayFlags, other.mVector3fArrayFlags,
mCopiedVector3fFlags, mNumBoolFlagsPerCategory0, record_copies,
ignore_temp_flags, shouldFindExistingCopyRecord());
copyFlagArrays<sead::Vector4f>(mVector4fArrayFlags, other.mVector4fArrayFlags,
mCopiedVector4fFlags, mNumBoolFlagsPerCategory0, record_copies,
ignore_temp_flags, shouldFindExistingCopyRecord());
}
s32 TriggerParam::getBoolIdx(u32 name) const {