diff --git a/src/CosmosDBSessionStateProviderAsync/CosmosDBSessionStateProviderAsync.cs b/src/CosmosDBSessionStateProviderAsync/CosmosDBSessionStateProviderAsync.cs index 8d3cd96..5f2595b 100644 --- a/src/CosmosDBSessionStateProviderAsync/CosmosDBSessionStateProviderAsync.cs +++ b/src/CosmosDBSessionStateProviderAsync/CosmosDBSessionStateProviderAsync.cs @@ -611,7 +611,7 @@ public override SessionStateStoreData CreateNewStoreData(HttpContextBase context staticObjects = GetSessionStaticObjects(context.ApplicationInstance.Context); } - return new SessionStateStoreData(new SessionStateItemCollection(), staticObjects, timeout); + return new SessionStateStoreData(CreateItemCollection(), staticObjects, timeout); } /// @@ -628,7 +628,7 @@ public override async Task CreateUninitializedItemAsync( string encodedBuf; - var item = new SessionStateStoreData(new SessionStateItemCollection(), + var item = new SessionStateStoreData(CreateItemCollection(), GetSessionStaticObjects(context.ApplicationInstance.Context), timeout); @@ -1038,6 +1038,28 @@ private static string GetEncodedStringFromMemoryStream(MemoryStream s) return Convert.ToBase64String(bytes.Array, bytes.Offset, bytes.Count); } + private static ISessionStateItemCollection CreateItemCollection() + { + return SessionStateModuleAsync.AllowConcurrentRequestsPerSession ? + new ConcurrentSessionStateItemCollection() as ISessionStateItemCollection : + new SessionStateItemCollection() as ISessionStateItemCollection; + } + + private static void SerializeItemCollection(ISessionStateItemCollection items, BinaryWriter writer) + { + if (items is ConcurrentSessionStateItemCollection concurrentItems) + concurrentItems.Serialize(writer); + else if (items is SessionStateItemCollection defaultItems) + defaultItems.Serialize(writer); + } + + private static ISessionStateItemCollection DeserializeItemCollection(BinaryReader reader) + { + return SessionStateModuleAsync.AllowConcurrentRequestsPerSession ? + ConcurrentSessionStateItemCollection.Deserialize(reader) as ISessionStateItemCollection : + SessionStateItemCollection.Deserialize(reader) as ISessionStateItemCollection; + } + private static void Serialize(SessionStateStoreData item, Stream stream) { bool hasItems = true; @@ -1060,7 +1082,7 @@ private static void Serialize(SessionStateStoreData item, Stream stream) if (hasItems) { - ((SessionStateItemCollection)item.Items).Serialize(writer); + SerializeItemCollection(item.Items, writer); } if (hasStaticObjects) @@ -1089,7 +1111,7 @@ internal static SessionStateStoreData DeserializeStoreData(HttpContextBase conte private static SessionStateStoreData Deserialize(HttpContextBase context, Stream stream) { int timeout; - SessionStateItemCollection sessionItems; + ISessionStateItemCollection sessionItems; bool hasItems; bool hasStaticObjects; HttpStaticObjectsCollection staticObjects; @@ -1106,11 +1128,11 @@ private static SessionStateStoreData Deserialize(HttpContextBase context, Stream if (hasItems) { - sessionItems = SessionStateItemCollection.Deserialize(reader); + sessionItems = DeserializeItemCollection(reader); } else { - sessionItems = new SessionStateItemCollection(); + sessionItems = CreateItemCollection(); } if (hasStaticObjects) diff --git a/src/SessionStateModule/InProcSessionStateStoreAsync.cs b/src/SessionStateModule/InProcSessionStateStoreAsync.cs index 5f43c14..c65edf0 100644 --- a/src/SessionStateModule/InProcSessionStateStoreAsync.cs +++ b/src/SessionStateModule/InProcSessionStateStoreAsync.cs @@ -437,7 +437,14 @@ private SessionStateStoreData CreateLegitStoreData( { if (sessionItems == null) { - sessionItems = new SessionStateItemCollection(); + if (SessionStateModuleAsync.AllowConcurrentRequestsPerSession) + { + sessionItems = new ConcurrentNonSerializingSessionStateItemCollection(); + } + else + { + sessionItems = new SessionStateItemCollection(); + } } if (staticObjects == null && context != null) diff --git a/src/SessionStateModule/Microsoft.AspNet.SessionState.SessionStateModule.csproj b/src/SessionStateModule/Microsoft.AspNet.SessionState.SessionStateModule.csproj index 7a54490..d6c4f52 100644 --- a/src/SessionStateModule/Microsoft.AspNet.SessionState.SessionStateModule.csproj +++ b/src/SessionStateModule/Microsoft.AspNet.SessionState.SessionStateModule.csproj @@ -66,8 +66,10 @@ + + diff --git a/src/SessionStateModule/Resources/SR.Designer.cs b/src/SessionStateModule/Resources/SR.Designer.cs index 44bc15e..423fd6c 100644 --- a/src/SessionStateModule/Resources/SR.Designer.cs +++ b/src/SessionStateModule/Resources/SR.Designer.cs @@ -60,6 +60,15 @@ internal SR() { } } + /// + /// Looks up a localized string similar to Unable to serialize the session state. For out-of-proc session stores, ASP.NET will serialize the session state objects, and as a result non-serializable objects or MarshalByRef objects are not permitted.. + /// + internal static string Cant_serialize_session_state { + get { + return ResourceManager.GetString("Cant_serialize_session_state", resourceCulture); + } + } + /// /// Looks up a localized string similar to Error occured when reading config secion '{0}'.. /// @@ -87,6 +96,15 @@ internal static string Invalid_session_custom_provider { } } + /// + /// Looks up a localized string similar to The session state information is invalid and might be corrupted.. + /// + internal static string Invalid_session_state { + get { + return ResourceManager.GetString("Invalid_session_state", resourceCulture); + } + } + /// /// Looks up a localized string similar to The custom session state store provider '{0}' is not found.. /// diff --git a/src/SessionStateModule/Resources/SR.resx b/src/SessionStateModule/Resources/SR.resx index 022bdfc..6c146d2 100644 --- a/src/SessionStateModule/Resources/SR.resx +++ b/src/SessionStateModule/Resources/SR.resx @@ -117,6 +117,9 @@ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + Unable to serialize the session state. For out-of-proc session stores, ASP.NET will serialize the session state objects, and as a result non-serializable objects or MarshalByRef objects are not permitted. + Error occured when reading config secion '{0}'. @@ -126,6 +129,9 @@ The custom session state store provider name '{0}' is invalid. + + The session state information is invalid and might be corrupted. + The custom session state store provider '{0}' is not found. diff --git a/src/SessionStateModule/SessionStateItemCollections.cs b/src/SessionStateModule/SessionStateItemCollections.cs new file mode 100644 index 0000000..c1e0324 --- /dev/null +++ b/src/SessionStateModule/SessionStateItemCollections.cs @@ -0,0 +1,774 @@ +using Microsoft.AspNet.SessionState.Resources; +using System; +using System.Collections; +using System.Collections.Specialized; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Web; +using HttpRuntime = System.Web.HttpRuntime; +using ISessionStateItemCollection = System.Web.SessionState.ISessionStateItemCollection; + +namespace Microsoft.AspNet.SessionState +{ + /// + /// A threadsafe collection of objects stored in session state. Does not serialize state. This class cannot be inherited. + /// + public sealed class ConcurrentNonSerializingSessionStateItemCollection : NameObjectCollectionBase, ISessionStateItemCollection, ICollection, IEnumerable + { + // It should be noted that the base NameObjectCollectionBase isn't really intrinsically threadsafe. + // We protect the basic operations with a lock here, but we don't protect GetEnumerator/Keys + // scenarios. It is incumbant upon the caller to ensure that they don't modify the collection while + // enumerating it. This is a limitation of the NameObjectCollectionBase class. + private object _collectionLock = new object(); + + /// Gets or sets a value indicating whether the collection has been marked as changed. + /// true if the contents have been changed; otherwise, false. + public bool Dirty { get; set; } + + /// Gets or sets a value in the collection by name. + /// The value in the collection with the specified name. If the specified key is not found, attempting to get it returns null, and attempting to set it creates a new element using the specified key. + /// The key name of the value in the collection. + public object this[string name] + { + get + { + lock (this._collectionLock) + { + object obj = base.BaseGet(name); + if (obj != null && !IsImmutable(obj)) + { + this.Dirty = true; + } + return obj; + } + } + set + { + lock (this._collectionLock) + { + base.BaseSet(name, value); + this.Dirty = true; + } + } + } + + /// Gets or sets a value in the collection by numerical index. + /// The value in the collection stored at the specified index. If the specified key is not found, attempting to get it returns null, and attempting to set it creates a new element using the specified key. + /// The numerical index of the value in the collection. + public object this[int index] + { + get + { + lock (this._collectionLock) + { + object obj = base.BaseGet(index); + if (obj != null && !IsImmutable(obj)) + { + this.Dirty = true; + } + return obj; + } + } + set + { + lock (this._collectionLock) + { + base.BaseSet(index, value); + this.Dirty = true; + } + } + } + + /// Gets a collection of the variable names for all values stored in the collection. + /// The collection that contains all the collection keys. + public override NameObjectCollectionBase.KeysCollection Keys + { + get + { + return base.Keys; + } + } + + /// Creates a new, empty object. + public ConcurrentNonSerializingSessionStateItemCollection() : base(Misc.CaseInsensitiveInvariantKeyComparer) + { + } + + /// Removes all values and keys from the session-state collection. + public void Clear() + { + lock (this._collectionLock) + { + base.BaseClear(); + this.Dirty = true; + } + } + + /// Returns an enumerator that can be used to read all the key names in the collection. + /// An that can iterate through the variable names in the session-state collection. + public override IEnumerator GetEnumerator() + { + return base.GetEnumerator(); + } + + internal static bool IsImmutable(object o) + { + return Misc.ImmutableTypes[o.GetType()] != null; + } + + /// Deletes an item from the collection. + /// The name of the item to delete from the collection. + public void Remove(string name) + { + lock (this._collectionLock) + { + base.BaseRemove(name); + this.Dirty = true; + } + } + + /// Deletes an item at a specified index from the collection. + /// The index of the item to remove from the collection. + /// + /// is less than zero.- or - is equal to or greater than . + public void RemoveAt(int index) + { + lock (this._collectionLock) + { + base.BaseRemoveAt(index); + this.Dirty = true; + } + } + } + + /// A thread-safe collection of objects stored in session state. This class cannot be inherited. + public sealed class ConcurrentSessionStateItemCollection : NameObjectCollectionBase, ISessionStateItemCollection, ICollection, IEnumerable + { + private const int NO_NULL_KEY = -1; + private const int SIZE_OF_INT32 = 4; + + private bool _dirty; + private KeyedCollection _serializedItems; + private Stream _stream; + private int _iLastOffset; + private object _serializedItemsLock = new object(); + + /// Gets or sets a value indicating whether the collection has been marked as changed. + /// true if the contents have been changed; otherwise, false. + public bool Dirty + { + get + { + return this._dirty; + } + set + { + this._dirty = value; + } + } + + /// Gets or sets a value in the collection by name. + /// The value in the collection with the specified name. If the specified key is not found, attempting to get it returns null, and attempting to set it creates a new element using the specified key. + /// The key name of the value in the collection. + public object this[string name] + { + get + { + lock (this._serializedItemsLock) + { + this.DeserializeItem(name, true); + object obj = base.BaseGet(name); + if (obj != null && !IsImmutable(obj)) + { + this._dirty = true; + } + return obj; + } + } + set + { + lock (this._serializedItemsLock) + { + this.MarkItemDeserialized(name); + base.BaseSet(name, value); + this._dirty = true; + } + } + } + + /// Gets or sets a value in the collection by numerical index. + /// The value in the collection stored at the specified index. If the specified key is not found, attempting to get it returns null, and attempting to set it creates a new element using the specified key. + /// The numerical index of the value in the collection. + public object this[int index] + { + get + { + lock (this._serializedItemsLock) + { + this.DeserializeItem(index); + object obj = base.BaseGet(index); + if (obj != null && !IsImmutable(obj)) + { + this._dirty = true; + } + return obj; + } + } + set + { + lock (this._serializedItemsLock) + { + this.MarkItemDeserialized(index); + base.BaseSet(index, value); + this._dirty = true; + } + } + } + + /// Gets a collection of the variable names for all values stored in the collection. + /// The collection that contains all the collection keys. + public override NameObjectCollectionBase.KeysCollection Keys + { + get + { + // Unfortunately, we have to deserialize all items first, because Keys.GetEnumerator might + // be called and we have the same problem as in GetEnumerator() below. Also, DeserializeAllItems + // take the lock to ensure consistency - which it does within. + this.DeserializeAllItems(); + return base.Keys; + } + } + + /// Creates a new, empty object. + public ConcurrentSessionStateItemCollection() : base(Misc.CaseInsensitiveInvariantKeyComparer) + { + } + + /// Removes all values and keys from the session-state collection. + public void Clear() + { + lock (this._serializedItemsLock) + { + if (this._serializedItems != null) + { + this._serializedItems.Clear(); + } + base.BaseClear(); + this._dirty = true; + } + } + + /// Returns an enumerator that can be used to read all the key names in the collection. + /// An that can iterate through the variable names in the session-state collection. + public override IEnumerator GetEnumerator() + { + // Have to deserialize all items; otherwise the enumerator won't work because we'll keep + // on changing the collection during individual item deserialization. Also, DeserializeAllItems + // take the lock to ensure consistency - which it does within. + this.DeserializeAllItems(); + return base.GetEnumerator(); + } + + internal static bool IsImmutable(object o) + { + return Misc.ImmutableTypes[o.GetType()] != null; + } + + /// Deletes an item from the collection. + /// The name of the item to delete from the collection. + public void Remove(string name) + { + lock (this._serializedItemsLock) + { + if (this._serializedItems != null) + { + this._serializedItems.Remove(name); + } + base.BaseRemove(name); + this._dirty = true; + } + } + + /// Deletes an item at a specified index from the collection. + /// The index of the item to remove from the collection. + /// + /// is less than zero.- or - is equal to or greater than . + public void RemoveAt(int index) + { + lock (this._serializedItemsLock) + { + if (this._serializedItems != null && index < this._serializedItems.Count) + { + this._serializedItems.RemoveAt(index); + } + base.BaseRemoveAt(index); + this._dirty = true; + } + } + + private void DeserializeAllItems() + { + if (_serializedItems == null) + { + return; + } + + lock (_serializedItemsLock) + { + for (int i = 0; i < _serializedItems.Count; i++) + { + DeserializeItem(_serializedItems.GetKey(i), false); + } + } + } + + private void DeserializeItem(int index) + { + // No-op if SessionStateItemCollection is not deserialized from a persistent storage. + if (_serializedItems == null) + { + return; + } + + lock (_serializedItemsLock) + { + // No-op if the item isn't serialized. + if (index >= _serializedItems.Count) + { + return; + } + + DeserializeItem(_serializedItems.GetKey(index), false); + } + } + + private void DeserializeItem(String name, bool check) + { + object val; + + lock (_serializedItemsLock) + { + if (check) + { + // No-op if SessionStateItemCollection is not deserialized from a persistent storage, + if (_serializedItems == null) + { + return; + } + + // User is asking for an item we don't have. + if (!_serializedItems.ContainsKey(name)) + { + return; + } + } + + Debug.Assert(_serializedItems != null); + Debug.Assert(_stream != null); + + SerializedItemPosition position = (SerializedItemPosition)_serializedItems[name]; + if (position.IsDeserialized) + { + // It has been deserialized already. + return; + } + + // Position the stream to the place where the item is stored. + _stream.Seek(position.Offset, SeekOrigin.Begin); + val = StateSerializationUtil.ReadValueFromStream(new BinaryReader(_stream)); + + BaseSet(name, val); + + // At the end, mark the item as deserialized by making the offset -1 + position.MarkDeserializedOffsetAndCheck(); + } + } + + private void MarkItemDeserialized(String name) + { + // No-op if SessionStateItemCollection is not deserialized from a persistent storage, + if (_serializedItems == null) + { + return; + } + + lock (_serializedItemsLock) + { + // If the serialized collection contains this key, mark it deserialized + if (_serializedItems.ContainsKey(name)) + { + // Mark the item as deserialized by making it -1. + ((SerializedItemPosition)_serializedItems[name]).MarkDeserializedOffset(); + } + } + } + + private void MarkItemDeserialized(int index) + { + // No-op if SessionStateItemCollection is not deserialized from a persistent storage, + if (_serializedItems == null) + { + return; + } + + lock (_serializedItemsLock) + { + // No-op if the item isn't serialized. + if (index >= _serializedItems.Count) + { + return; + } + + ((SerializedItemPosition)_serializedItems[index]).MarkDeserializedOffset(); + } + } + + private class KeyedCollection : NameObjectCollectionBase + { + internal object this[string name] + { + get + { + return base.BaseGet(name); + } + set + { + if (base.BaseGet(name) == null && value == null) + { + return; + } + base.BaseSet(name, value); + } + } + + internal object this[int index] + { + get + { + return base.BaseGet(index); + } + } + + internal KeyedCollection(int count) : base(count, Misc.CaseInsensitiveInvariantKeyComparer) + { + } + + internal void Clear() + { + base.BaseClear(); + } + + internal bool ContainsKey(string name) + { + return base.BaseGet(name) != null; + } + + internal string GetKey(int index) + { + return base.BaseGetKey(index); + } + + internal void Remove(string name) + { + base.BaseRemove(name); + } + + internal void RemoveAt(int index) + { + base.BaseRemoveAt(index); + } + } + + /// + /// Serializes the session state item collection to a stream. + /// + /// + public void Serialize(BinaryWriter writer) + { + int count; + int i; + long iOffsetStart; + long iValueStart; + string key; + object value; + long curPos; + byte[] buffer = null; + Stream baseStream = writer.BaseStream; + + lock (_serializedItemsLock) + { + count = Count; + writer.Write(count); + + if (count > 0) + { + if (BaseGet(null) != null) + { + // We have a value with a null key. Find its index. + for (i = 0; i < count; i++) + { + key = BaseGetKey(i); + if (key == null) + { + writer.Write(i); + break; + } + } + + Debug.Assert(i != count); + } + else + { + writer.Write(NO_NULL_KEY); + } + + // Write out all the keys. + for (i = 0; i < count; i++) + { + key = BaseGetKey(i); + if (key != null) + { + writer.Write(key); + } + } + + // Next, allocate space to store the offset: + // - We won't store the offset of first item because it's always zero. + // - The offset of an item is counted from the beginning of serialized values + // - But we will store the offset of the first byte off the last item because + // we need that to calculate the size of the last item. + iOffsetStart = baseStream.Position; + baseStream.Seek(SIZE_OF_INT32 * count, SeekOrigin.Current); + + iValueStart = baseStream.Position; + + for (i = 0; i < count; i++) + { + // See if that item has not be deserialized yet. + if (_serializedItems != null && + i < _serializedItems.Count && + !((SerializedItemPosition)_serializedItems[i]).IsDeserialized) + { + + SerializedItemPosition position = (SerializedItemPosition)_serializedItems[i]; + + Debug.Assert(_stream != null); + + // The item is read as serialized data from a store, and it's still + // serialized, meaning no one has referenced it. Just copy + // the bytes over. + + // Move the stream to the serialized data and copy it over to writer + _stream.Seek(position.Offset, SeekOrigin.Begin); + + if (buffer == null || buffer.Length < position.DataLength) + { + buffer = new Byte[position.DataLength]; + } + + _stream.Read(buffer, 0, position.DataLength); + + baseStream.Write(buffer, 0, position.DataLength); + } + else + { + value = BaseGet(i); + StateSerializationUtil.WriteValueToStream(value, writer); + } + + curPos = baseStream.Position; + + // Write the offset + baseStream.Seek(i * SIZE_OF_INT32 + iOffsetStart, SeekOrigin.Begin); + writer.Write((int)(curPos - iValueStart)); + + // Move back to current position + baseStream.Seek(curPos, SeekOrigin.Begin); + } + } + } + } + + /// + /// Deserializes the session state item collection from a stream. + /// + /// + /// + /// + public static ConcurrentSessionStateItemCollection Deserialize(BinaryReader reader) + { + ConcurrentSessionStateItemCollection d = new ConcurrentSessionStateItemCollection(); + int count; + int nullKey; + String key; + int i; + byte[] buffer; + + count = reader.ReadInt32(); + + if (count > 0) + { + nullKey = reader.ReadInt32(); + + d._serializedItems = new KeyedCollection(count); + + // First, deserialize all the keys + for (i = 0; i < count; i++) + { + if (i == nullKey) + { + key = null; + } + else + { + key = reader.ReadString(); + } + + // Need to set them with null value first, so that + // the order of them items is correct. + d.BaseSet(key, null); + } + + // Next, deserialize all the offsets + // First offset will be 0, and the data length will be the first read offset + int offset0 = reader.ReadInt32(); + d._serializedItems[d.BaseGetKey(0)] = new SerializedItemPosition(0, offset0); + + int offset1 = 0; + for (i = 1; i < count; i++) + { + offset1 = reader.ReadInt32(); + d._serializedItems[d.BaseGetKey(i)] = new SerializedItemPosition(offset0, offset1 - offset0); + offset0 = offset1; + } + + d._iLastOffset = offset0; + + // _iLastOffset is the first byte past the last item, which equals + // the total length of all serialized data + buffer = new byte[d._iLastOffset]; + int bytesRead = reader.BaseStream.Read(buffer, 0, d._iLastOffset); + if (bytesRead != d._iLastOffset) + { + throw new HttpException(String.Format(CultureInfo.CurrentCulture, SR.Invalid_session_state)); + } + d._stream = new MemoryStream(buffer); + } + + d._dirty = false; + + return d; + } + + private sealed class SerializedItemPosition + { + int _offset; + int _dataLength; + + internal SerializedItemPosition(int offset, int dataLength) + { + this._offset = offset; + this._dataLength = dataLength; + } + + internal int Offset + { + get { return _offset; } + } + + internal int DataLength + { + get { return _dataLength; } + } + + // Mark the item as deserialized by making the offset -1. + internal void MarkDeserializedOffset() + { + _offset = -1; + } + + internal void MarkDeserializedOffsetAndCheck() + { + if (_offset >= 0) + { + MarkDeserializedOffset(); + } + else + { + Debug.Fail("Offset shouldn't be negative inside MarkDeserializedOffsetAndCheck."); + } + } + + internal bool IsDeserialized + { + get { return _offset < 0; } + } + } + } + + internal sealed class Misc + { + private static Hashtable s_immutableTypes; + internal static Hashtable ImmutableTypes => s_immutableTypes; + + private static StringComparer s_caseInsensitiveInvariantKeyComparer; + + internal static StringComparer CaseInsensitiveInvariantKeyComparer + { + get + { + if (Misc.s_caseInsensitiveInvariantKeyComparer == null) + { + Misc.s_caseInsensitiveInvariantKeyComparer = StringComparer.Create(CultureInfo.InvariantCulture, true); + } + return Misc.s_caseInsensitiveInvariantKeyComparer; + } + } + + static Misc() + { + s_immutableTypes = new Hashtable(19); + Type type = typeof(string); + s_immutableTypes.Add(type, type); + type = typeof(int); + s_immutableTypes.Add(type, type); + type = typeof(bool); + s_immutableTypes.Add(type, type); + type = typeof(DateTime); + s_immutableTypes.Add(type, type); + type = typeof(decimal); + s_immutableTypes.Add(type, type); + type = typeof(byte); + s_immutableTypes.Add(type, type); + type = typeof(char); + s_immutableTypes.Add(type, type); + type = typeof(float); + s_immutableTypes.Add(type, type); + type = typeof(double); + s_immutableTypes.Add(type, type); + type = typeof(sbyte); + s_immutableTypes.Add(type, type); + type = typeof(short); + s_immutableTypes.Add(type, type); + type = typeof(long); + s_immutableTypes.Add(type, type); + type = typeof(ushort); + s_immutableTypes.Add(type, type); + type = typeof(uint); + s_immutableTypes.Add(type, type); + type = typeof(ulong); + s_immutableTypes.Add(type, type); + type = typeof(TimeSpan); + s_immutableTypes.Add(type, type); + type = typeof(Guid); + s_immutableTypes.Add(type, type); + type = typeof(IntPtr); + s_immutableTypes.Add(type, type); + type = typeof(UIntPtr); + s_immutableTypes.Add(type, type); + } + + public Misc() { } + } +} \ No newline at end of file diff --git a/src/SessionStateModule/SessionStateModuleAsync.cs b/src/SessionStateModule/SessionStateModuleAsync.cs index b1aeff3..2341193 100644 --- a/src/SessionStateModule/SessionStateModuleAsync.cs +++ b/src/SessionStateModule/SessionStateModuleAsync.cs @@ -102,6 +102,12 @@ public static SessionStateMode ConfigMode return s_configMode; } } + + /// + /// Indicates whether the session state module is configured to optimistically allow concurrent requests + /// + public static bool AllowConcurrentRequestsPerSession => AppSettings.AllowConcurrentRequestsPerSession; + /// /// Initialize the module /// diff --git a/src/SessionStateModule/StateSerializationUtil.cs b/src/SessionStateModule/StateSerializationUtil.cs new file mode 100644 index 0000000..9328d27 --- /dev/null +++ b/src/SessionStateModule/StateSerializationUtil.cs @@ -0,0 +1,315 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See the License.txt file in the project root for full license information. + +namespace Microsoft.AspNet.SessionState +{ + using Microsoft.AspNet.SessionState.Resources; + using System; + using System.Diagnostics; + using System.Globalization; + using System.IO; + using System.Runtime.Serialization.Formatters.Binary; + using System.Web; + using System.Web.SessionState; + + internal static class StateSerializationUtil + { + enum TypeID : byte + { + String = 1, + Int32, + Boolean, + DateTime, + Decimal, + Byte, + Char, + Single, + Double, + SByte, + Int16, + Int64, + UInt16, + UInt32, + UInt64, + TimeSpan, + Guid, + IntPtr, + UIntPtr, + Object, + Null, + } + + internal static Object ReadValueFromStream(BinaryReader reader) + { + TypeID id; + Object value = null; + + id = (TypeID)reader.ReadByte(); + switch (id) + { + case TypeID.String: + value = reader.ReadString(); + break; + + case TypeID.Int32: + value = reader.ReadInt32(); + break; + + case TypeID.Boolean: + value = reader.ReadBoolean(); + break; + + case TypeID.DateTime: + value = new DateTime(reader.ReadInt64()); + break; + + case TypeID.Decimal: + { + int[] bits = new int[4]; + for (int i = 0; i < 4; i++) + { + bits[i] = reader.ReadInt32(); + } + + value = new Decimal(bits); + } + break; + + case TypeID.Byte: + value = reader.ReadByte(); + break; + + case TypeID.Char: + value = reader.ReadChar(); + break; + + case TypeID.Single: + value = reader.ReadSingle(); + break; + + case TypeID.Double: + value = reader.ReadDouble(); + break; + + case TypeID.SByte: + value = reader.ReadSByte(); + break; + + case TypeID.Int16: + value = reader.ReadInt16(); + break; + + case TypeID.Int64: + value = reader.ReadInt64(); + break; + + case TypeID.UInt16: + value = reader.ReadUInt16(); + break; + + case TypeID.UInt32: + value = reader.ReadUInt32(); + break; + + case TypeID.UInt64: + value = reader.ReadUInt64(); + break; + + case TypeID.TimeSpan: + value = new TimeSpan(reader.ReadInt64()); + break; + + case TypeID.Guid: + { + byte[] bits = reader.ReadBytes(16); + value = new Guid(bits); + } + break; + + case TypeID.IntPtr: + if (IntPtr.Size == 4) + { + value = new IntPtr(reader.ReadInt32()); + } + else + { + Debug.Assert(IntPtr.Size == 8); + value = new IntPtr(reader.ReadInt64()); + } + break; + + case TypeID.UIntPtr: + if (UIntPtr.Size == 4) + { + value = new UIntPtr(reader.ReadUInt32()); + } + else + { + Debug.Assert(UIntPtr.Size == 8); + value = new UIntPtr(reader.ReadUInt64()); + } + break; + + case TypeID.Object: + BinaryFormatter formatter = new BinaryFormatter(); + if (SessionStateUtility.SerializationSurrogateSelector != null) + { + formatter.SurrogateSelector = SessionStateUtility.SerializationSurrogateSelector; + } + value = formatter.Deserialize(reader.BaseStream); + break; + + case TypeID.Null: + value = null; + break; + } + + return value; + } + + internal static void WriteValueToStream(Object value, BinaryWriter writer) + { + if (value == null) + { + writer.Write((byte)TypeID.Null); + } + else if (value is String) + { + writer.Write((byte)TypeID.String); + writer.Write((String)value); + } + else if (value is Int32) + { + writer.Write((byte)TypeID.Int32); + writer.Write((Int32)value); + } + else if (value is Boolean) + { + writer.Write((byte)TypeID.Boolean); + writer.Write((Boolean)value); + } + else if (value is DateTime) + { + writer.Write((byte)TypeID.DateTime); + writer.Write(((DateTime)value).Ticks); + } + else if (value is Decimal) + { + writer.Write((byte)TypeID.Decimal); + int[] bits = Decimal.GetBits((Decimal)value); + for (int i = 0; i < 4; i++) + { + writer.Write((int)bits[i]); + } + } + else if (value is Byte) + { + writer.Write((byte)TypeID.Byte); + writer.Write((byte)value); + } + else if (value is Char) + { + writer.Write((byte)TypeID.Char); + writer.Write((char)value); + } + else if (value is Single) + { + writer.Write((byte)TypeID.Single); + writer.Write((float)value); + } + else if (value is Double) + { + writer.Write((byte)TypeID.Double); + writer.Write((double)value); + } + else if (value is SByte) + { + writer.Write((byte)TypeID.SByte); + writer.Write((SByte)value); + } + else if (value is Int16) + { + writer.Write((byte)TypeID.Int16); + writer.Write((short)value); + } + else if (value is Int64) + { + writer.Write((byte)TypeID.Int64); + writer.Write((long)value); + } + else if (value is UInt16) + { + writer.Write((byte)TypeID.UInt16); + writer.Write((UInt16)value); + } + else if (value is UInt32) + { + writer.Write((byte)TypeID.UInt32); + writer.Write((UInt32)value); + } + else if (value is UInt64) + { + writer.Write((byte)TypeID.UInt64); + writer.Write((UInt64)value); + } + else if (value is TimeSpan) + { + writer.Write((byte)TypeID.TimeSpan); + writer.Write(((TimeSpan)value).Ticks); + } + else if (value is Guid) + { + writer.Write((byte)TypeID.Guid); + Guid guid = (Guid)value; + byte[] bits = guid.ToByteArray(); + writer.Write(bits); + } + else if (value is IntPtr) + { + writer.Write((byte)TypeID.IntPtr); + IntPtr v = (IntPtr)value; + if (IntPtr.Size == 4) + { + writer.Write((Int32)v.ToInt32()); + } + else + { + Debug.Assert(IntPtr.Size == 8); + writer.Write((Int64)v.ToInt64()); + } + } + else if (value is UIntPtr) + { + writer.Write((byte)TypeID.UIntPtr); + UIntPtr v = (UIntPtr)value; + if (UIntPtr.Size == 4) + { + writer.Write((UInt32)v.ToUInt32()); + } + else + { + Debug.Assert(UIntPtr.Size == 8); + writer.Write((UInt64)v.ToUInt64()); + } + } + else + { + writer.Write((byte)TypeID.Object); + BinaryFormatter formatter = new BinaryFormatter(); + if (SessionStateUtility.SerializationSurrogateSelector != null) + { + formatter.SurrogateSelector = SessionStateUtility.SerializationSurrogateSelector; + } + try + { + formatter.Serialize(writer.BaseStream, value); + } + catch (Exception innerException) + { + HttpException outerException = new HttpException(String.Format(CultureInfo.CurrentCulture, SR.Cant_serialize_session_state), innerException); + throw outerException; + } + } + } + } +} \ No newline at end of file diff --git a/src/SqlSessionStateProviderAsync/SqlSessionStateProviderAsync.cs b/src/SqlSessionStateProviderAsync/SqlSessionStateProviderAsync.cs index 7a6b3bf..ee91c33 100644 --- a/src/SqlSessionStateProviderAsync/SqlSessionStateProviderAsync.cs +++ b/src/SqlSessionStateProviderAsync/SqlSessionStateProviderAsync.cs @@ -41,7 +41,7 @@ public class SqlSessionStateProviderAsync : SessionStateStoreProviderAsyncBase private static RepositoryType s_repositoryType; private int _rqOrigStreamLen; - + /// /// Initialize the provider through the configuration /// @@ -172,6 +172,27 @@ internal static Func GetSessionStaticO } = SessionStateUtility.GetSessionStaticObjects; #endregion + private static ISessionStateItemCollection CreateItemCollection() + { + return SessionStateModuleAsync.AllowConcurrentRequestsPerSession ? + new ConcurrentSessionStateItemCollection() as ISessionStateItemCollection : + new SessionStateItemCollection() as ISessionStateItemCollection; + } + + private static void SerializeItemCollection(ISessionStateItemCollection items, BinaryWriter writer) + { + if (items is ConcurrentSessionStateItemCollection concurrentItems) + concurrentItems.Serialize(writer); + else if (items is SessionStateItemCollection defaultItems) + defaultItems.Serialize(writer); + } + + private static ISessionStateItemCollection DeserializeItemCollection(BinaryReader reader) + { + return SessionStateModuleAsync.AllowConcurrentRequestsPerSession ? + ConcurrentSessionStateItemCollection.Deserialize(reader) as ISessionStateItemCollection : + SessionStateItemCollection.Deserialize(reader) as ISessionStateItemCollection; + } private int? GetMaxRetryNum(NameValueCollection config) { @@ -209,7 +230,7 @@ public override SessionStateStoreData CreateNewStoreData(HttpContextBase context staticObjects = GetSessionStaticObjects(context.ApplicationInstance.Context); } - return new SessionStateStoreData(new SessionStateItemCollection(), staticObjects, timeout); + return new SessionStateStoreData(CreateItemCollection(), staticObjects, timeout); } /// @@ -231,7 +252,7 @@ public override async Task CreateUninitializedItemAsync( byte[] buf; int length; - var item = new SessionStateStoreData(new SessionStateItemCollection(), + var item = new SessionStateStoreData(CreateItemCollection(), GetSessionStaticObjects(context.ApplicationInstance.Context), timeout); @@ -478,7 +499,7 @@ private static void Serialize(SessionStateStoreData item, Stream stream) if (hasItems) { - ((SessionStateItemCollection)item.Items).Serialize(writer); + SerializeItemCollection(item.Items, writer); } if (hasStaticObjects) @@ -507,7 +528,7 @@ internal static SessionStateStoreData DeserializeStoreData(HttpContextBase conte private static SessionStateStoreData Deserialize(HttpContextBase context, Stream stream) { int timeout; - SessionStateItemCollection sessionItems; + ISessionStateItemCollection sessionItems; bool hasItems; bool hasStaticObjects; HttpStaticObjectsCollection staticObjects; @@ -525,10 +546,10 @@ private static SessionStateStoreData Deserialize(HttpContextBase context, Stream if (hasItems) { - sessionItems = SessionStateItemCollection.Deserialize(reader); + sessionItems = DeserializeItemCollection(reader); } else { - sessionItems = new SessionStateItemCollection(); + sessionItems = CreateItemCollection(); } if (hasStaticObjects) diff --git a/test/Microsoft.AspNet.SessionState.SqlSessionStateProviderAsync.Test/SqlSessionStateAsyncProviderTest.cs b/test/Microsoft.AspNet.SessionState.SqlSessionStateProviderAsync.Test/SqlSessionStateAsyncProviderTest.cs index 0412361..cd1de5a 100644 --- a/test/Microsoft.AspNet.SessionState.SqlSessionStateProviderAsync.Test/SqlSessionStateAsyncProviderTest.cs +++ b/test/Microsoft.AspNet.SessionState.SqlSessionStateProviderAsync.Test/SqlSessionStateAsyncProviderTest.cs @@ -7,6 +7,7 @@ namespace Microsoft.AspNet.SessionState.SqlSessionStateAsyncProvider.Test using Microsoft.Data.SqlClient; using Moq; using System; + using System.Collections.Generic; using System.Collections.Specialized; using System.Configuration; using System.IO; @@ -228,12 +229,18 @@ public void CreateNewStoreData_Should_Return_Empty_Store() Assert.Equal(TestTimeout, store.Timeout); } + public static IEnumerable SerializeSessionData() + { + yield return new object[] { true, new SessionStateItemCollection() }; + yield return new object[] { true, new SessionStateItemCollection() }; + yield return new object[] { false, new ConcurrentSessionStateItemCollection() }; + yield return new object[] { false, new ConcurrentSessionStateItemCollection() }; + } + [Theory] - [InlineData(true)] - [InlineData(false)] - public void Serialize_And_Deserialized_SessionStateStoreData_RoundTrip_Should_Work(bool enableCompression) + [MemberData(nameof(SerializeSessionData))] + public void Serialize_And_Deserialized_SessionStateStoreData_RoundTrip_Should_Work(bool enableCompression, ISessionStateItemCollection sessionCollection) { - var sessionCollection = new SessionStateItemCollection(); var now = DateTime.UtcNow; sessionCollection["test1"] = "test1"; sessionCollection["test2"] = now;